Awesome
Out-of-distribution Generalization Investigation on Vision Transformers
This repository contains PyTorch evaluation code for CVPR 2022 accepted paper Delving Deep into the Generalization of Vision Transformers under Distribution Shifts.
Taxonomy of Distribution Shifts
<p align="middle"> <img src="https://github.com/Phoenix1153/ViT_OOD_generalization/raw/main/img/taxonomy.png" width="60%"> <p>Illustration of our taxonomy of distribution shifts. We build the taxonomy upon what kinds of semantic concepts are modified from the original image and divide the distribution shifts into four cases: background shifts, corruption shifts, texture shifts, and style shifts. <img src="http://latex.codecogs.com/gif.latex?{\color{Red} \checkmark}" /> denotes the unmodified vision cues under certain type of distribution shifts. Please refer to the literature for details.
Dataset
We build OOD-Net, a collection constituted of data under four types of distribution shift and their in-distribution counterparts, for comprehensive investigation into model out-out-distribution generalization properties. The download link is available here.
Dataset | Shift Type |
---|---|
ImageNet-9 | Background Shift |
ImageNet-C | Corruption Shift |
Cue Conflict Stimuli | Texture Shift |
Stylized-ImageNet | Texture Shift |
ImageNet-R | Style Shift |
DomainNet | Style Shift |
Generalization-Enhanced Vision Transformers
<p align="middle"> <img src="https://github.com/Phoenix1153/ViT_OOD_generalization/raw/main/img/new_DANN-1.png" width="40%"> <img src="https://github.com/Phoenix1153/ViT_OOD_generalization/raw/main/img/new_MME-1.png" width="45%"> <p> <p align="middle"> <img src="https://github.com/Phoenix1153/ViT_OOD_generalization/raw/main/img/new_SSL-1.png" width="90%"> <p>A framework overview of the three designed generalization-enhanced ViTs. All networks use a Vision Transformer <img src="http://latex.codecogs.com/gif.latex?F" /> as feature encoder and a label prediction head <img src="http://latex.codecogs.com/gif.latex?C" /> . Under this setting, the inputs to the models have labeled source examples and unlabeled target examples. top left: T-ADV promotes the network to learn domain-invariant representations by introducing a domain classifier <img src="http://latex.codecogs.com/gif.latex?D" /> for domain adversarial training. top right: T-MME leverage the minimax process on the conditional entropy of target data to reduce the distribution gap while learning discriminative features for the task. The network uses a cosine similarity-based classifier architecture <img src="http://latex.codecogs.com/gif.latex?C" /> to produce class prototypes. bottom: T-SSL is an end-to-end prototype-based self-supervised learning framework. The architecture uses two memory banks <img src="http://latex.codecogs.com/gif.latex?V^s" /> and <img src="http://latex.codecogs.com/gif.latex?V^t" /> to calculate cluster centroids. A cosine classifier <img src="http://latex.codecogs.com/gif.latex?C" /> is used for classification in this framework.
Run Our Code
Environment Installation
conda create -n vit python=3.6 conda activate vit conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.0 -c pytorch
Before Running
conda activate vit PYTHONPATH=$PYTHONPATH:.
Evaluation
CUDA_VISIBLE_DEVICES=0 python main.py \ --model deit_small_b16_384 \ --num-classes 345 \ --checkpoint data/checkpoints/deit_small_b16_384_baseline_real.pth.tar \ --meta-file data/metas/DomainNet/sketch_test.jsonl \ --root-dir data/images/DomainNet/sketch/test
Experimental Results
DomainNet
DeiT_small_b16_384
confusion matrix for the baseline model
clipart | painting | real | sketch | |
---|---|---|---|---|
clipart | 80.25 | 33.75 | 55.26 | 43.43 |
painting | 36.89 | 75.32 | 52.08 | 31.14 |
real | 50.59 | 45.81 | 84.78 | 39.31 |
sketch | 52.16 | 35.27 | 48.19 | 71.92 |
Above used models could be found here.
Remarks
-
These results may slightly differ from those in our paper due to differences of the environments.
-
We will continuously update this repo.
Citation
If you find these investigations useful in your research, please consider citing:
@article{zhang2021delving,
title={Delving deep into the generalization of vision transformers under distribution shifts},
author={Zhang, Chongzhi and Zhang, Mingyuan and Zhang, Shanghang and Jin, Daisheng and Zhou, Qiang and Cai, Zhongang and Zhao, Haiyu and Yi, Shuai and Liu, Xianglong and Liu, Ziwei},
journal={arXiv preprint arXiv:2106.07617},
year={2021}
}