Home

Awesome

Token-Label Alignment for Vision Transformers (ICCV 2023)

Paper

This is the pytorch implementation for the paper: Token-Label Alignment for Vision Transformers.

Han Xiao*, Wenzhao Zheng*, Zheng Zhu, Jie Zhou, and Jiwen Lu

overview

Highlights

Results

ModelImage SizeParamsFLOPsTop-1 Acc.(%)Top-5 Acc.(%)
DeiT-T$224^2$5.7M1.6G72.291.3
+TL-Align$224^2$5.7M1.6G73.291.7
DeiT-S$224^2$22M4.6G79.895.0
+TL-Align$224^2$22M4.6G80.695.0
DeiT-B$224^2$86M17.5G81.895.5
+TL-Align$224^2$86M17.5G82.395.8
Swin-T$224^2$29M4.5G81.295.5
+TL-Align$224^2$29M4.5G81.495.7
Swin-S$224^2$50M8.8G83.096.3
+TL-Align$224^2$50M8.8G83.496.5
Swin-B$224^2$88M15.4G83.596.4
+TL-Align$224^2$88M15.4G83.796.5

Usage

Prerequisites

This repository is built upon the Timm library and the DeiT repository.

You need to install PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models 0.3.2:

conda install -c pytorch pytorch torchvision
pip install timm==0.3.2

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is the standard layout for the torchvision datasets.ImageFolder, and the training and validation data are expected to be in the train/ folder and val folder respectively:

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

Training by token-label alignment

To enable token-label alignment during training, you can simply add a --tl-align in your training script. For example, for DeiT-small, run:

python3 -m torch.distributed.launch --nproc_per_node=8 --use_env main_tla.py \
--model deit_small_patch16_224 \
--batch-size 128    \
--mixup 0.0 \
--tl-align \
--data-path /path/to/imagenet  \
--output_dir /path/to/output  \

or

bash train_deit_small_tla.sh

This should give 80.6% top-1 accuracy after 300 epochs of training.

Evaluation

The evaluation of models trained by our token-label alignment is the same as timm. You can also find your validation accuracy during training.

For Deit-small, run:

python main_tla.py --eval --resume checkpoint.pth --model deit_small_patch16_224 --data-path /path/to/imagenet

Citation

If you find this project useful in your research, please cite:

@article{xiao2022token,
    title={Token-Label Alignment for Vision Transformers},
    author={Xiao, Han and Zheng, Wenzhao and Zhu, Zheng and Zhou, Jie and Lu, Jiwen},
    journal={arXiv preprint arXiv:2210.06455},
    year={2022}
}