An open source implementation of Google Research's paper (Authored by Utku Evci, an AI Resident @ Google Brain): Rigging the Lottery: Making All Tickets Winners (RigL) in PyTorch as versatile, simple, and fast as possible.

You only need to add 2 lines of code to your PyTorch project to use RigL to train your model with sparsity!

ImageNet Results

Results aren't quite as complete as the original paper, but if you end up running on ImageNet with different configurations and/or different datasets, I would love to include them in the repo here!


ArchitectureSparsity %S. DistributionTop-1Original Top-1Model/Ckpt

Other Implementations:

Contributions Beyond the Paper:

Gradient Accumulation:




User Setup:

Contributor Setup:


from rigl_torch.RigL import RigLScheduler

# first, create your model
model = ... # note: only tested on torch.hub's resnet networks (ie. resnet18 / resnet50)

# create your dataset/dataloader
dataset = ...
dataloader = ...

# define your optimizer (recommended SGD w/ momentum)
optimizer = ...

# RigL runs best when you allow RigL's topology modifications to run for 75% of the total training iterations (batches)
# so, let's calculate T_end according to this
epochs = 100
total_iterations = len(dataloader) * epochs
T_end = int(0.75 * total_iterations)

# ------------------------------------ REQUIRED LINE # 1 ------------------------------------
# now, create the RigLScheduler object
pruner = RigLScheduler(model,                           # model you created
                       optimizer,                       # optimizer (recommended = SGD w/ momentum)
                       dense_allocation=0.1,            # a float between 0 and 1 that designates how sparse you want the network to be 
                                                          # (0.1 dense_allocation = 90% sparse)
                       sparsity_distribution='uniform', # distribution hyperparam within the paper, currently only supports `uniform`
                       T_end=T_end,                     # T_end hyperparam within the paper (recommended = 75% * total_iterations)
                       delta=100,                       # delta hyperparam within the paper (recommended = 100)
                       alpha=0.3,                       # alpha hyperparam within the paper (recommended = 0.3)
                       grad_accumulation_n=1,           # new hyperparam contribution (not in the paper) 
                                                          # for more information, see the `Contributions Beyond the Paper` section
                       static_topo=False,               # if True, the topology will be frozen, in other words RigL will not do it's job 
                                                          # (for debugging)
                       ignore_linear_layers=False,      # if True, linear layers in the network will be kept fully dense
                       state_dict=None)                 # if you have checkpointing enabled for your training script, you should save 
                                                          # `pruner.state_dict()` and when resuming pass the loaded `state_dict` into 
                                                          # the pruner constructor
# -------------------------------------------------------------------------------------------
... more code ...

for epoch in range(epochs):
    for data in dataloader:
        # do forward pass, calculate loss, etc.
        # instead of calling optimizer.step(), wrap it as such:
# ------------------------------------ REQUIRED LINE # 2 ------------------------------------
        if pruner():
# -------------------------------------------------------------------------------------------
            # this block of code will execute according to the given hyperparameter schedule
            # in other words, optimizer.step() is not called after a RigL step
    # it is also recommended that after every epoch you checkpoint your training progress
    # to do so with RigL training you should also save the pruner object state_dict
        'model': model.state_dict(),
        'pruner': pruner.state_dict(),
        'optimizer': optimizer.state_dict()
    }, 'checkpoint.pth')
# at any time you can print the RigLScheduler object and it will show you the sparsity distributions, number of training steps/rigl steps, etc!

# save model
torch.save(model.state_dict(), 'model.pth')


    author = {McCreary, Dyllan},
    title = {PyTorch Implementation of Rigging the Lottery: Making All Tickets Winners}, 
    url = {https://github.com/nollied/rigl-torch},
    year = {2020}, 
    month = {Nov},
    note = {Reimplementation/extension of the work done by Google Research: https://github.com/google-research/rigl}

