Home

Awesome

Vision Transformer - Pytorch

Pytorch implementation of Vision Transformer. Pretrained pytorch weights are provided which are converted from original jax/flax weights. This is a project of the ASYML family and CASL.

Introduction

Figure 1 from paper

Pytorch implementation of paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. We provide the pretrained pytorch weights which are converted from pretrained jax/flax models. We also provide fine-tune and evaluation script. Similar results as in original implementation are achieved.

Installation

Create environment:

conda create --name vit --file requirements.txt
conda activate vit

Available Models

We provide pytorch model weights, which are converted from original jax/flax wieghts. You can download them and put the files under 'weights/pytorch' to use them.

Otherwise you can download the original jax/flax weights and put the fimes under 'weights/jax' to use them. We'll convert the weights for you online.

Datasets

Currently three datasets are supported: ImageNet2012, CIFAR10, and CIFAR100. To evaluate or fine-tune on these datasets, download the datasets and put them in 'data/dataset_name'.

More datasets will be supported.

Fine-Tune/Train

python src/train.py --exp-name ft --n-gpu 4 --tensorboard  --model-arch b16 --checkpoint-path weights/pytorch/imagenet21k+imagenet2012_ViT-B_16.pth --image-size 384 --batch-size 32 --data-dir data/ --dataset CIFAR10 --num-classes 10 --train-steps 10000 --lr 0.03 --wd 0.0

Evaluation

Make sure you have downloaded the pretrained weights either in '.npy' format or '.pth' format

python src/eval.py --model-arch b16 --checkpoint-path weights/jax/imagenet21k+imagenet2012_ViT-B_16.npy --image-size 384 --batch-size 128 --data-dir data/ImageNet --dataset ImageNet --num-classes 1000

Results and Models

Pretrained Results on ImageNet2012

upstreammodeldatasetorig. jax accpytorch accmodel link
imagenet21kViT-B_16imagenet201284.6283.90checkpoint
imagenet21kViT-B_32imagenet201281.7981.14checkpoint
imagenet21kViT-L_16imagenet201285.0784.94checkpoint
imagenet21kViT-L_32imagenet201282.0181.03checkpoint

Fine-Tune Results on CIFAR10/100

Due to limited GPU resources, the fine-tune results are obtained by using a batch size of 32 which may impact the performance a bit.

upstreammodeldatasetorig. jax accpytorch acc
imagenet21kViT-B_16CIFAR1098.9298.90
imagenet21kViT-B_16CIFAR10092.2691.65

TODO

Acknowledge

  1. https://github.com/google-research/vision_transformer
  2. https://github.com/lucidrains/vit-pytorch
  3. https://github.com/kamalkraj/Vision-Transformer

Contributing

Issues and Pull Requests are welcome for improving this repo. Please follow the contribution guide

License

Apache License 2.0

Supporting Companies and Universities

<p float="left"> <img src="https://raw.githubusercontent.com/asyml/forte/master/docs/_static/img/Petuum.png" width="200" align="top"> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; <img src="https://asyml.io/assets/institutions/cmu.png", width="200" align="top"> </p>