Home

Awesome

Scalable Vision Transformers with Hierarchical Pooling

License

This is the official PyTorch implementation of ICCV 2021 paper: Scalable Vision Transformers with Hierarchical Pooling.

By Zizheng Pan, Bohan Zhuang, Jing Liu, Haoyu He, and Jianfei Cai.

DeiT

In our paper, we propose a Hierarchical Visual Transformer (HVT) which progressively pools visual tokens to shrink the sequence length and hence reduces the computational cost, analogous to the feature maps downsampling in Convolutional Neural Networks (CNNs). Moreover, we empirically find that the average pooled visual tokens contain more discriminative information than the single class token.

If you use this code for a paper please cite:

@InProceedings{Pan_2021_ICCV,
    author    = {Pan, Zizheng and Zhuang, Bohan and Liu, Jing and He, Haoyu and Cai, Jianfei},
    title     = {Scalable Vision Transformers With Hierarchical Pooling},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {377-386}
}

Updates

Usage

First, clone the repository locally:

git clone https://github.com/MonashAI/HVT

Then, install PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models 0.3.2:

conda install -c pytorch pytorch torchvision
pip install timm==0.3.2

Data preparation

ImageNet

Download the ImageNet 2012 dataset from here, and prepare the dataset based on this script. The file structure should look like:

imagenet
├── train
│   ├── class1
│   │   ├── img1.jpeg
│   │   ├── img2.jpeg
│   │   └── ...
│   ├── class2
│   │   ├── img3.jpeg
│   │   └── ...
│   └── ...
└── val
    ├── class1
    │   ├── img4.jpeg
    │   ├── img5.jpeg
    │   └── ...
    ├── class2
    │   ├── img6.jpeg
    │   └── ...
    └── ...

CIFAR100

Download the CIFAR100 dataset from here.

Training

To train HVT-Ti-1 on ImageNet with 8 gpus, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --config config/hvt-ti-1.json --data-set IMNET --data-path [path/to/imagenet]

We also provide configuration files for HVT-S-1 and Scale HVT-Ti-4 under the config folder.

To train HVT-Ti-1 on CIFAR100, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --config config/hvt-ti-1.json --data-set CIFAR --data-path [path/to/cifar100]

Evaluation

To evaluate a model on ImageNet, e.g. HVT-S-1, run:

python main.py --config config/hvt-s-1.json --data-set IMNET --data-path [path/to/imagenet] --eval --resume [path/to/hvt_s_1.pth]

Scaling HVT

You can scale a HVT model with various settings, which is supported in the configuration file:

Results on ImageNet

Main Results

NameFLOPs (G)Params (M)Top-1 Acc. (%)Top-5 Acc. (%)ModelLog
HVT-Ti-10.645.7469.6489.40githublog
Scale HVT-Ti-41.3922.1275.2392.30githublog
HVT-S-12.4022.0978.0093.83githublog

Note that model weights and logs for HVT-Ti-1 and HVT-S-1 have been retrained.

More Pooling Stages with HVT-S

NameFLOPs (G)Params (M)Top-1 Acc. (%)Top-5 Acc. (%)ModelLog
HVT-S-04.5722.0580.3995.13githublog
HVT-S-12.4022.0978.0093.83githublog
HVT-S-21.9422.1177.3693.55githublog
HVT-S-31.6222.1176.3292.90githublog
HVT-S-41.3922.1275.2392.30githublog

For CIFAR-100 results, please check out our paper for more details.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Acknowledgement

This repository has adopted codes from DeiT, we thank the authors for their open-sourced code.