Home

Awesome

Dynamic Token Normalization Improves Vision Transfromers, ICLR 2022

This is the PyTorch implementation of the paper Dynamic Token Normalization Improves Vision Transformers in ICLR 2022.

Dynamic Token Normalization

We design a novel normalization method, termed Dynamic Token Normalization (DTN), which inherits the advantages from LayerNorm and InstanceNorm. DTN can be seamlessly plugged into various transformer models, consistenly improving the performance.

<div align=center><img src="DTN_token.png" width="1080" height="210"></div>

News

2022-5-20 We release the code of DTN in training ViT and PVT. More models with DTN will be released soon.

Main Results

1. Performance on ImageNet with ViT and its variants in terms of FLOPs, Parameters, Top-1, and Top-5 accuracies. H and C denote head number and embedding.

ModelNormHCFLOPsParamsTop-1Top-5
ViT-TLN31921.26G5.7M72.291.3
ViT-T*LN41921.26G5.7M72.391.4
ViT-T*DTN41921.26G5.7M73.291.7
ViT-S*LN63844.60G22.1M79.995.0
ViT-S*DTN63844.88G22.1M80.695.3
ViT-B*LN1676817.58G86.5M81.795.0
ViT-B*DTN1676818.13G86.5M82.596.1

2. Comparison between various normalizers in terms of Top-1 accuracy on ImageNet. ScN and PN denote ScaleNorm and PowerNorm, respectively.

ModelLNBNINGNSNScNPNDTN
ViT-S79.977.377.778.380.180.079.880.6
ViT-S*80.677.277.679.581.080.680.481.7

3. Visualization of attention distance for each head in ViT-S. Many heads in ViT-S with DTN have a small mean attention distance. Hence, DTN can capture local context well.

<div align=center><img src="DTN_Head.png" width="1080" height="210"></div>

Getting Started

Requirements

conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch
pip install timm==0.4.9

Data Preparation

Training a model from scratch

An example to train our DTN is given in DTN/scripts/train.sh. To train ViT-S* with our DTN,

cd DTN/scripts   
sh train.sh layer vit_norm_s_star configs/ViT/vit.yaml

Number of GPUs and configuration file to use can be modified in train.sh

License

DTN is released under BSD 3-Clause License.

Acknowledgement

Our code is based on the implementation of timm package in PyTorch Image Models, https://github.com/rwightman/pytorch-image-models.

Citation

If our code is helpful to your work, please cite:

@article{shao2021dynamic,
  title={Dynamic Token Normalization Improves Vision Transformer},
  author={Shao, Wenqi and Ge, Yixiao and Zhang, Zhaoyang and Xu, Xuyuan and Wang, Xiaogang and Shan, Ying and Luo, Ping},
  journal={arXiv preprint arXiv:2112.02624},
  year={2021}
}