Home

Awesome

Few-Shot Learning via Embedding Adaptation with Set-to-Set Functions

The code repository for "Few-Shot Learning via Embedding Adaptation with Set-to-Set Functions" [paper] [ArXiv] [slides] [poster] (Accepted by CVPR 2020) in PyTorch. If you use any content of this repo for your work, please cite the following bib entry:

@inproceedings{ye2020fewshot,
  author    = {Han-Jia Ye and
               Hexiang Hu and
               De-Chuan Zhan and
               Fei Sha},
  title     = {Few-Shot Learning via Embedding Adaptation with Set-to-Set Functions},
  booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  pages     = {8808--8817},
  year      = {2020}
}

Embedding Adaptation with Set-to-Set Functions

We propose a novel model-based approach to adapt the instance embeddings to the target classification task with a #set-to-set# function, yielding embeddings that are task-specific and are discriminative. We empirically investigated various instantiations of such set-to-set functions and observed the Transformer is most effective --- as it naturally satisfies key properties of our desired model. We denote our method as Few-shot Embedding Adaptation with Transformer (FEAT).

<img src='imgs/architecture.png' width='640' height='280'>

Standard Few-shot Learning Results

Experimental results on few-shot learning datasets with ResNet-12 backbone (Same as this repo). We report average results with 10,000 randomly sampled few-shot learning episodes for stablized evaluation.

MiniImageNet Dataset

Setups1-Shot 5-Way5-Shot 5-WayLink to Weights
ProtoNet62.3980.531-Shot, 5-Shot
BILSTM63.9080.631-Shot, 5-Shot
DEEPSETS64.1480.931-Shot, 5-Shot
GCN64.5081.651-Shot, 5-Shot
FEAT66.7882.051-Shot, 5-Shot

TieredImageNet Dataset

Setups1-Shot 5-Way5-Shot 5-WayLink to Weights
ProtoNet68.2384.031-Shot, 5-Shot
BILSTM68.1484.231-Shot, 5-Shot
DEEPSETS68.5984.361-Shot, 5-Shot
GCN68.2084.641-Shot, 5-Shot
FEAT70.8084.791-Shot, 5-Shot

Prerequisites

The following packages are required to run the scripts:

Dataset

MiniImageNet Dataset

The MiniImageNet dataset is a subset of the ImageNet that includes a total number of 100 classes and 600 examples per class. We follow the previous setup, and use 64 classes as SEEN categories, 16 and 20 as two sets of UNSEEN categories for model validation and evaluation, respectively.

CUB Dataset

Caltech-UCSD Birds (CUB) 200-2011 dataset is initially designed for fine-grained classification. It contains in total 11,788 images of birds over 200 species. On CUB, we randomly sampled 100 species as SEEN classes, and another two 50 species are used as two UNSEEN sets. We crop all images with given bounding boxes before training. We only test CUB with the ConvNet backbone in our work.

TieredImageNet Dataset

TieredImageNet is a large-scale dataset with more categories, which contains 351, 97, and 160 categoriesfor model training, validation, and evaluation, respectively. The dataset can also be download from here. We only test TieredImageNet with ResNet backbone in our work.

Check this for details of data downloading and preprocessing.

Code Structures

To reproduce our experiments with FEAT, please use train_fsl.py. There are four parts in the code.

Model Training and Evaluation

Please use train_fsl.py and follow the instructions below. FEAT meta-learns the embedding adaptation process such that all the training instance embeddings in a task is adapted, based on their contextual task information, using Transformer. The file will automatically evaluate the model on the meta-test set with 10,000 tasks after given epochs.

Arguments

The train_fsl.py takes the following command line options (details are in the model/utils.py):

Task Related Arguments

Optimization Related Arguments

Model Related Arguments

Other Arguments

Running the command without arguments will train the models with the default hyper-parameter values. Loss changes will be recorded as a tensorboard file.

Training scripts for FEAT

For example, to train the 1-shot/5-shot 5-way FEAT model with ConvNet backbone on MiniImageNet:

$ python train_fsl.py  --max_epoch 200 --model_class FEAT --use_euclidean --backbone_class ConvNet --dataset MiniImageNet --way 5 --eval_way 5 --shot 1 --eval_shot 1 --query 15 --eval_query 15 --balance 1 --temperature 64 --temperature2 16 --lr 0.0001 --lr_mul 10 --lr_scheduler step --step_size 20 --gamma 0.5 --gpu 8 --init_weights ./saves/initialization/miniimagenet/con-pre.pth --eval_interval 1
$ python train_fsl.py  --max_epoch 200 --model_class FEAT --use_euclidean --backbone_class ConvNet --dataset MiniImageNet --way 5 --eval_way 5 --shot 5 --eval_shot 5 --query 15 --eval_query 15 --balance 0.1 --temperature 32 --temperature2 64 --lr 0.0001 --lr_mul 10 --lr_scheduler step --step_size 20 --gamma 0.5 --gpu 14 --init_weights ./saves/initialization/miniimagenet/con-pre.pth --eval_interval 1

to train the 1-shot/5-shot 5-way FEAT model with ResNet-12 backbone on MiniImageNet:

$ python train_fsl.py  --max_epoch 200 --model_class FEAT  --backbone_class Res12 --dataset MiniImageNet --way 5 --eval_way 5 --shot 1 --eval_shot 1 --query 15 --eval_query 15 --balance 0.01 --temperature 64 --temperature2 64 --lr 0.0002 --lr_mul 10 --lr_scheduler step --step_size 40 --gamma 0.5 --gpu 1 --init_weights ./saves/initialization/miniimagenet/Res12-pre.pth --eval_interval 1 --use_euclidean
$ python train_fsl.py  --max_epoch 200 --model_class FEAT  --backbone_class Res12 --dataset MiniImageNet --way 5 --eval_way 5 --shot 5 --eval_shot 5 --query 15 --eval_query 15 --balance 0.1 --temperature 64 --temperature2 32 --lr 0.0002 --lr_mul 10 --lr_scheduler step --step_size 40 --gamma 0.5 --gpu 0 --init_weights ./saves/initialization/miniimagenet/Res12-pre.pth --eval_interval 1 --use_euclidean

to train the 1-shot/5-shot 5-way FEAT model with ResNet-12 backbone on TieredImageNet:

$ python train_fsl.py  --max_epoch 200 --model_class FEAT  --backbone_class Res12 --dataset TieredImageNet --way 5 --eval_way 5 --shot 1 --eval_shot 1 --query 15 --eval_query 15 --balance 0.1 --temperature 64 --temperature2 64 --lr 0.0002 --lr_mul 10 --lr_scheduler step --step_size 20 --gamma 0.5 --gpu 0 --init_weights ./saves/initialization/tieredimagenet/Res12-pre.pth --eval_interval 1  --use_euclidean
$ python train_fsl.py  --max_epoch 200 --model_class FEAT  --backbone_class Res12 --dataset TieredImageNet --way 5 --eval_way 5 --shot 5 --eval_shot 5 --query 15 --eval_query 15 --balance 0.1 --temperature 32 --temperature2 64 --lr 0.0002 --lr_mul 10 --lr_scheduler step --step_size 40 --gamma 0.5 --gpu 0 --init_weights ./saves/initialization/tieredimagenet/Res12-pre.pth --eval_interval 1  --use_euclidean

Acknowledgment

We thank the following repos providing helpful components/functions in our work.