Home

Awesome

How to Train Vision Transformer on Small-scale Datasets? [BMVC'22]

Hanan Gani, Muzammal Naseer, and Mohammad Yaqub

paper Poster Slides Video

Abstract: Vision Transformer (ViT), a radically different architecture than convolutional neural networks offers multiple advantages including design simplicity, robustness and state-of-the-art performance on many vision tasks. However, in contrast to convolutional neural networks, Vision Transformer lacks inherent inductive biases. Therefore, successful training of such models is mainly attributed to pre-training on large-scale datasets such as ImageNet with 1.2M or JFT with 300M images. This hinders the direct adaption of Vision Transformer for small-scale datasets. In this work, we show that self-supervised inductive biases can be learned directly from small-scale datasets and serve as an effective weight initialization scheme for fine tuning. This allows to train these models without large scale pre-training, changes to model architecture or loss functions. We present thorough experiments to successfully train monolithic and non-monolithic Vision Transformers on five small datasets including CIFAR10/100, CINIC-10, SVHN, Tiny-ImageNet and two fine-grained datasets: Aircarft and Cars. Our approach consistently improves the performance while retaining their properties such as attention to salient regions and higher robustness.

<hr>

Contents

  1. What's New?
  2. Highlights
  3. Model Zoo
  4. Requirements
  5. Self-Supervised Training
  6. Supervised Training
  7. Results
  8. Citation
  9. Contact
  10. References
<hr>

What's New?

(September 30, 2022)

(September 06, 2022)

(August 09, 2022)

(August 08, 2022)

<hr>

Highlights

  1. Vision Transformers, whether monolithic or non-monolithic, both suffer when trained from scratch on small datasets. This is primarily due to the lack of locality, inductive biases and hierarchical structure of the representations which is commonly observed in the Convolutional Neural Networks. As a result, ViTs require large-scale pre-training to learn such properties from the data for better transfer learning to downstream tasks. We show that inductive biases can be learned directly from the small dataset through self-supervision, thus serving as an effective weight initialization for finetuning on the same dataset.
<!-- <img src="assets/final_main_figure.png" height="500" width="700"> --> <!-- ![main_figure](assets/final_main_figure.png) -->
  1. Our proposed self-supervised inductive biases improve the performance of ViTs on small datasets without modifying the network architecture or loss functions.
<hr>

Model Zoo

DatasetInput SizeModelPretrained Weights
CIFAR1032x32ViTLink
CIFAR1032x32SwinLink
CIFAR1032x32CaiTLink
CIFAR10032x32ViTLink
CIFAR10032x32SwinLink
CIFAR10032x32CaiTLink
CINIC1032x32ViTLink
CINIC1032x32SwinLink
CINIC1032x32CaiTLink
SVHN32x32ViTLink
SVHN32x32SwinLink
SVHN32x32CaiTLink
Tiny-Imagenet64x64ViTLink
Tiny-Imagenet64x64SwinLink
Tiny-Imagenet64x64CaiTLink
<hr>

Requirements

pip install -r requirements.txt
<hr>

Self-supervised Training

For Tiny-Imagenet:

With ViT architecture

python -m torch.distributed.launch --nproc_per_node=2 train_ssl.py --arch vit \
                                   --dataset Tiny_Imagenet --image_size 64 \
                                   --datapath "/path/to/tiny-imagenet/train/folder" \
                                   --patch_size 8  \
                                   --mlp_head_in 192 \
                                   --local_crops_number 8 \
                                   --local_crops_scale 0.2 0.4 \
                                   --global_crops_scale 0.5 1. 
                                   --out_dim 1024 \
                                   --batch_size_per_gpu 256  \
                                   --output_dir "/path/for/saving/checkpoints"

With Swin architecture

python -m torch.distributed.launch --nproc_per_node=2 train_ssl.py --arch swin \
                                   --dataset Tiny_Imagenet --image_size 64 \
                                   --datapath "/path/to/tiny-imagenet/train/folder" \
                                   --patch_size 4  \
                                   --mlp_head_in 384 \
                                   --local_crops_number 8 \
                                   --local_crops_scale 0.2 0.4 \
                                   --global_crops_scale 0.5 1. 
                                   --out_dim 1024 \
                                   --batch_size_per_gpu 256  \
                                   --output_dir "/path/for/saving/checkpoints"

For CIFAR based datasets:

With ViT architecture

python -m torch.distributed.launch --nproc_per_node=2 train_ssl.py --arch vit \
                                   --dataset CIFAR10 --image_size 32 \
                                   --patch_size 4  \
                                   --mlp_head_in 192  \
                                   --local_crops_number 8 \
                                   --local_crops_scale 0.2 0.5 \
                                   --global_crops_scale 0.7 1. 
                                   --out_dim 1024 \
                                   --batch_size_per_gpu 256  \
                                   --output_dir "/path/for/saving/checkpoints"

With Swin architecture

python -m torch.distributed.launch --nproc_per_node=2 train_ssl.py --arch swin \
                                   --dataset Tiny_Imagenet --image_size 32 \
                                   --datapath "/path/to/tiny-imagenet/train/folder" \
                                   --patch_size 2  \
                                   --mlp_head_in 384  \
                                   --local_crops_number 8 \
                                   --local_crops_scale 0.2 0.5 \
                                   --global_crops_scale 0.7 1. 
                                   --out_dim 1024 \
                                   --batch_size_per_gpu 256  \
                                   --output_dir "/path/for/saving/checkpoints"

--dataset can be Tiny_Imagenet/CIFAR10/CIFAR100/CINIC/SVHN.

--arch can be vit/swin/cait.

--local_crops_scale and --global_crops_scale vary based on the dataset used.

--mlp_head_in is dimension of the Vision transformer output going into Projection MLP head and varies based on the model used. For ViT/CaiT, keep --mlp_head_in=192. For Swin, keep --mlp_head_in=384

<hr>

Supervised Training

python finetune.py --arch vit  \
                   --dataset Tiny-Imagenet \
                   --datapath "/path/to/data/folder" \
                   --batch_size 256 \
                   --epochs 100 \
                   --pretrained_weights "/path/to/saved/checkpoint"

--arch can be vit/swin/cait . --datasets can be Tiny-Imagenet/CIFAR10/CIFAR100/CINIC/SVHN . Load the corresponding weights for finetuning.

<hr>

Results

We test our approach on 5 small low resolution datasets: Tiny-Imagenet, CIFAR10, CIFAR100, CINIC10 and SVHN. We compare the results of our approach with 4 baselines: ConvNets, Scratch ViT training, Efficient Training of Visual Transformers with Small Datasets (NIPS'21), Vision Transformer for Small-Size Datasets (arXiv'21)

1. Quantitative results :

main_results

2. Results on high resolution inputs as compared to baseline - Efficient Training of Visual Transformers with Small Datasets (NIPS'21)

results_nips

3. Qualitative results - Attention to salient regions

Our proposed self-supervised training is able to capture the shape of the salient objects efficiently with minimal or no attention to the background on unseen test-set samples without any supervision.

<!-- <img src="assets/atten_maps_paper.png" height="250" width="275"> -->

arxiv_heatmaps_figure

<hr>

Citation

If you use our work, please consider citing:

@inproceedings{Gani_2022_BMVC,
author    = {Hanan Gani and Muzammal Naseer and Mohammad Yaqub},
title     = {How to Train Vision Transformer on Small-scale Datasets?},
booktitle = {33rd British Machine Vision Conference 2022, {BMVC} 2022, London, UK, November 21-24, 2022},
publisher = {{BMVA} Press},
year      = {2022},
url       = {https://bmvc2022.mpi-inf.mpg.de/0731.pdf}
}
<hr>

Contact

Should you have any questions, please create an issue in this repository or contact at hanan.ghani@mbzuai.ac.ae

<hr>

References

Our code is build on the repositories of DINO and Vision Transformer for Small-Size Datasets. We thank them for releasing their code.

<hr>