Home

Awesome

BaseTransformers: Attention over base data-points for One Shot Learning

The code repository for "BaseTransformers: Attention over base data-points for One Shot Learning" [paper] [ArXiv] (Accepted British Machine Vision Conference 2022) in PyTorch. If you use any content of this repo for your work, please cite the following bib entry:

@inproceedings{Maniparambil_2022_BMVC,
author    = {Mayug Maniparambil and Kevin McGuinness and Noel O Connor},
title     = {BaseTransformers: Attention over base data-points for One Shot Learning},
booktitle = {33rd British Machine Vision Conference 2022, {BMVC} 2022, London, UK, November 21-24, 2022},
publisher = {{BMVA} Press},
year      = {2022},
url       = {https://bmvc2022.mpi-inf.mpg.de/0482.pdf}
}

This repository has been adapted from the code repository of "Few-Shot Learning via Embedding Adaptation with Set-to-Set Functions" [https://github.com/Sha-Lab/FEAT]

BaseTransformers

We propose to make use of the well-trained feature representations of the base dataset that are closest to each support instance to improve its representation during meta-test time. To this end, we propose BaseTransformers, that attends to the most relevant regions of the base dataset feature space and improves support instance representations.

<img src='imgs/base_illustrative_centaur_new.png' width='640' height='280'>

Standard Few-shot Learning Results

Experimental results on few-shot learning datasets with ResNet-12 backbone (ResNet12 same as this repo). We report average results with 10,000 randomly sampled few-shot learning episodes for stablized evaluation.

MiniImageNet Dataset

<p align="center"> <img src='imgs/mini.png' width='500' height='280'> </p>

TieredImageNet Dataset

<p align="center"> <img src='imgs/tiered.png' width='300' height='280'> </p>

CUB Dataset

<p align="center"> <img src='imgs/cub.png' width='450' height='200'> </p>

Prerequisites

The following packages are required to run the scripts:

Docker

Alternatively, use docker to re-create the training environment we used. Requires docker, docker compose and nvidia-docker

$ cd docker_nvidia/
  sudo docker compose build
  sudo docker compose up -d
  sudo docker exec -it BaseTransformers_n bash

Run the training commands once inside the docker-bash

Dataset

MiniImageNet Dataset

The MiniImageNet dataset is a subset of the ImageNet that includes a total number of 100 classes and 600 examples per class. We follow the previous setup, and use 64 classes as SEEN categories, 16 and 20 as two sets of UNSEEN categories for model validation and evaluation, respectively. We download mini-imagenet from repo for paper Optimization as a model for few-shot learning

CUB Dataset

Caltech-UCSD Birds (CUB) 200-2011 dataset is initially designed for fine-grained classification. It contains in total 11,788 images of birds over 200 species. On CUB, we randomly sampled 100 species as SEEN classes, and another two 50 species are used as two UNSEEN sets. We crop all images with given bounding boxes before training. We test CUB with the ConvNet and Res12 backbone in BaseTransformers.

TieredImageNet Dataset

TieredImageNet is a large-scale dataset with more categories, which contains 351, 97, and 160 categoriesfor model training, validation, and evaluation, respectively. The dataset can also be download from here. We only test TieredImageNet with ResNet backbone in our work.

Check this for details of data downloading and preprocessing.

Caches

Base 2d features cache: Base features are pre-calculated.

ConvNet

Resnet-12

Semantic querying cache: Closest base-instances are precalculated for faster training.

Download both base 2d features cache and querying cache and place them in embeds_cache/

Code Structures

To reproduce our experiments with BaseTransformers, please use train_fsl.py. There are four parts in the code.

Model Training and Evaluation

Please use train_fsl.py and follow the instructions below. The file will automatically evaluate the model on the meta-test set with 10,000 tasks after given epochs.

Arguments

The train_fsl.py takes the following command line options (details are in the model/utils.py):

Task Related Arguments

Optimization Related Arguments

Model Related Arguments

Other Arguments

Running the command without arguments will train the models with the default hyper-parameter values. Loss changes will be recorded as a tensorboard file.

Training scripts for BaseTransformers

For example, to train the 1-shot/5-shot 5-way BaseTransformers model with ConvNet backbone on MiniImageNet

$ python train_fsl.py  --max_epoch 200 --model_class FEATBaseTransformer3_2d --use_euclidean --backbone_class ConvNet --dataset MiniImageNet --way 5 --eval_way 5 --shot 1 --eval_shot 1 --query 15 --eval_query 15 --balance 0.0 --temperature 0.1 --temperature2 0.1 --lr 0.0001 --lr_mul 10 --lr_scheduler step --step_size 20 --gamma 0.5 --gpu 0 --init_weights ./saves/mini_conv4_ver11_113120.pth --eval_interval 1 --k 30 --base_protos 0 --feat_attn 0 --pass_ids 1 --base_wt 0.1 --remove_instances 1 --embed_pool post_loss_avg --orig_imsize 128 --fast_query ./embeds_cache/fastq_imgnet_wordnet_pathsim_random-preset-wts.pt --embeds_cache_2d ./embeds_cache/embeds_cache_cnn4_contrastive-init-ver1-1-corrected_2d.pt --wandb_mode disabled --mixed_precision O2 --z_norm before_tx

$ python train_fsl.py  --max_epoch 200 --model_class FEATBaseTransformer3_2d --use_euclidean --backbone_class ConvNet --dataset MiniImageNet --way 5 --eval_way 5 --shot 5 --eval_shot 5 --query 15 --eval_query 15 --balance 0 --temperature 0.1 --temperature2 0.1 --lr 0.0001 --lr_mul 10 --lr_scheduler step --step_size 20 --gamma 0.5 --gpu 0 --init_weights ./saves/mini_conv4_ver11_113120.pth --eval_interval 1 --k 30 --base_protos 0 --feat_attn 0 --pass_ids 1 --base_wt 0.1 --remove_instances 1 --embed_pool post_loss_avg --orig_imsize 128 --fast_query ./embeds_cache/fastq_imgnet_wordnet_pathsim_random-preset-wts.pt --embeds_cache_2d ./embeds_cache/embeds_cache_cnn4_contrastive-init-ver1-1-corrected_2d.pt --wandb_mode disabled --mixed_precision O2 --z_norm before_tx

to train the 1-shot/5-shot 5-way BaseTransformer model with ResNet-12 backbone on MiniImageNet:

$ python train_fsl.py  --max_epoch 200 --model_class FEATBaseTransformer3_2d --use_euclidean --backbone_class Res12 --dataset MiniImageNet --way 5 --eval_way 5 --shot 1 --eval_shot 1 --query 15 --eval_query 15 --balance 0.1 --temperature 0.1 --temperature2 0.1 --lr 0.0002 --lr_mul 10 --lr_scheduler step --step_size 40 --gamma 0.5 --gpu 0 --init_weights ./saves/mini_r12_ver2_corrected_140403.pth --eval_interval 1 --k 30 --base_protos 0 --feat_attn 0 --pass_ids 1 --base_wt 0.1 --orig_imsize 128 --embed_pool post_loss_avg --dim_model 640 --remove_instances 1 --fast_query ./embeds_cache/fastq_imgnet_wordnet_pathsim_random-preset-wts.pt --embeds_cache_2d ./embeds_cache/embeds_cache_res12_ver2-640-140403_evalon_2d.pt --baseinstance_2d_norm True --return_simclr 2 --simclr_loss_type ver2.2 --wandb_mode disabled --exp_name mini_1shot --mixed_precision O2 --z_norm before_tx

$ python train_fsl.py  --max_epoch 200 --model_class FEATBaseTransformer3_2d --use_euclidean --backbone_class Res12 --dataset MiniImageNet --way 5 --eval_way 5 --shot 5 --eval_shot 5 --query 15 --eval_query 15 --balance 0 --temperature 0.1 --temperature2 0.1 --lr 0.0005 --lr_mul 10 --lr_scheduler step --step_size 40 --gamma 0.5 --gpu 0 --init_weights ./saves/mini_r12_ver2_corrected_140403.pth --eval_interval 1 --k 10 --base_protos 0 --feat_attn 0 --pass_ids 1 --base_wt 0.1 --orig_imsize 128 --embed_pool post_loss_avg --dim_model 640 --remove_instances 1 --fast_query ./embeds_cache/fastq_imgnet_wordnet_pathsim_random-preset-wts.pt --embeds_cache_2d ./embeds_cache/embeds_cache_res12_ver2-640-140403_evalon_2d.pt --baseinstance_2d_norm True --wandb_mode disabled --exp_name mini_5shot --mixed_precision O2 --z_norm before_tx

to train the 1-shot/5-shot 5-way BaseTransformer model with ResNet-12 backbone on TieredImageNet:

$ python train_fsl.py  --max_epoch 100 --model_class FEATBaseTransformer3_2d --use_euclidean --backbone_class Res12 --dataset TieredImageNet_og --way 5 --eval_way 5 --shot 1 --eval_shot 1 --query 15 --eval_query 15 --balance 0 --temperature 0.1 --temperature2 0.1 --lr 0.0002 --lr_mul 10 --lr_scheduler step --step_size 40 --gamma 0.5 --gpu 0 --init_weights ./saves/tiered_r12_og_nosimclr_180842.pth --eval_interval 1 --base_protos 0 --feat_attn 0 --pass_ids 1 --base_wt 0.1 --remove_instances 1 --embed_pool post_loss_avg --orig_imsize -1 --dim_model 640 --fast_query ./embeds_cache/fastq_tiered_wordnetdef-hypernyms-bert-closest_classes_randomsample_eqlwts_classes-sampling.pt --embeds_cache_2d ./embeds_cache/ti_og_r12-default-180842_classwise_2d_new.pt --k 30 --mixed_precision O2 --wandb_mode disabled --exp_name tiered_1shot --z_norm before_tx

$ python train_fsl.py  --max_epoch 100 --model_class FEATBaseTransformer3_2d --use_euclidean --backbone_class Res12 --dataset TieredImageNet_og --way 5 --eval_way 5 --shot 5 --eval_shot 5 --query 15 --eval_query 15 --balance 0 --temperature 0.1 --temperature2 0.1 --lr 0.0002 --lr_mul 10 --lr_scheduler step --step_size 40 --gamma 0.5 --gpu 0 --init_weights ./saves/tiered_r12_og_nosimclr_180842.pth --eval_interval 1 --base_protos 0 --feat_attn 0 --pass_ids 1 --base_wt 0.1 --remove_instances 1 --embed_pool post_loss_avg --orig_imsize -1 --dim_model 640 --fast_query ./embeds_cache/fastq_tiered_wordnetdef-hypernyms-bert-closest_classes_randomsample_eqlwts_classes-sampling.pt --embeds_cache_2d ./embeds_cache/ti_og_r12-default-180842_classwise_2d_new.pt --k 30 --mixed_precision O2 --wandb_mode disabled --exp_name tiered_5shot --z_norm before_tx

to train the 1-shot/5-shot 5-way BaseTransformer model with ConvNet backbone on CUB dataset:

$ python train_fsl.py  --max_epoch 250 --model_class FEATBaseTransformer3_2d --use_euclidean --backbone_class ConvNet --dataset CUB --way 5 --eval_way 5 --shot 1 --eval_shot 1 --query 15 --eval_query 15 --balance 0 --temperature 0.1 --temperature2 16 --lr 0.0001 --lr_mul 10 --lr_scheduler step --step_size 20 --gamma 0.5 --gpu 0 --init_weights ./saves/cub_bal0.01_jit0.1-0.1_rotate30_simclrfc1-noopt_201711.pt --eval_interval 1 --k 30 --base_protos 0 --feat_attn 0 --pass_ids 1 --base_wt 0.1 --remove_instances 1 --orig_imsize 128 --embed_pool post_loss_avg --mixed_precision O2 --fast_query /notebooks/fastq_cub_semantic_query_top5_random.pt --embeds_cache_2d embeds_cache/cub_bal0.01_jit0.1-0.1_rotate30_simclrfc1-noopt_201711_2d.pt --mixed_precision O2 --wandb_mode disabled

$ python train_fsl.py  --max_epoch 250 --model_class FEATBaseTransformer3_2d --use_euclidean --backbone_class ConvNet --dataset CUB --way 5 --eval_way 5 --shot 5 --eval_shot 5 --query 15 --eval_query 15 --balance 0 --temperature 0.1 --temperature2 16 --lr 0.0001 --lr_mul 10 --lr_scheduler step --step_size 20 --gamma 0.5 --gpu 0 --init_weights ./saves/cub_bal0.01_jit0.1-0.1_rotate30_simclrfc1-noopt_201711.pt --eval_interval 1 --k 30 --base_protos 0 --feat_attn 0 --pass_ids 1 --base_wt 0.1 --remove_instances 1 --orig_imsize 128 --embed_pool post_loss_avg --mixed_precision O2 --fast_query /notebooks/fastq_cub_semantic_query_top5_random.pt --embeds_cache_2d embeds_cache/cub_bal0.01_jit0.1-0.1_rotate30_simclrfc1-noopt_201711_2d.pt --mixed_precision O2 --wandb_mode disabled

to train the 1-shot/5-shot 5-way BaseTransformer model with ResNet backbone on CUB dataset:

$ python train_fsl.py  --max_epoch 250 --model_class FEATBaseTransformer3_2d --use_euclidean --backbone_class Res12 --dataset CUB --way 5 --eval_way 5 --shot 1 --eval_shot 1 --query 15 --eval_query 15 --balance 0.0 --temperature 0.1 --temperature2 0.1 --lr 0.0002 --lr_mul 10 --lr_scheduler step --step_size 40 --gamma 0.5 --gpu 0 --init_weights ./saves/cub_r12_bal0.01_jit0.1-0.1_rotate30_simclrfc1-yesopt_144500.pt --eval_interval 1 --k 30 --base_protos 0 --feat_attn 0 --pass_ids 1 --base_wt 0.1 --orig_imsize 128 --embed_pool post_loss_avg --dim_model 640 --remove_instances 1 --fast_query /notebooks/fastq_cub_semantic_query_top5_random.pt  --embeds_cache_2d ./embeds_cache/cub_r12_bal0.01_jit0.1-0.1_rotate30_simclrfc1-yesopt_144500_2d.pt --baseinstance_2d_norm True --return_simclr 2 --simclr_loss_type ver2.2 --wandb_mode disabled --exp_name mini_1shot --mixed_precision O2 --z_norm before_tx


$ python train_fsl.py  --max_epoch 250 --model_class FEATBaseTransformer3_2d --use_euclidean --backbone_class Res12 --dataset CUB --way 5 --eval_way 5 --shot 5 --eval_shot 5 --query 15 --eval_query 15 --balance 0.0 --temperature 0.1 --temperature2 0.1 --lr 0.0002 --lr_mul 10 --lr_scheduler step --step_size 40 --gamma 0.5 --gpu 0 --init_weights ./saves/cub_r12_bal0.01_jit0.1-0.1_rotate30_simclrfc1-yesopt_144500.pt --eval_interval 1 --k 30 --base_protos 0 --feat_attn 0 --pass_ids 1 --base_wt 0.1 --orig_imsize 128 --embed_pool post_loss_avg --dim_model 640 --remove_instances 1 --fast_query /notebooks/fastq_cub_semantic_query_top5_random.pt  --embeds_cache_2d ./embeds_cache/cub_r12_bal0.01_jit0.1-0.1_rotate30_simclrfc1-yesopt_144500_2d.pt --baseinstance_2d_norm True --return_simclr 2 --simclr_loss_type ver2.2 --wandb_mode disabled --exp_name mini_1shot --mixed_precision O2 --z_norm before_tx

Trained weights for BaseTransformers

Trained weights are available at gdrive_link. The names of the files are [dataset][encoder][numberofshots]shot.pth.

To check test performance use the training scripts as mentioned in 'Training scripts for BaseTransformers' section above with the test argument followed by path to the testing checkpoint.

For example to test the performance of Conv4 1 shot on mini-Imagenet one would use the following command.

$ python train_fsl.py  --max_epoch 200 --model_class FEATBaseTransformer3_2d --use_euclidean --backbone_class ConvNet --dataset MiniImageNet --way 5 --eval_way 5 --shot 1 --eval_shot 1 --query 15 --eval_query 15 --balance 0.0 --temperature 0.1 --temperature2 0.1 --lr 0.0001 --lr_mul 10 --lr_scheduler step --step_size 20 --gamma 0.5 --gpu 0 --init_weights ./saves/mini_conv4_ver11_113120.pth --eval_interval 1 --k 40 --base_protos 0 --feat_attn 0 --pass_ids 1 --base_wt 0.1 --remove_instances 1 --embed_pool post_loss_avg --orig_imsize 128 --fast_query ./embeds_cache/fastq_imgnet_wordnet_pathsim_random-preset-wts.pt --embeds_cache_2d ./embeds_cache/embeds_cache_cnn4_contrastive-init-ver1-1-corrected_2d.pt --wandb_mode disabled --mixed_precision O2 --z_norm before_tx --test ./test_weights/mini_conv4_1shot.pth

Note: It has been observed that using a higher k than that of the training results in higher performance for 1 shot.

Acknowledgment

We thank the following repos providing helpful components/functions in our work.

Contact

Feel free to raise an issue or contact me at mayugmaniparambil@gmail.com for queries and discussions.