Home

Awesome

Dynamic-Vision-Transformer (NeurIPS 2021)

This repo contains the official PyTorch code and pre-trained models for the Dynamic Vision Transformer (DVT).

We also provide an implementation under the MindSpore framework and train DVT on a cluster of Ascend AI processors. Code and pre-trained models will be available at here.

Update on 2021/10/02: Release the Training Code.

Update on 2021/06/01: Release Pre-trained Models and the Inference Code on ImageNet.

Introduction

<p align="center"> <img src="figures/examples.png" width= "400"> </p>

We develop a Dynamic Vision Transformer (DVT) to automatically configure a proper number of tokens for each individual image, leading to a significant improvement in computational efficiency, both theoretically and empirically.

<p align="center"> <img src="figures/overview.png" width= "810"> </p>

Citation

If you find this work valuable or use our code in your own research, please consider citing us with the following bibtex:

@inproceedings{wang2021not,
        title = {Not All Images are Worth 16x16 Words: Dynamic Transformers for Efficient Image Recognition},
       author = {Wang, Yulin and Huang, Rui and Song, Shiji and Huang, Zeyi and Huang, Gao},
    booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
         year = {2021}
}

Results

<p align="center"> <img src="figures/result_main.png" width= "810"> </p> <p align="center"> <img src="figures/cifar.png" width= "500"> </p> <p align="center"> <img src="figures/result_speed.png" width= "400"> </p> <p align="center"> <img src="figures/result_visual.png" width= "700"> </p>

Pre-trained Models

Backbone# of Exits# of TokensLinks
T2T-ViT-1237x7-10x10-14x14Tsinghua Cloud / Google Drive
T2T-ViT-1437x7-10x10-14x14Tsinghua Cloud / Google Drive
DeiT-small37x7-10x10-14x14Tsinghua Cloud / Google Drive
**.pth.tar
├── model_state_dict: state dictionaries of the model
├── flops: a list containing the GFLOPs corresponding to exiting at each exit
├── anytime_classification: Top-1 accuracy of each exit
├── dynamic_threshold: the confidence thresholds used in budgeted batch classification
├── budgeted_batch_classification: results of budgeted batch classification (a two-item list, [0] and [1] correspond to the two coordinates of a curve)

Requirements

Data Preparation

ImageNet
├── train
│   ├── folder 1 (class 1)
│   ├── folder 2 (class 1)
│   ├── ...
├── val
│   ├── folder 1 (class 1)
│   ├── folder 2 (class 1)
│   ├── ...

Evaluate Pre-trained Models

CUDA_VISIBLE_DEVICES=0 python inference.py --model {DVT_T2t_vit_12, DVT_T2t_vit_14, DVT_Deit_small} --checkpoint_path PATH_TO_CHECKPOINT  --eval_mode 0
CUDA_VISIBLE_DEVICES=0 python inference.py --data_url PATH_TO_DATASET --batch_size 64 --model {DVT_T2t_vit_12, DVT_T2t_vit_14, DVT_Deit_small} --checkpoint_path PATH_TO_CHECKPOINT  --eval_mode 1
CUDA_VISIBLE_DEVICES=0 python inference.py --data_url PATH_TO_DATASET --batch_size 64 --model {DVT_T2t_vit_12, DVT_T2t_vit_14, DVT_Deit_small} --checkpoint_path PATH_TO_CHECKPOINT  --eval_mode 2

Train

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 main.py PATH_TO_DATASET --model DVT_T2t_vit_12 --b 128 --lr 2e-3 --weight-decay .03 --amp --img-size 224
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 main.py PATH_TO_DATASET --model DVT_T2t_vit_14 --b 64 --lr 5e-4 --weight-decay .05 --amp --img-size 224

Transfer DVT to CIFAR-10/100

We finetune our pretrained DVT_T2t_vit_12/14 to CIFAR-10/100 in the same way as T2T-ViT.

Contact

This is a re-implementation version. If you have any question, please feel free to contact the authors. Yulin Wang: wang-yl19@mails.tsinghua.edu.cn.

Acknowledgment

Our code of T2T-ViT is from here. Our code of DeiT is from here.