Home

Awesome

A Gradient Flow Framework for Analyzing Network Pruning

Codebase for the paper "A Gradient Flow Framework for Analyzing Network Pruning" [ICLR, 2021].

Requirements

The code requires:

To install other dependencies, the following command can be used (uses pip):

./requirements.sh

Organization

The provided modules serve the following purpose:

Trained base models will be stored in the directory pretrained and pruned models will be saved in pruned_nets. Stats collected for train/test numbers are stored in the directory stats.

Example execution

To prune a model (e.g., resnet-56) using a particular importance measure (e.g., magnitude-based pruning), run the following command

python main.py --model=resnet-56 --pruning_type=mag_based --prune_percent=75 --n_rounds=25

Summary of basic options

--model=<model_name>

--seed=<change_random_seed>

--pretrained_path=<use_pretrained_model>

--data_path=<path_to_data>

--download=<download_cifar>

--pruning_type=<how_to_estimate_importance>

--prune_percent=<how_much_percent_filters_to_prune>

--n_rounds=<number_of_pruning_rounds>

--T=<temperature>

--grasp_T=<temperature_for_grasp>

--imp_samples=<importance_samples>

--track_stats=<track_train/test_numbers>

Training Settings: To change number of epochs or the learning rate schedule for training the base models or the pruned models, change the hyperparameters in config.py. By default, models are trained using SGD with momentum (0.9).

Stats: The stats are stored as a dict divided into train and test, which are both further divided into warmup training, pruning, and final training.

Evaluation

To evaluate a model (e.g., a pruned VGG-13 model), use:

python eval.py --model vgg --pruned True --model_path <path_to_model_file> --test_acc True

Summary of available options for evaluating models:

--model=<model_name>

--pruned=<evaluating_a_pruned_model>

--model_path=<path_to_model>

--data_path=<path_to_dataset>

--train_acc=<evaluate_train_accuracy>

--test_acc=<evaluate_test_accuracy>

--flops=<evaluate_flops_in_model>

--compression=<evaluate_compression_ratio>

--download=<download_standard_dataset>

--num_classes=<num_classes>

Extra functionalities (for experimental purposes)

The codebase contains several functionalities that weren't used in the paper. These allow one to experiment further with our paper's theory. For example, we provide pruned model classes for ResNet-34 and ResNet-18, several other importance measures based on loss-preservation, allow importance tracking over minibatches, options to warmup a model by training for few epochs before pruning, use of manual pruning thresholds, etc. While base settings for these extra functionalities based on our limited tests are already set, we encourage users to fiddle around and engage with us to find better settings or even better importance measures! Here, brief summary of these options is provided:

--pruning_type=<how_to_estimate_importance>

--data_path=<path_to_data>

--num_classes=<number_of_classes_in_dataset>

--moment=<momentum_for_training>

--lr_prune=<learning_rate_for_pruning>

--warmup_epochs=<warmup_before_pruning>

--thresholds=<manual_thresholds_for_pruning>

--preserve_moment=<preserve_momentum>

--track_batches=<track_importance>

--track_batches=<track_batches>

--alpha=<momentum_for_tracking_importance>

--use_l2=<use_l2>

--use_init=<use_init>

--bypass=<bypass>