Home

Awesome

Low-shot Learning by Shrinking and Hallucinating Features

This repository contains code associated with the following paper:<br> Low-shot Visual Recognition by Shrinking and Hallucinating Features <br> Bharath Hariharan, Ross Girshick<br> arxiv 2016.

You can find trained models here.

Prerequisites

This code uses pytorch, numpy and h5py. It requires GPUs and Cuda.

Running the code

Running a low-shot learning experiment will involve three or four steps:

  1. Train a ConvNet representation
  2. Save features from the ConvNet
  3. (Optional) Train analogy-based generator
  4. Use saved features to train and test on the low-shot learning benchmark. Each step is described below.

The scripts directory contains scripts required to generate results for the baseline representation, representations trained with the SGM loss or L2 regularization, and results with and without the analogy-based generation strategy.

Training a ConvNet representation

To train the ConvNet, we first need to specify the training and validation sets. The training and validation datasets, together with data-augmentation and preprocessing steps, are specified through yaml files: see base_classes_train_template.yaml and base_classes_val_template.yaml. You will need to specify the path to the directory containing ImageNet in each file.

The main entry point for training a ConvNet representation is main.py. For example, to train a ResNet10 representation with the sgm loss, run:

mkdir -p checkpoints/ResNet10_sgm
python ./main.py --model ResNet10 \
  --traincfg base_classes_train_template.yaml \
  --valcfg base_classes_val_template.yaml \
  --print_freq 10 --save_freq 10 \
  --aux_loss_wt 0.02 --aux_loss_type sgm \
  --checkpoint_dir checkpoints/ResNet10_sgm

Here, aux_loss_type is the kind of auxilliary loss to use (sgm or l2 or batchsgm), aux_loss_wt is the weight attached to this auxilliary loss, and checkpoint_dir is a cache directory to save the checkpoints.

The model checkpoints will be saved as epoch-number.tar. Training by default runs for 90 epochs, so the final model saved will be 89.tar.

Saving features from the ConvNet

The next step is to save features from the trained ConvNet. This is fairly straightforward: first, create a directoryto save the features in, and then save the features for the train set and the validation set. Thus, for the ResNet10 model trained above:

mkdir -p features/ResNet10_sgm
python ./save_features.py \
  --cfg train_save_data.yaml \
  --outfile features/ResNet10_sgm/train.hdf5 \
  --modelfile checkpoints/ResNet10_sgm/89.tar \
  --model ResNet10
python ./save_features.py \
  --cfg val_save_data.yaml \
  --outfile features/ResNet10_sgm/val.hdf5 \
  --modelfile checkpoints/ResNet10_sgm/89.tar \
  --model ResNet10

Training the analogy-based generator

The entry point for training the analogy-based generator is train_analogy_generator.py. To train the analogy based generation on the above representation, run:

mkdir generation
python ./train_analogy_generator.py \
  --lowshotmeta label_idx.json \
  --trainfile features/ResNet10_sgm/train.hdf5 \
  --outdir generation \
  --networkfile checkpoints/ResNet10_sgm/89.tar \
  --initlr 1

Here, label_idx.json contains the split of base and novel classes, and is used to pick out the saved features corresponding to just the base classes. The analogy generation has several steps and maintains a cache. The final generator will be saved in generation/ResNet10_sgm/generator.tar

Running the low shot benchmark

The benchmark tests with 5 different settings for the number of novel category examples n = {1,2,5,10,20}. The benchmark is organized into 5 experiments, with each experiment corresponding to a fixed choice of n examples for each category.

The main entry point for running the low shot benchmark is low_shot.py, which will run a single experiment for a single value of n. Thus, to run the benchmark, low_shot.py will have to be run 25 times. This design choice has been made to allow the 25 experiments to be run in parallel.

There is one final wrinkle. To allow cross-validation of hyperparameters, there are two different setups for the benchmark: a validation setup, and a test setup. The setups use different settings for the hyperparameters.

To run the benchmark, first create a results directory, and then run each experiment for each value of n. For example, running the first experiment with n=2 on the test setup will look like:

