Awesome
Width and Depth Pruning for Vision Transformers
This is the official implementation of the AAAI 2022 paper [Width and Depth Pruning for Vision Transformers] (https://www.aaai.org/AAAI22Papers/AAAI-2102.FangYu.pdf)
Installation
Requirements
- torch>=1.8.0
- torchvision>=0.9.0
- timm==0.4.9
- h5py
- scipy
- scikit-learn
Data preparation: download and extract ImageNet images from http://image-net.org/. The directory structure should be
│ILSVRC2012/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
Model preparation: download pre-trained DeiT models for pruning:
sh download_pretrain.sh
Demo
Training on ImageNet
To train DeiT models on ImageNet, run:
DeiT-Base
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --master_port 29500 --use_env main_wdpruning.py --arch deit_base --data-set IMNET --batch-size 128 --data-path ../data/ILSVRC2012/ --output_dir logs --classifier 10 --R_threshold 0.8
Training on CIFAR-10
To train DeiT models on CIFAR-10, run:
Pruning width and depth for DeiT-Base
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port 29566 --use_env main_wdpruning.py --arch deit_base --data-set CIFAR10 --batch-size 128 --data-path ../data/ --output_dir logs/cifar --classifiers 10
Only pruning width for DeiT-Base
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port 29566 --use_env main_wdpruning.py --arch deit_base --data-set CIFAR10 --batch-size 128 --data-path ../data/ --output_dir logs/cifar
Pruning and Evaluation
Test the amout of parameters, GPU throughput of pruned transformer.
python masked_parameter_count.py --arch deit_base --pretrained_dir logs/checkpoint.pth --eval_batch_size 1024 --classifiers 10 --classifier_choose 10
Note that '--classifier_choose' means choose which classifier to prune. '--classifier_choose 12' means choose the last classifier.
Test the amout of parameters, CPU latency of pruned transformer.
python masked_parameter_count.py --arch deit_base --pretrained_dir logs/checkpoint.pth --no_cuda --eval_batch_size 1 --classifiers 10
Acknowledgement
Our code is built on top of Movement Pruning.
Citing
If you find these useful for your research or project, feel free to cite our paper.
@inproceedings{yu2022width,
title={Width \& Depth Pruning for Vision Transformers},
author={Yu, Fang and Huang, Kun and Wang, Meng and Cheng, Yuan and Chu, Wei and Cui, Li},
booktitle={AAAI Conference on Artificial Intelligence (AAAI)},
volume={2022},
year={2022}
}