Home

Awesome

<div align="center">

<a href="https://www.tensorflow.org">TensorFLow <a href="https://github.com/EMalagoli92/GCViT-TensorFlow/blob/main/LICENSE">License <a href="https://www.python.org">Python</a>

</div>

GCViT-TensorFlow

TensorFlow 2.X reimplementation of Global Context Vision Transformers Ali Hatamizadeh, Hongxu (Danny) Yin, Jan Kautz Pavlo Molchanov.

Table of contents

<div id="abstract"/>

Abstract

GC ViT achieves state-of-the-art results across image classification, object detection and semantic segmentation tasks. On ImageNet-1K dataset for classification, the tiny, small and base variants of GC ViT with 28M, 51M and 90M, surpass comparably-sized prior art such as CNN-based ConvNeXt and ViT-based Swin Transformer by a large margin. Pre-trained GC ViT backbones in downstream tasks of object detection, instance segmentation, and semantic segmentation using MS COCO and ADE20K datasets outperform prior work consistently, sometimes by large margins.

Alt text

<p align = "center"> <sub>Top-1 accuracy vs. model FLOPs/parameter size on ImageNet-1K dataset. GC ViT achieves new SOTA benchmarks for different model sizes as well as FLOPs, outperforming competing approaches by a significant margin.</sub> </p>

Alt text

<p align = "center"><sub>Architecture of the Global Context ViT. The authors use alternating blocks of local and global context self attention layers in each stage of the architecture.</sub></p> <div id="results"/>

Results

TensorFlow implementation and ImageNet ported weights have been compared to the official PyTorch implementation on ImageNet-V2 test set.

ConfigurationTop-1 (Original)Top-1 (Ported)Top-5 (Original)Top-5 (Ported)#Params
GCViT-XXTiny68.7968.7388.5288.4712M
GCViT-XTiny70.977189.889.7920M
GCViT-Tiny72.9372.990.790.728M
GCViT-Small73.4673.591.1491.0851M
GCViT-Base74.1374.1691.6691.6990M

Mean metrics difference: 3e-4.

<div id="installation"/>

Installation

pip install gcvit-tensorflow
pip install git+https://github.com/EMalagoli92/GCViT-TensorFlow
git clone https://github.com/EMalagoli92/GCViT-TensorFlow.git
pip install -r requirements.txt

Tested on Ubuntu 20.04.4 LTS x86_64, python 3.9.7.

<div id="usage"/>

Usage

from gcvit_tensorflow import GCViT

# Define a custom GCViT configuration
model = GCViT(
    depths=[2, 2, 6, 2],
    num_heads=[2, 4, 8, 16],
    window_size=[7, 7, 14, 7],
    dim=64,
    resolution=224,
    in_chans=3,
    mlp_ratio=3,
    drop_path_rate=0.2,
    data_format="channels_last",
    num_classes=100,
    classifier_activation="softmax",
)
from gcvit_tensorflow import GCViT

model = GCViT(configuration="xxtiny")
model.build((None, 224, 224, 3))
print(model.summary())
Model: "xxtiny"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 patch_embed (PatchEmbed)    (None, 56, 56, 64)        45632     
                                                                 
 pos_drop (Dropout)          (None, 56, 56, 64)        0         
                                                                 
 levels/0 (GCViTLayer)       (None, 28, 28, 128)       185766    
                                                                 
 levels/1 (GCViTLayer)       (None, 14, 14, 256)       693258    
                                                                 
 levels/2 (GCViTLayer)       (None, 7, 7, 512)         5401104   
                                                                 
 levels/3 (GCViTLayer)       (None, 7, 7, 512)         5400546   
                                                                 
 norm (LayerNorm_)           (None, 7, 7, 512)         1024      
                                                                 
 avgpool (AdaptiveAveragePoo  (None, 512, 1, 1)        0         
 ling2D)                                                         
                                                                 
 head (Linear_)              (None, 1000)              513000    
                                                                 
=================================================================
Total params: 12,240,330
Trainable params: 11,995,428
Non-trainable params: 244,902
_________________________________________________________________
# Example
model.compile(
    optimizer="sgd",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy", "sparse_top_k_categorical_accuracy"],
)
model.fit(x, y)
# Example
from gcvit_tensorflow import GCViT

model = GCViT(configuration="base", pretrained=True, classifier_activation="softmax")
y_pred = model(image)
<div id="acknowledgement"/>

Acknowledgement

<div id="citations"/>

Citations

@article{hatamizadeh2022global,
  title={Global Context Vision Transformers},
  author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
  journal={arXiv preprint arXiv:2206.09959},
  year={2022}
}
<div id="license"/>

License

This work is made available under the MIT License

The pre-trained weights are shared under CC-BY-NC-SA-4.0