python ./low_shot.py --lowshotmeta label_idx.json \
  --experimentpath experiment_cfgs/splitfile_{:d}.json \
  --experimentid  1 --lowshotn 2 \
  --trainfile features/ResNet10_sgm/train.hdf5 \
  --testfile features/ResNet10_sgm/val.hdf5 \
  --outdir results \
  --lr 1 --wd 0.001 \
  --testsetup 1

If you want to use the analogy based generator, and generate till there are at least 5 examples per category, then you can run:

python ./low_shot.py --lowshotmeta label_idx.json \
  --experimentpath experiment_cfgs/splitfile_{:d}.json \
  --experimentid  1 --lowshotn 2 \
  --trainfile features/ResNet10_sgm/train.hdf5 \
  --testfile features/ResNet10_sgm/val.hdf5 \
  --outdir results \
  --lr 1 --wd 0.001 \
  --testsetup 1 \
  --max_per_label 5 \
  --generator_name analogies \
  --generator_file generation/ResNet10_sgm/generator.tar

Here generator_name is the kind of generator to use; only analogy based generation is implemented, but other ways of generating data can easily be added (see below).

Once all the experiments are done, you can use the quick-and-dirty script parse_results.py to assemble the results:

python ./parse_results.py --resultsdir results \
  --repr ResNet10_sgm \
  --lr 1 --wd 0.001 \
  --max_per_label 5

Extensions

New losses

It is fairly easy to implement novel loss functions or forms of regularization. Such losses can be added to losses.py, and can make use of the scores, the features, and even the model weights. Create your own loss function, add it to the dictionary of auxiliary losses in GenericLoss, and specify how it should be called in the __call__ function.

New generation strategies

Again, implementing new data generation strategies is also easy. Any generation strategy should provide two functions:

  1. init_generator should be able to load whatever state you need to load from a single filename provided as input and return a generator.
  2. do_generate should take four arguments: the original set of novel class feats, novel class labels, the generator produced by init_generator and the total number of examples per class we want to target. It should return a new set of novel class feats and novel class labels that include both the real and the generated examples.

Add any new generation strategy to generation.py.

Matching Networks

This repository also includes an implementation of Matching Networks. Given a saved feature representation (such as the one above), you can train matching networks by running:

python matching_network.py --test 0 \
  --trainfile features/ResNet10_sgm/train.hdf5 \
  --lowshotmeta label_idx.json \
  --modelfile matching_network_sgm.tar

This will save the trained model in matching_network_sgm.tar. Then, test the model using:

python matching_network.py --test 1 \
  --trainfile features/ResNet10_sgm/train.hdf5 \
  --testfile features/ResNet10_sgm/val.hdf5 \
  --lowshotmeta label_idx.json \
  --modelfile matching_network_sgm.tar \
  --lowshotn 1 --experimentid 1 \
  --experimentpath experiment_cfgs/splitfile_{:d}.json \
  --outdir results

As in the benchmark above, this tests a single experiment for a single value of n.

New results

The initial implementation, corresponding to the original paper, was in Lua. For this release, we have switched to Pytorch. As such, there are small differences in the resulting numbers, although the trends are the same. The new numbers are below:

Top-1, Novel classes

RepresentationLow-shot phasen=1251020
BaselineBaseline2.7710.7826.3835.4641.49
BaselineGeneration9.1715.8525.4733.2140.41
SGMBaseline4.1413.0827.8336.0441.36
SGMGeneration9.8517.3227.8936.1741.42
Batch SGMBaseline4.1613.0128.1236.5642.07
L2Baseline7.1416.7527.7332.3235.11
BaselineMatching Networks18.3323.8731.0835.2738.45
Baseline (Resnet 50)Baseline6.8218.3736.5546.1551.99
Baseline (Resnet 50)Generation16.5825.3836.1644.5352.06
SGM (Resnet 50)Baseline10.2321.4537.2546.0051.83
SGM (Resnet 50)Generation15.7724.4337.2245.9651.82

