Home

Awesome

Global Context Vision Transformer (GC ViT)

This repository presents the official PyTorch implementation of Global Context Vision Transformers (ICML2023)

Global Context Vision Transformers
Ali Hatamizadeh, Hongxu (Danny) Yin, Greg Heinrich, Jan Kautz, and Pavlo Molchanov.

GC ViT achieves state-of-the-art results across image classification, object detection and semantic segmentation tasks. On ImageNet-1K dataset for classification, GC ViT variants with 51M, 90M and 201M parameters achieve 84.3, 85.9 and 85.7 Top-1 accuracy, respectively, surpassing comparably-sized prior art such as CNN-based ConvNeXt and ViT-based Swin Transformer.

<p align="center"> <img src="https://github.com/NVlabs/GCVit/assets/26806394/d1820d6d-3aef-470e-a1d3-af370f1c1f77" width=63% height=63% class="center"> </p>

The architecture of GC ViT is demonstrated in the following:

gc_vit

💥 News 💥

Introduction

GC ViT leverages global context self-attention modules, joint with local self-attention, to effectively yet efficiently model both long and short-range spatial interactions, without the need for expensive operations such as computing attention masks or shifting local windows.

<p align="center"> <img src="https://github.com/NVlabs/GCVit/assets/26806394/da64f22a-e7af-4577-8884-b08ba4e24e49" width=72% height=72% class="center"> </p>

ImageNet Benchmarks

ImageNet-1K Pretrained Models

<table> <tr> <th>Model Variant</th> <th>Acc@1</th> <th>#Params(M)</th> <th>FLOPs(G)</th> <th>Download</th> </tr> <tr> <td>GC ViT-XXT</td> <th>79.9</th> <td>12</td> <td>2.1</td> <td><a href="https://drive.google.com/uc?export=download&id=1apSIWQCa5VhWLJws8ugMTuyKzyayw4Eh">model</a></td> </tr> <tr> <td>GC ViT-XT</td> <th>82.0</th> <td>20</td> <td>2.6</td> <td><a href="https://drive.google.com/uc?export=download&id=1OgSbX73AXmE0beStoJf2Jtda1yin9t9m">model</a></td> </tr> <tr> <td>GC ViT-T</td> <th>83.5</th> <td>28</td> <td>4.7</td> <td><a href="https://drive.google.com/uc?export=download&id=11M6AsxKLhfOpD12Nm_c7lOvIIAn9cljy">model</a></td> </tr> <tr> <td>GC ViT-T2</td> <th>83.7</th> <td>34</td> <td>5.5</td> <td><a href="https://drive.google.com/uc?export=download&id=1cTD8VemWFiwAx0FB9cRMT-P4vRuylvmQ">model</a></td> </tr> <tr> <td>GC ViT-S</td> <th>84.3</th> <td>51</td> <td>8.5</td> <td><a href="https://drive.google.com/uc?export=download&id=1Nn6ABKmYjylyWC0I41Q3oExrn4fTzO9Y">model</a></td> </tr> <tr> <td>GC ViT-S2</td> <th>84.8</th> <td>68</td> <td>10.7</td> <td><a href="https://drive.google.com/uc?export=download&id=1E5TtYpTqILznjBLLBTlO5CGq343RbEan">model</a></td> </tr> <tr> <td>GC ViT-B</td> <th>85.0</th> <td>90</td> <td>14.8</td> <td><a href="https://drive.google.com/uc?export=download&id=1PF7qfxKLcv_ASOMetDP75n8lC50gaqyH">model</a></td> </tr> <tr> <td>GC ViT-L</td> <th>85.7</th> <td>201</td> <td>32.6</td> <td><a href="https://drive.google.com/uc?export=download&id=1Lkz1nWKTwCCUR7yQJM6zu_xwN1TR0mxS">model</a></td> </tr> </table>

ImageNet-21K Pretrained Models

