Home

Awesome

RVT: Towards Robust Vision Transformer

News: We add adversarial training result of RVT here !!

This repository contains PyTorch code for Robust Vision Transformers.

Note: Since the model is trained on our private platform, this transferred code has not been tested and may have some bugs. If you meet any problems, feel free to open an issue!

RVT

For details see our paper "Towards Robust Vision Transformer"

Usage

First, clone the repository locally:

git clone https://github.com/vtddggg/Robust-Vision-Transformer.git

Install PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models 0.3.2:

conda install -c pytorch pytorch torchvision
pip install timm==0.3.2

In addition, einops and kornia is required for using this implementation:

pip install einops
pip install kornia

We use 4 nodes with 8 gpus to train RVT-Ti, RVT-S and RVT-B:

Training

RVT-Ti:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=4 main.py --model rvt_tiny --data-path /path/to/imagenet --output_dir output --dist-eval

RVT-S:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=4 main.py --model rvt_small --data-path /path/to/imagenet --output_dir output --dist-eval

RVT-B:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=4 main.py --model rvt_base --data-path /path/to/imagenet --output_dir output --batch-size 32 --dist-eval

You can also finetune the pretrained model by adding --pretrained.

If you want to train RVT-Ti*, RVT-S* or RVT-B*, simply specify --model as rvt_tiny_plus, rvt_small_plus or rvt_base_plus, then add --use_patch_aug to enable patch-wise augmentation.

Testing

News: The robustness evaluation now is supported!! Because of the environmental differences, the results of robustness may have the fluctuations of ±0.1~0.3% compared with paper results.

RVT-Ti:

python main.py --eval --pretrained --model rvt_tiny --data-path /path/to/imagenet

RVT-Ti*:

python main.py --eval --pretrained --model rvt_tiny_plus --data-path /path/to/imagenet

RVT-S:

python main.py --eval --pretrained --model rvt_small --data-path /path/to/imagenet

RVT-S*:

python main.py --eval --pretrained --model rvt_small_plus --data-path /path/to/imagenet

RVT-B:

python main.py --eval --pretrained --model rvt_base --data-path /path/to/imagenet

RVT-B*:

python main.py --eval --pretrained --model rvt_base_plus --data-path /path/to/imagenet

To enable robustness evaluation, please add one of --inc_path /path/to/imagenet-c, --ina_path /path/to/imagenet-a, --inr_path /path/to/imagenet-r or --insk_path /path/to/imagenet-sketch to test ImageNet-C, ImageNet-A, ImageNet-R or ImageNet-Sketch.

If you want to test the accuracy under adversarial attackers, please add --fgsm_test or --pgd_test.

Pretrained weights

Model nameFLOPsaccuracyweights
rvt_tiny1.3 G78.4link
rvt_small4.7 G81.7link
rvt_base (ImageNet-22k)17.7 G83.4link
rvt_tiny*1.3 G79.3link
rvt_small*4.7 G81.8link
rvt_base* (ImageNet-22k)17.7 G83.6link

Adversarially trained weights

Model nameclean accuracyPGD accuracyweights
adv_deit_tiny52.0526.55link
adv_rvt_tiny54.9128.1link

To test these models, run following commands:

python adv_test.py --model deit_tiny_patch16_224 --ckpt_path adv_deit_tiny.pth
python adv_test.py --model rvt_tiny --ckpt_path adv_rvt_tiny.pth