Home

Awesome

Rigging the Lottery: Making All Tickets Winners

<img src="https://github.com/google-research/rigl/blob/master/imgs/flops8.jpg" alt="80% Sparse Resnet-50" width="45%" align="middle">

Paper: https://arxiv.org/abs/1911.11134

15min Presentation [pml4dc] [icml]

ML Reproducibility Challenge 2020 report

Colabs for Calculating FLOPs of Sparse Models

MobileNet-v1

ResNet-50

Best Sparse Models

Parameters are float, so each parameter is represented with 4 bytes. Uniform sparsity distribution keeps first layer dense therefore have slightly larger size and parameters. ERK applies to all layers except for 99% sparse model, in which we set the first layer to be dense, since otherwise we observe much worse performance.

Extended Training Results

Performance of RigL increases significantly with extended training iterations. In this section we extend the training of sparse models by 5x. Note that sparse models require much less FLOPs per training iteration and therefore most of the extended trainings cost less FLOPs than baseline dense training.

Observing improving performance we wanted to understand where the performance of sparse networks saturates. Longest training we ran had 100x training length of the original 100 epoch ImageNet training. This training costs 5.8x of the original dense training FLOPS and the resulting 99% sparse Resnet-50 achieves an impressive 68.15% test accuracy (vs 5x training accuracy of 61.86%).

S. DistributionSparsityTraining FLOPsInference FLOPsModel Size (Bytes)Top-1 AccCkpt
- (DENSE)03.2e188.2e9102.12276.8-
ERK0.82.09x0.42x23.68377.17link
Uniform0.81.14x0.23x23.68576.71link
ERK0.91.23x0.24x13.49976.42link
Uniform0.90.66x0.13x13.53275.73link
ERK0.950.63x0.12x8.39974.63link
Uniform0.950.42x0.08x8.43373.22link
ERK0.9650.45x0.09x6.90472.77link
Uniform0.9650.34x0.07x6.90471.31link
ERK0.990.29x0.05x4.35461.86link
ERK0.990.58x0.05x4.35463.89link
ERK0.992.32x0.05x4.35466.94link
ERK0.995.8x0.05x4.35468.15link

We also ran extended training runs with MobileNet-v1. Again training 100x more, we were not able saturate the performance. Training longer consistently achieved better results.

S. DistributionSparsityTraining FLOPsInference FLOPsModel Size (Bytes)Top-1 AccCkpt
- (DENSE)04.5e171.14e916.86472.1-
ERK0.891.39x0.21x2.39269.31link
ERK0.892.79x0.21x2.39270.63link
Uniform0.891.25x0.09x2.39269.28link
Uniform0.896.25x0.09x2.39270.25link
Uniform0.8912.5x0.09x2.39270.59link

1x Training Results

S. DistributionSparsityTraining FLOPsInference FLOPsModel Size (Bytes)Top-1 AccCkpt
ERK0.80.42x0.42x23.68375.12link
Uniform0.80.23x0.23x23.68574.60link
ERK0.90.24x0.24x13.49973.07link
Uniform0.90.13x0.13x13.53272.02link

Results w/o label smoothing

S. DistributionSparsityTraining FLOPsInference FLOPsModel Size (Bytes)Top-1 AccCkpt
ERK0.80.42x0.42x23.68375.02link
ERK0.82.09x0.42x23.68376.17link
ERK0.90.24x0.24x13.49973.4link
ERK0.91.23x0.24x13.49975.9link
ERK0.950.13x0.12x8.39970.39link
ERK0.950.63x0.12x8.39974.36link

Evaluating checkpoints

Download the checkpoints and run the evaluation on ERK checkpoints with the following:

python imagenet_train_eval.py --mode=eval_once --output_dir=path/to/ckpt/folder \
    --eval_once_ckpt_prefix=model.ckpt-3200000 --use_folder_stub=False \
    --training_method=rigl --mask_init_method=erdos_renyi_kernel \
    --first_layer_sparsity=-1

When running checkpoints with uniform sparsity distribution use --mask_init_method=random and --first_layer_sparsity=0. Set --model_architecture=mobilenet_v1 when evaluating mobilenet checkpoints.

Sparse Training Algorithms

In this repository we implement following dynamic sparsity strategies:

  1. SET: Implements Sparse Evalutionary Training (SET) which corresponds to replacing low magnitude connections randomly with new ones.

  2. SNFS: Implements momentum based training without sparsity re-distribution:

  3. RigL: Our method, RigL, removes a fraction of connections based on weight magnitudes and activates new ones using instantaneous gradient information.

And the following one-shot pruning algorithm:

  1. SNIP: Single-shot Network Pruning based on connection sensitivity prunes the least salient connections before training.

We have code for following settings:

Setup

First clone this repo.

git clone https://github.com/google-research/rigl.git
cd rigl

We use Neurips 2019 MicroNet Challenge code for counting operations and size of our networks. Let's clone the google_research repo and add current folder to the python path.

git clone https://github.com/google-research/google-research.git
mv google-research/ google_research/
export PYTHONPATH=$PYTHONPATH:$PWD

Now we can run some tests. Following script creates a virtual environment and installs the necessary libraries. Finally, it runs few tests.

bash run.sh

We need to activate the virtual environment before running an experiment. With that, we are ready to run some trivial MNIST experiments.

source env/bin/activate

python rigl/mnist/mnist_train_eval.py

You can load and verify the performance of the Resnet-50 checkpoints like following.

python rigl/imagenet_resnet/imagenet_train_eval.py --mode=eval_once --training_method=baseline --eval_batch_size=100 --output_dir=/path/to/folder --eval_once_ckpt_prefix=s80_model.ckpt-1280000 --use_folder_stub=False

We use the Official TPU Code for loading ImageNet data. First clone the tensorflow/tpu repo and then add models/ folder to the python path.

git clone https://github.com/tensorflow/tpu.git
export PYTHONPATH=$PYTHONPATH:$PWD/tpu/models/

Other Implementations

Citation

@incollection{rigl,
 author = {Evci, Utku and Gale, Trevor and Menick, Jacob and Castro, Pablo Samuel and Elsen, Erich},
 booktitle = {Proceedings of Machine Learning and Systems 2020},
 pages = {471--481},
 title = {Rigging the Lottery: Making All Tickets Winners},
 year = {2020}
}

Disclaimer

This is not an official Google product.