Home

Awesome

Swin for win!

TensorFlow 2.8 Models on TF-Hub

This repository provides TensorFlow / Keras implementations of different Swin Transformer [1, 2] variants by Liu et al. and Chen et al. It also provides the TensorFlow / Keras models that have been populated with the original Swin pre-trained params available from [3, 4]. These models are not blackbox SavedModels i.e., they can be fully expanded into tf.keras.Model objects and one can call all the utility functions on them (example: .summary()).

Refer to the "Using the models" section to get started.

I find Swin Transformers interesting because they induce a sense of hierarchies by using shifted windows. Multi-scale representations like that are crucial to get good performance in tasks like object detection and segmentation. teaser <sup><a href=https://github.com/microsoft/Swin-Transformer>Source</a></sup>

"Swin for win!" however doesn't portray my architecture bias -- I found it cool and hence kept it.

Table of contents

Conversion

TensorFlow / Keras implementations are available in swins/models.py. All model configurations are in swins/model_configs.py. Conversion utilities are in convert.py. To run the conversion utilities, first install all the dependencies listed in requirements.txt. Additionally, nnstall timm from source:

pip install -q git+https://github.com/rwightman/pytorch-image-models

Models

Find the models on TF-Hub here: https://tfhub.dev/sayakpaul/collections/swin/1. You can fully inspect the architecture of the TF-Hub models like so:

import tensorflow as tf

model_gcs_path = "gs://tfhub-modules/sayakpaul/swin_tiny_patch4_window7_224/1/uncompressed"
model = tf.keras.models.load_model(model_gcs_path)

dummy_inputs = tf.ones((2, 224, 224, 3))
_ = model(dummy_inputs)
print(model.summary(expand_nested=True))

Results

The table below provides a performance summary (ImageNet-1k validation set):

model_nametop1_acc(%)top5_acc(%)orig_top1_acc(%)
swin_base_patch4_window7_22485.13497.4885.2
swin_large_patch4_window7_22486.25297.87886.3
swin_s3_base_22483.95896.53284
swin_s3_small_22483.64896.35883.7
swin_s3_tiny_22482.03495.86482.1
swin_small_patch4_window7_22483.17896.2483.2
swin_tiny_patch4_window7_22481.18495.51281.2
swin_base_patch4_window12_38486.42898.04286.4
swin_large_patch4_window12_38487.27298.24287.3

The base and large models were first pre-trained on the ImageNet-22k dataset and then fine-tuned on the ImageNet-1k dataset.

in1k-eval directory provides details on how these numbers were generated. Original scores for all the models except for the s3 ones were gathered from here. Scores for the s3 model were gathered from here.

Using the models

Pre-trained models:

When doing transfer learning try using the models that were pre-trained on the ImageNet-22k dataset. All the base and large models listed here were pre-trained on the ImageNet-22k dataset. Refer to the model collection page on TF-Hub to know more.

These models also output attention weights from each of the Transformer blocks. Refer to this notebook for more details. Additionally, the notebook shows how to obtain the attention maps for a given image.

Randomly initialized models:

import tensorflow as tf

from swins import SwinTransformer

cfg = dict(
    patch_size=4,
    window_size=7,
    embed_dim=128,
    depths=(2, 2, 18, 2),
    num_heads=(4, 8, 16, 32),
)
 
swin_base_patch4_window7_224 = SwinTransformer(
    name="swin_base_patch4_window7_224", **cfg
)
print("Model instantiated, attempting predictions...")
random_tensor = tf.random.normal((2, 224, 224, 3))
outputs = swin_base_patch4_window7_224(random_tensor, training=False)

print(outputs.shape)

print(swin_base_patch4_window7_224.count_params() / 1e6)

To initialize a network with say, 5 classes do:

cfg = dict(
    patch_size=4,
    window_size=7,
    embed_dim=128,
    depths=(2, 2, 18, 2),
    num_heads=(4, 8, 16, 32),
    num_classes=5,
)

swin_base_patch4_window7_224 = SwinTransformer(
    name="swin_base_patch4_window7_224", **cfg
)

To view different model configurations, refer to swins/model_configs.py.

References

[1] Swin Transformer: Hierarchical Vision Transformer using Shifted Windows Liu et al.

[2] Searching the Search Space of Vision Transformer by Chen et al.

[3] Swin Transformers GitHub

[4] AutoFormerV2 GitHub

Acknowledgements