Home

Awesome

SNIP-it / SNAP-it

(Un)structured Pruning via Iterative Ranking of Sensitivity Statistics

Python 3.7 PyTorch 1.4 MIT

This repository is the official implementation of the paper Pruning via Iterative Ranking of Sensitivity Statistics. Currently under review. Please use this preliminary BibTex entry when referring to our work:

@article{verdenius2020pruning,
       author = {{Verdenius}, Stijn and {Stol}, Maarten and {Forr{\'e}}, Patrick},
        title = "{Pruning via Iterative Ranking of Sensitivity Statistics}",
      journal = {arXiv e-prints},
     keywords = {Computer Science - Machine Learning, Statistics - Machine Learning},
         year = 2020,
        month = jun,
          eid = {arXiv:2006.00896},
        pages = {arXiv:2006.00896},
archivePrefix = {arXiv},
       eprint = {2006.00896},
 primaryClass = {cs.LG},
}

Content

The repository implements novel pruning / compression algorithms for deep learning / neural networks. Additionally, it implements the shrinkage of actual tensors to really benefit from structured pruning without external hardware libraries. We implement:

Setup

pip3 install virtualenv

virtualenv -p python3 ~/virtualenvs/SNIPIT

source ~/virtualenvs/SNIPIT/bin/activate

pip install -r requirements.txt

Training Examples & Results

Some examples of training the models from the paper.

Structured Pruning (SNAP-it)

To run training for SNAP-it - our structured pruning before training algorithm - with a VGG16 on CIFAR10, run the following:

python3 main.py --model VGG16 --data_set CIFAR10 --prune_criterion SNAPit --pruning_limit 0.93 --epochs 80
<img src="./pictures/__structured-VGG16-CIFAR10_acc_node_sparse.png" alt="drawing" width="500"/>
accuracy-dropweight sparsitynode sparsitycumulative training FLOPS reduction
-1%99%93%8 times

Unstructured Pruning (SNIP-it)

To run training for SNIP-it - our unstructured pruning algorithm - with a ResNet18 on CIFAR10, run one of the following:

## during training
python3 main.py --model ResNet18 --data_set CIFAR10 --prune_criterion SNIPitDuring --pruning_limit 0.98 --outer_layer_pruning --epochs 80 --prune_delay 4 --prune_freq 4
## before training
python3 main.py --model ResNet18 --data_set CIFAR10 --prune_criterion SNIPit --pruning_limit 0.98 --outer_layer_pruning --epochs 80 
<img src="./pictures/__unstructured-ResNet18-CIFAR10_acc_weight_sparse.png" alt="drawing" width="400"/>
accuracy-dropweight sparsity
SNIP-it (during)-0%98%
SNIP-it (before)-4%98%

Adversarial Evaluation

To evaluate a model on adversarial attacks (for now only supported on unstructured), run:

python main.py --eval --model MLP5 --data_set MNIST --checkpoint_name <see_results_folder> --checkpoint_model MLP5_finished --attack CarliniWagner

Visualization

Results and saved models will be logged to the terminal, logfiles in result-folders and in tensorboard files in the /gitignored/results/ folder. To run tensorboard's interface run the following:

tensorboard --logdir ./gitignored/results/

Arguments

The regular arguments for running are the following. Additionally, there are some more found in utils/config_utils.py.

argumentdescriptiontype
--modelThe neural network architecture from /models/networks/str
--data_setThe dataset from /utils/dataloaders.pystr
--prune_criterionThe pruning criterion from /models/criterions/str
--batch_sizeThe batch sizeint
--optimizerThe optimizer model class from [ADAM, SGD & RMSPROP]str
--lossThe loss function from /models/losses/str
--train_schemeThe training scheme from /models/trainers/ (if applicable)str
--test_schemeThe testing scheme from /models/testers/ (if applicable)str
--evalAdd to run in test modebool
--attackName of adersarial attack if that is the test_schemestr
--deviceDevice [cuda or cpu]srt
--run_nameExtra run identification for generated run folderstr
--checkpoint_nameLoad from this checkpoint folder if not Nonestr
--checkpoint_modelLoad this model from checkpoint_namestr
--outer_layer_pruningPrunes outer layers too. Use iff pruning unstructuredbool
--enable_rewindingDoes rewinding of weights (for IMP)bool
--rewind_epochEpoch to rewind toint
--l0Run with L0-reg layers, overrides some other optionsbool
--l0_regL0 regularisation hyperparameterfloat
--hoyer_squareRun with hoyersquare, overrides some other optionsbool
--group_hoyer_squareRun with grouphoyersquare, overrides some other optionsbool
--hoyer_regHoyer regularisation hyperparameterfloat
--learning_rateLearning ratefloat
--pruning_limitFinal sparsity endeavour for applicable pruning criterionsfloat
--pruning_rateOutdated pruning_limit, still used for UnstructuredRandomfloat
--snip_stepsS from paper algorithm box 1. Number of pruning stepsint
--epochsHow long to train forint
--prune_delayTau from paper algorithm box 1. How long to start pruningint
--prune_freqTau from paper algorithm box 1 again. How often to pruneint
--seedRandom seed to run withint
--tuningRun with train and held out validationset, omit testsetbool

Some notes:

Codebase Design

How to run the other baselines

## unpruned baselines
python3 main.py --model VGG16 --data_set CIFAR10 --prune_criterion EmptyCrit --epochs 80 --pruning_limit 0.0
python3 main.py --model ResNet18 --data_set CIFAR10 --prune_criterion EmptyCrit --epochs 80 --pruning_limit 0.0

## structured baselines
python3 main.py --model VGG16 --data_set CIFAR10 --prune_criterion StructuredRandom --pruning_limit 0.93 --epochs 80
python3 main.py --model VGG16 --data_set CIFAR10 --prune_criterion GateDecorators --pruning_limit 0.93 --epochs 70 --checkpoint_name <unpruned_results_folder> --checkpoint_model VGG16_finished 
python3 main.py --model VGG16 --data_set CIFAR10 --prune_criterion EfficientConvNets --pruning_limit 0.93 --epochs 80 --prune_delay 69 --prune_freq 1
python3 main.py --model VGG16 --data_set CIFAR10 --prune_criterion GroupHoyerSquare --hoyer_reg <REG> --epochs 80 --prune_delay 69 --prune_freq 1 --group_hoyer_square
python3 main.py --model VGG16 --data_set CIFAR10 --l0_reg <REG> --epochs 160 --l0

## unstructured baselines
python3 main.py --model ResNet18 --data_set CIFAR10 --prune_criterion UnstructuredRandom --pruning_rate 0.98 --pruning_limit 0.98 --outer_layer_pruning --epochs 80
python3 main.py --model ResNet18 --data_set CIFAR10 --prune_criterion <SNIP or GRASP> --pruning_limit 0.98 --outer_layer_pruning --epochs 80
python3 main.py --model ResNet18 --data_set CIFAR10 --prune_criterion HoyerSquare --hoyer_reg <REG> --outer_layer_pruning --epochs 80 --prune_delay 69 --prune_freq 1 --hoyer_square
python3 main.py --model ResNet18 --data_set CIFAR10 --prune_criterion IMP --pruning_limit 0.98 --outer_layer_pruning --epochs 80 --prune_delay 4 --prune_freq 4 --enable_rewinding --rewind_to 6

Licence

MIT Licence