Awesome
HCTransformers
PyTorch implementation for "Attribute Surrogates Learning and Spectral Tokens Pooling in Transformers for Few-shot Learning".
[arxiv
]
<div align="center"> <img width="100%" alt="HCT Network Architecture" src=".github/network.png"> </div>Code will be continuously updated.
Updates
06/21/2024
Update share links for pretrained weights and extracted features.
07/07/2022
- Datasets description and guideline are updated.
- Features extracted by the pretrained models on our ππππImageNet is also provided here.
07/01/2022
Provided pretrained weights download links and evaluation command line.
Prerequisites
This codebase has been developed with Python version 3.8, PyTorch version 1.9.0, CUDA 11.1 and torchvision 0.10.0. It has been tested on Ubuntu 20.04.
Pretrained weights
Pretrained weights on ππππImageNet, ππππππ
ImageNet, CIFAR-FS and FC100 are available now. Note that for ππππππ
ImageNet
and FC100
there are only checkpoints for the first stage (without cascaded training). Accuracy of 5-way 1-shot and 5-way 5-shot shown in the table is evaluated on the test
split and for reference only.
Pretrained weights for the cascaded-trained models on ππππImageNet and CIFAR-FS are provided as follows. Note that the path to pretrained weight in the first stage must be specified when evaluating (see Evaluation).
<table> <tr> <th>dataset</th> <th>1-shot</th> <th>5-shot</th> <th colspan="2">download</th> </tr> <tr> <td>ππππImageNet</td> <td>74.74%</td> <td>89.19%</td> <td rowspan="4"> <a href="https://www.ilanzou.com/s/TfVHcz4">checkpoints_pooling</a> </td> <td> <a href="https://www.ilanzou.com/s/ZUzHCPC">features_mini</a> </td> </tr> <tr> <td>CIFAR-FS</td> <td>78.89%</td> <td>90.50%</td> <td>-</td> </tr> </table>Datasets
ππππImageNet
The ππππImageNet dataset was proposed by Vinyals et al. for few-shot learning evaluation. Its complexity is high due to the use of ImageNet images but requires fewer resources and infrastructure than running on the full ImageNet dataset. In total, there are 100 classes with 600 samples of color images per class. These 100 classes are divided into 64, 16, and 20 classes respectively for sampling tasks for meta-training, meta-validation, and meta-test. To generate this dataset from ImageNet, you may use the repository ππππImageNet tools.
Note that in our implemenation images are resized to 480 Γ 480 because the data augmentation we used require the image resolution to be greater than 224 to avoid distortions. Therefore, when generating ππππImageNet, you should set --image_resize 0
to keep the original size or --image_resize 480
as what we did.
ππππππ ImageNet
The π‘πππππImageNet dataset is a larger subset of ILSVRC-12 with 608 classes (779,165 images) grouped into 34 higher-level nodes in the ImageNet human-curated hierarchy. To generate this dataset from ImageNet, you may use the repository π‘πππππImageNet dataset: π‘πππππImageNet tools.
Similar to ππππImageNet, you should set --image_resize 0
to keep the original size or --image_resize 480
as what we did when generating ππππππ
ImageNet.
Training
We provide the training code for ππππImageNet, ππππππ ImageNet and CIFAR-FS, extending the DINO repo (link).
1 Pre-train the First Transformer
To pre-train the first Transformer with attribute surrogates learning on ππππImageNet from scratch with multiple GPU, run:
python -m torch.distributed.launch --nproc_per_node=8 main_hct_first.py --arch vit_small --data_path /path/to/mini_imagenet/train --output_dir /path/to/saving_dir
2 Train the Hierarchically Cascaded Transformers
To train the Hierarchically Cascaded Transformers with sprectral token pooling on ππππImageNet, run:
python -m torch.distributed.launch --nproc_per_node=8 main_hct_pooling.py --arch vit_small --data_path /path/to/mini_imagenet/train --output_dir /path/to/saving_dir --pretrained_weights /path/to/pretrained_weights
Evaluation
To evaluate the performance of the first Transformer on ππππImageNet 5-way 1-shot task, run:
python eval_hct_first.py --arch vit_small --server mini --partition test --checkpoint_key student --ckp_path /path/to/checkpoint_mini/ --num_shots 1
To evaluate the performance of the Hierarchically Cascaded Transformers on ππππImageNet 5-way 5-shot task, run:
python eval_hct_pooling.py --arch vit_small --server mini_pooling --partition val --checkpoint_key student --ckp_path /path/to/checkpoint_mini_pooling/ --pretrained_weights /path/to/pretrained_weights_of_first_satge --num_shots 5
License
This repository is released under the Apache 2.0 license as found in the LICENSE file.
Citation
If you find our code or paper useful to your research work, please consider citing our work using the following bibtex:
@inproceedings{he2022attribute,
title={Attribute surrogates learning and spectral tokens pooling in transformers for few-shot learning},
author={He, Yangji and Liang, Weihan and Zhao, Dongyang and Zhou, Hong-Yu and Ge, Weifeng and Yu, Yizhou and Zhang, Wenqiang},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={9119--9129},
year={2022}
}