<table> <tr> <th>Model Variant</th> <th>Resolution</th> <th>Acc@1</th> <th>#Params(M)</th> <th>FLOPs(G)</th> <th>Download</th> </tr> <tr> <td>GC ViT-L</td> <td>224 x 224</td> <th>86.6</th> <td>201</td> <td>32.6</td> <td><a href="https://drive.google.com/uc?export=download&id=1maGDr6mJkLyRTUkspMzCgSlhDzNRFGEf">model</a></td> </tr> <tr> <td>GC ViT-L</td> <td>384 x 384</td> <th>87.4</th> <td>201</td> <td>120.4</td> <td><a href="https://drive.google.com/uc?export=download&id=1P-IEhvQbJ3FjnunVkM1Z9dEpKw-tsuWv">model</a></td> </tr> <tr> <td>GC ViT-L</td> <td>512 x 512</td> <th>87.6</th> <td>201</td> <td>245.0</td> <td><a href="https://huggingface.co/nvidia/GCViT/resolve/main/gcvit_21k_large_512.pth.tar">model</a></td> </tr> </table>

Installation

The dependencies can be installed by running:

pip install -r requirements.txt

Data Preparation

Please download the ImageNet dataset from its official website. The training and validation images need to have sub-folders for each class with the following structure:

  imagenet
  ├── train
  │   ├── class1
  │   │   ├── img1.jpeg
  │   │   ├── img2.jpeg
  │   │   └── ...
  │   ├── class2
  │   │   ├── img3.jpeg
  │   │   └── ...
  │   └── ...
  └── val
      ├── class1
      │   ├── img4.jpeg
      │   ├── img5.jpeg
      │   └── ...
      ├── class2
      │   ├── img6.jpeg
      │   └── ...
      └── ...
 

Commands

Training on ImageNet-1K From Scratch (Multi-GPU)

The GC ViT model can be trained on ImageNet-1K dataset by running:

python -m torch.distributed.launch --nproc_per_node <num-of-gpus> --master_port 11223  train.py \ 
--config <config-file> --data_dir <imagenet-path> --batch-size --amp <batch-size-per-gpu> --tag <run-tag> --model-ema

To resume training from a pre-trained checkpoint:

python -m torch.distributed.launch --nproc_per_node <num-of-gpus> --master_port 11223  train.py \ 
--resume <checkpoint-path> --config <config-file> --amp --data_dir <imagenet-path> --batch-size <batch-size-per-gpu> --tag <run-tag> --model-ema

Evaluation

To evaluate a pre-trained checkpoint using ImageNet-1K validation set on a single GPU:

python validate.py --model <model-name> --checkpoint <checkpoint-path> --data_dir <imagenet-path> --batch-size <batch-size-per-gpu>

Citation

Please consider citing GC ViT paper if it is useful for your work:

@inproceedings{hatamizadeh2023global,
  title={Global context vision transformers},
  author={Hatamizadeh, Ali and Yin, Hongxu and Heinrich, Greg and Kautz, Jan and Molchanov, Pavlo},
  booktitle={International Conference on Machine Learning},
  pages={12633--12646},
  year={2023},
  organization={PMLR}
}

Third-party Implementations and Resources

In this section, we list third-party contributions by other users. If you would like to have your work included here, please raise an issue in this repository.

NameLinkContributorFramework
timmLink@rwightmanPyTorch
tfgcvitLink@shkarupa-alexTensorflow 2.0 (Keras)
gcvit-tfLink@awsaf49Tensorflow 2.0 (Keras)
GCViT-TensorFlowLink@EMalagoli92Tensorflow 2.0 (Keras)
keras_cv_attention_modelsLink@leondgarseKeras
flaimLink@BobMcDearJAX/Flax

Additional Resources

We list additional GC ViT resources such as notebooks, demos, paper explanations in this section. If you have created similar items and would like to be included, please raise an issue in this repository.

NameLinkContributorNote
Paper ExplanationLink@awsaf49Annotated GC ViT
Colab NotebookLink@awsaf49Flower classification
Kaggle NotebookLink@awsaf49Flower classification
Live DemoLink@awsaf49Hugging Face demo

Licenses

Copyright © 2023, NVIDIA Corporation. All rights reserved.

This work is made available under the Nvidia Source Code License-NC. Click here to view a copy of this license.

The pre-trained models are shared under CC-BY-NC-SA-4.0. If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.

For license information regarding the timm, please refer to its repository.

For license information regarding the ImageNet dataset, please refer to the ImageNet official website.

Acknowledgement