Home

Awesome

[ACCV2022] Network Pruning via Feature Shift Minimization

This is the official code for our ACCV22 paper: Network Pruning via Feature Shift Minimization.

Authors: Yuanzhi Duan, Yue Zhou, Peng He, Qiang Liu, Shukai Duan, Xiaofang Hu.

The diagram of feature shift is:

Environments

The code has been tested in the following environments:

Pre-trained Models

Vgg-16 | ResNet56 | GoogLeNet | ResNet50

Running Code

The settings are listed below:

1. VGGNet

Compression RateFlops($\downarrow $)Params($\downarrow $)Accuracy
[0.3]*5+[0.35]*3+[0.8]*5106.8M(66.0%)2.05M(86.3%)93.73%
[0.5]*8+[0.8]*559.7M(81.0%)1.41M(90.6%)92.86%
#VGGNet
#All run scripts can be cut-copy-paste from run.bat or run.sh.
python main.py \
--arch vgg_16_bn \
--resume [pre-trained model dir] \
--compress_rate [0.3]*5+[0.35]*3+[0.8]*5 \
--num_workers 1 \
--epochs 1 \
--lr 0.001 \
--lr_decay_step 1 \
--weight_decay 0. \
--data_dir [dataset dir] \
--dataset CIFAR10 \
--save_id 1 

python main.py \
--arch vgg_16_bn \
--from_scratch True \
--resume final_pruned_model/vgg_16_bn_1.pt \
--num_workers 1 \
--epochs 200 \
--gpu 0 \
--lr 0.01 \
--lr_decay_step 100,150 \
--weight_decay 0. \
--data_dir [dataset dir] \
--dataset CIFAR10 \
--save_id 1 

2. ResNet-56

Compression RateFlops($\downarrow $)Params($\downarrow $)Accuracy
[0.]+[0.6,0.]*9+[0.6,0.]*9+[0.4,0.]*961.7M(51.2%)0.48M(43.6%)93.58%
[0.]+[0.7,0.]*9+[0.7,0.]*9+[0.7,0.]*940.2M(68.2)0.27M(68.5)92.76%
#ResNet-56
python main.py \
--arch resnet_56 \
--resume [pre-trained model dir] \
--compress_rate [0.]+[0.6,0.]*9+[0.6,0.]*9+[0.4,0.]*9 \
--num_workers 1 \
--epochs 1 \
--lr 0.001 \
--lr_decay_step 1 \
--weight_decay 0. \
--data_dir [dataset dir] \
--dataset CIFAR10 \
--save_id 1 

python main.py \
--arch resnet_56 \
--from_scratch True \
--resume final_pruned_model/resnet_56_1.pt \
--num_workers 1 \
--epochs 300 \
--gpu 0 \
--lr 0.01 \
--lr_decay_step 150,225 \
--weight_decay 0.0005 \
--data_dir [dataset dir] \
--dataset CIFAR10 \
--save_id 1 

3. GoogLeNet

Compression RateFlops($\downarrow $)Params($\downarrow $)Accuracy
[0.2]+[0.7]*15+[0.8]*9+[0.,0.4,0.]0.567B(63.0%)2.75M(55.5%)94.72%
[0.2]+[0.9]*24+[0.,0.4,0.]0.376B(75.4%)2.19M(64.6%)94.29%
#GoogLeNet
python main.py \
--arch googlenet \
--resume [pre-trained model dir] \
--compress_rate [0.2]+[0.7]*15+[0.8]*9+[0.,0.4,0.] \
--num_workers 1 \
--epochs 1 \
--lr 0.001 \
--weight_decay 0. \
--data_dir [dataset dir] \
--dataset CIFAR10 \
--save_id 1 

python main.py \
--arch googlenet \
--from_scratch True \
--resume final_pruned_model/googlenet_1.pt \
--num_workers 1 \
--epochs 200 \
--lr 0.01 \
--lr_decay_step 100,150 \
--weight_decay 0. \
--data_dir [dataset dir] \
--dataset CIFAR10 \
--save_id 1 

4. ResNet-50

Compression RateFlops($\downarrow $)Params($\downarrow $)Top-1 AccTop-5 Acc
[0.]+[0.2,0.2,0.2]*1+[0.65,0.65,0.2]*2+[0.2,0.2,0.2]*1+[0.65,0.65,0.2]*3+[0.2,0.2,0.2]*1+[0.65,0.65,0.2]*5+[0.2,0.2,0.1]+[0.2,0.2,0.1]*21.76B(57.2%)14.6M(42.8%)75.43%92.45%
#ResNet-50
python main.py \
--arch resnet_50 \
--resume [pre-trained model dir] \
--data_dir [dataset dir] \
--dataset ImageNet \
--compress_rate [0.]+[0.2,0.2,0.2]*1+[0.65,0.65,0.2]*2+[0.2,0.2,0.2]*1+[0.65,0.65,0.2]*3+[0.2,0.2,0.2]*1+[0.65,0.65,0.2]*5+[0.2,0.2,0.1]+[0.2,0.2,0.1]*2 \
--num_workers 4 \
--batch_size 64 \
--epochs 1 \
--lr_decay_step 1 \
--lr 0.001 \
--weight_decay 0. \
--input_size 224 \
--save_id 1 

python main.py \
--arch resnet_50 \
--from_scratch True \
--resume finally_pruned_model/resnet_50_1.pt \
--num_workers 4 \
--epochs 100 \
--lr 0.01 \
--lr_decay_step 30,60,90 \
--batch_size 64 \
--weight_decay 0.0001 \
--input_size 224 \
--data_dir [dataset dir] \
--dataset ImageNet \
--save_id 1

Acknowledgments

Our implementation partially reuses HRank's code.