Top-5, Novel classes

RepresentationLow-shot phasen=1251020
BaselineBaseline14.1033.3456.2066.1571.52
BaselineGeneration29.6842.1556.1364.5270.56
SGMBaseline23.1442.3761.6869.6073.76
SGMGeneration32.8046.3761.7069.7173.81
Batch SGMBaseline22.9742.3561.9169.9174.45
L2Baseline29.0847.4262.3367.9670.63
BaselineMatching Networks41.2751.2562.1367.8271.78
Baseline (Resnet 50)Baseline28.1651.0371.0178.3982.32
Baseline (Resnet 50)Generation44.7658.9871.3777.6582.30
SGM (Resnet 50)Baseline37.8157.0872.7879.0982.61
SGM (Resnet 50)Generation45.1158.8372.7679.0982.61

Top-1, Base classes

RepresentationLow-shot phasen=1251020
BaselineBaseline71.0469.6365.6763.5662.83
BaselineGeneration72.3870.1268.5068.1169.47
SGMBaseline75.7674.2470.8269.0268.29
SGMGeneration72.6271.0570.8668.8868.24
Batch SGMBaseline75.7574.5070.8368.8768.04
L2Baseline74.5072.2669.9969.6269.42
BaselineMatching Networks48.7152.1058.6562.5565.25
Baseline (Resnet 50)Baseline83.1681.9478.3676.2775.32
Baseline (Resnet 50)Generation79.3977.8176.8676.1275.27
SGM (Resnet 50)Baseline83.9682.5279.0476.7875.37
SGM (Resnet 50)Generation81.1779.6079.0476.8475.35

Top-5, Base classes

RepresentationLow-shot phasen=1251020
BaselineBaseline88.9087.5384.5683.2382.76
BaselineGeneration88.3286.8185.6185.5686.97
SGMBaseline91.0089.3286.6785.5184.97
SGMGeneration88.4387.1286.6285.4984.95
Batch SGMBaseline91.1389.3586.5585.3884.88
L2Baseline90.0387.8485.9985.6385.60
BaselineMatching Networks76.7077.8280.5982.1983.27
Baseline (Resnet 50)Baseline95.1194.0291.8490.7190.23
Baseline (Resnet 50)Generation92.3091.2690.4890.3390.18
SGM (Resnet 50)Baseline95.2593.8191.3289.8989.28
SGM (Resnet 50)Generation92.9491.6791.3589.9089.21

Top-1, All classes

RepresentationLow-shot phasen=1251020
BaselineBaseline29.1633.5341.5746.3349.74
BaselineGeneration33.6036.8342.1146.7051.65
SGMBaseline31.8336.7344.4548.7951.77
SGMGeneration34.1238.0944.5048.8251.79
Batch SGMBaseline31.8436.7844.6349.0552.11
L2Baseline33.1838.2144.0746.7448.37
BaselineMatching Networks30.0834.7841.7445.8148.81
Baseline (Resnet 50)Baseline36.3342.9552.7157.7961.01
Baseline (Resnet 50)Generation40.8645.6551.9056.7461.03
SGM (Resnet 50)Baseline38.7345.0653.4057.9060.93
SGM (Resnet 50)Generation41.0545.7653.3957.9060.92

Top-5, All classes

RepresentationLow-shot phasen=1251020
BaselineBaseline43.0254.2967.1772.7575.86
BaselineGeneration52.3559.4267.5372.6576.91
SGMBaseline49.3760.5271.3475.7578.10
SGMGeneration54.3162.1271.3375.8178.12
Batch SGMBaseline49.3260.5271.4475.8978.48
L2Baseline52.6563.0571.4874.7976.41
BaselineMatching Networks54.9761.5269.2773.3876.22
Baseline (Resnet 50)Baseline54.0567.6579.0783.1585.37
Baseline (Resnet 50)Generation63.1471.4578.7682.5585.35
SGM (Resnet 50)Baseline60.0271.2879.9583.2785.19
SGM (Resnet 50)Generation63.6071.5379.9583.2785.16