Awesome
NTK-SAP: Improving neural network pruning by aligning training dynamics
Yite Wang, Dawei Li, Ruoyu Sun
In ICLR 2023.
Overview
This is the PyTorch implementation of NTK-SAP: Improving neural network pruning by aligning training dynamics.
Installation
To run our code, then install all dependencies
pip install -r requirements.txt
Running
Below is a description of the major sections of the code base. Run python main.py --help
for a complete description of flags and hyperparameters.
1. Prepare the datasets
MNIST, CIFAR-10, CIFAR-100, Tiny ImageNet will be downloaded automatically. For ImageNet experiment, please download it to Data/imagenet_raw/
, or change corresponding path in Utils/load.py
.
2. Run foresight pruning experiments
Note experiments of ImageNet requires running code to prune and train separately, see the argument experiment
. For other experiments, models will be trained right after pruning. We include a few important arguments:
--experiment
: For CIFAR-10, CIFAR-100, and Tiny-ImageNet experiments, you can either usesingleshot
ormultishot
. For ImageNet experiment, please usemultishot_ddp_prune
to get mask then train withmultishot_ddp_train
.--dataset
: Which dataset to use, to reproduce our results, usecifar10
,cifar100
,tiny-imagenet
, andimagenet
.--model-class
: For CIFAR-10 and CIFAR-100 experiments, please uselottery
. For Tiny-imagenet and ImageNet experiments, please useimagenet
.--model
: Which model architecture to use. In our experiments, we useresnet20
,vgg16-bn
,resnet18
, andresnet50
.--pruner
: Which pruning algorithms to use, choose from:rand
,mag
,snip
,grasp
,synflow
,itersnip
,NTKSAP
.--prune-batch-size
: Batch size of pruning datasets.--compression
: You can use this argument to change sparsity forsingleshot
experiments. Specifically, the target density will be $0.8^{\text{compression}}$. Formultishot
experiments, please refer to--compression-list
.--prune-train-mode
: Set this toTrue
if you use pruning algorithms except Synflow.--prune-epochs
: Number of pruning iterations $T$.--ntksap_R
: Number of resampling procedures, only change this for CIFAR-10 experiment.--ntk_epsilon
: Perturbation hyper-parameter used in NTK-SAP.
A sample script can be found in scripts/run.sh
.
Acknowledgement
Our code is developed based on the Synflow code: https://github.com/ganguli-lab/Synaptic-Flow.