Awesome
SAViT: Structure-Aware Vision Transformer Pruning via Collaborative Optimization [NeurIPS 2022]
This repository contains PyTorch implementations for SAViT. For details see SAViT: Structure-Aware Vision Transformer Pruning via Collaborative Optimization
SAViT is a structured pruning method to prune the following comopnents in vision trnasofmrer and reaches 2.05x speedup with only 0.2% accuracy loss.
Setup
Step 1: Create a new conda environment:
conda create -n savit python=3.8
conda activate savit
Step 2: Install relevant packages
cd /path to deit_savit
pip install -r requirements.txt
Data preparation
Download and extract ImageNet train and val images from http://image-net.org/.
The directory structure is the standard layout for the torchvision datasets.ImageFolder
, and the training and validation data is expected to be in the train/
folder and val
folder respectively:
/path to imagenet
train/
class1/
img1.jpeg
class2/
img2.jpeg
val/
class1/
img3.jpeg
class/2
img4.jpeg
Prune
The scripts folder contrains all the bash commands to replicate the main results in our paper:
Running following command for pruning deit-base will give you a pruned model with mask indicates which neuron or head should be removed, corresponding to the results in Table 2.
<details> <summary> Prune deit-base 50% FLOPs </summary>python main.py \
--finetune=/path to deit_base checkpoint \
--batch-size=32 \
--num_workers=16 \
--data-path=/path to ImageNet \
--model=deit_base_patch16_224 \
--pruning_per_iteration=100 \
--pruning_feed_percent=0.1 \
--pruning_method=2 \
--pruning_layers=3 \
--pruning_flops_percentage=0.50 \
--pruning_flops_threshold=0.0001 \
--need_hessian \
--finetune_op=2 \
--epochs=1 \
--output_dir=/path to output
</details>
You can change FLOPs reduction or model as you wish. If you have already get pruning importance metric, you can simply load them by setting:
--pruning_pickle_from=/path to importance
For help information of the arguments please see main.py.
Fine-tune
For deit-base after pruning, we need to retrain the pruned model to recover their performance. Run following command for fine-tuning on ImageNet on a single node with 8 gpus with a total batch size of 1024 for 300 epochs.
<details> <summary> Fine-tune pruned DeiT-base </summary>GPU_NUM=8
output_dir=/path to output
ck_dir=$output_dir/checkpoint.pth
# check if checkpoint exists
if [ -e $ck_dir ];then
CMD="--resume=${ck_dir}"
else
CMD="--resume="
fi
python -m torch.distributed.launch --nproc_per_node=${GPU_NUM} --use_env main_deploy.py \
--dist-eval \
$CMD \
--masked_model=/path to pruned_model in previous step prune \
--teacher-path=/path to regnet model as deit paper\
--batch-size=128\
--num_workers=16 \
--data-path=/path to ImageNet \
--model=deit_base_patch16_224_deploy \
--pruning_flops_percentage=0 \
--finetune_op=1 \
--epochs=300 \
--warmup-epochs=0 \
--cooldown-epochs=0 \
--output_dir=$output_dir
</details>
Note: fine-tuning is runing the main_deploy.py, which generates a smaller model according to the pruning mask in pruned model from previous step prune to accelerate fine-tuning.
To ease reproduction of our results we provide prune and finetune logs folder. The slight difference between results in logs and results in our paper comes from PyTorch version.
Acknowledgement
Our repository is built on the Deit, Taylor_pruning, Timm and flops-counter, we sincerely thank the authors for their nicely organized code!
License
This repository is released under the Apache 2.0 license as found in the LICENSE file.
Citation
If you find this repository helpful, please cite:
@article{zheng2022savit,
title={SAViT: Structure-Aware Vision Transformer Pruning via Collaborative Optimization},
author={Zheng, Chuanyang and Zhang, Kai and Yang, Zhi and Tan, Wenming and Xiao, Jun and Ren, Ye and Pu, Shiliang and others},
journal={Advances in Neural Information Processing Systems},
volume={35},
pages={9010--9023},
year={2022}
}