Home

Awesome

Learning to Balance: Bayesian Meta-Learning for Imbalanced and Out-of-distribution Tasks

This is the TensorFlow implementation for the paper Learning to Balance: Bayesian Meta-Learning for Imbalanced and Out-of-distribution Tasks (ICLR 2020, oral presentation) : https://openreview.net/pdf?id=rkeZIJBYvr.

You can reproduce the results of Table 1 in the main paper.

Abstract

<img align="middle" width="700" src="https://github.com/haebeom-lee/l2b/blob/master/images/concept.png">

While tasks could come with varying the number of instances and classes in realistic settings, the existing meta-learning approaches for few-shot classification assume that number of instances per task and class is fixed. Due to such restriction, they learn to equally utilize the meta-knowledge across all the tasks, even when the number of instances per task and class largely varies. Moreover, they do not consider distributional difference in unseen tasks, on which the meta-knowledge may have less usefulness depending on the task relatedness. To overcome these limitations, we propose a novel meta-learning model that adaptively balances the effect of the meta-learning and task-specific learning within each task. Through the learning of the balancing variables, we can decide whether to obtain a solution by relying on the meta-knowledge or task-specific learning. We formulate this objective into a Bayesian inference framework and tackle it using variational inference. We validate our Bayesian Task-Adaptive Meta-Learning (Bayesian TAML) on two realistic task- and class-imbalanced datasets, on which it significantly outperforms existing meta-learning approaches. Further ablation study confirms the effectiveness of each balancing component and the Bayesian learning framework.

Contribution of this work

Structure of the posterior inference network <img align="middle" width="700" src="https://github.com/haebeom-lee/l2b/blob/master/images/encoder.png">

Prerequisites

If you are not familiar with preparing conda environment, please follow the below instructions:

$ conda create --name py35 python=3.5
$ conda activate py35
$ pip install --upgrade pip
$ pip install tensorflow-gpu==1.12.0
$ conda install -c anaconda cudatoolkit=9.0
$ conda install -c anaconda cudnn

And for data preprocessing,

$ pip install tqdm
$ pip install requests
$ pip install Pillow
$ pip install scipy

Data Preparation

Go to the folder of each dataset (i.e. data/cifar, data/svhn, data/mimgnet, or data/cub) and run python get_data.py there. For example, to download CIFAR-FS dataset and preprocess it,

$ cd ./data/cifar
$ python get_data.py

It will take some time to download and preprocess each dataset.

Run

Bash script for running Bayesian TAML model. (You may use only one of the options --omega_on, --gamma_on, and --z_on in order to reproduce the ablation studies in the main paper.)

  1. CIFAR / SVHN experiment
# Meta-training
$ python main.py \
  --gpu_id 0 \
  --savedir "./results/cifar/taml" --id_dataset 'cifar' --ood_dataset 'svhn' \
  --mode 'meta_train' --metabatch 4 --n_steps 5 --way 5 --max_shot 50 --query 15 \
  --n_train_iters 50000 --meta_lr 1e-3 \
  --alpha_on --omega_on --gamma_on --z_on

# Meta-testing
$ python main.py \
  --gpu_id 0 \
  --savedir "./results/cifar/taml" --id_dataset 'cifar' --ood_dataset 'svhn' \
  --mode 'meta_test' --metabatch 4 --n_steps 10 --way 5 --max_shot 50 --query 15 \
  --n_test_episodes 1000 \
  --alpha_on --omega_on --gamma_on --z_on --n_mc_samples 10
  1. miniImageNet /CUB experiment
# Meta-training
$ python main.py \
  --gpu_id 0 \
  --savedir "./results/mimgnet/taml" --id_dataset 'mimgnet' --ood_dataset 'cub' \
  --mode 'meta_train' --metabatch 1 --n_steps 5 --way 5 --max_shot 50 --query 15 \
  --n_train_iters 80000 --meta_lr 1e-4 \
  --alpha_on --omega_on --gamma_on --z_on

# Meta-testing
$ python main.py \
  --gpu_id 0 \
  --savedir "./results/mimgnet/taml" --id_dataset 'mimgnet' --ood_dataset 'cub' \
  --mode 'meta_test' --metabatch 1 --n_steps 10 --way 5 --max_shot 50 --query 15 \
  --n_test_episodes 1000 \
  --alpha_on --omega_on --gamma_on --z_on --n_mc_samples 10
  1. Multi-dataset experiment
# Meta-training
$ python main.py \
 --gpu_id 0 \
 --id_dataset 'aircraft, quickdraw, vgg_flower' --ood_dataset 'traffic, fashion-mnist' \
 --savedir "./results/multi/taml" --mode 'meta_train' --metabatch 3 --n_steps 5 \
 --way 10 --max_shot 50 --query 15 --n_train_iters 60000 -meta_lr 1e-3 \
 --alpha_on --omega_on --gamma_on --z_on 

# Meta-testing
$ python main.py \
 --gpu_id 0 \
 --id_dataset 'aircraft, quickdraw, vgg_flower' --ood_dataset 'traffic, fashion-mnist' \
 --savedir "./results/multi/taml" --mode 'meta_test' --metabatch 3 --n_steps 10 \
 --way 10 --max_shot 50 --query 15 --n_test_episode 1000 \
 --alpha_on --omega_on --gamma_on --z_on --n_mc_samples 10

Results

The results in the main paper (average over three independent runs, total 9000 (=3 x 3000) episodes):

CIFAR-FSSVHNminiImageNetCUB
MAML71.55±0.2345.17±0.2266.64±0.2265.77±0.24
Meta-SGD72.71±0.2146.45±0.2469.95±0.2065.94±0.22
Bayesian-TAML75.15±0.2051.87±0.2371.46±0.1971.71±0.21
AircarftQuickdrawVGG-FlowerTraffic SignsFashion-MNIST
MAML48.60±0.1769.02±0.1860.38±0.1651.96±0.2263.10±0.15
Meta-SGD49.71±0.1770.26±0.1659.41±0.2752.07±0.3562.71±0.25
Bayesian-TAML54.43±0.1672.03±0.1667.72±0.1664.81±0.2168.94±0.13

The results from running this repo (average over single run, total 1000 episodes):

CIFAR-FSSVHNminiImageNetCUB
MAML72.23±0.6747.19±0.6366.95±0.7166.82±0.73
Meta-SGD72.93±0.6647.63±0.7368.04±0.6766.45±0.63
Bayesian-TAML74.97±0.6252.25±0.6871.27±0.5972.89±0.62
AircarftQuickdrawVGG-FlowerTraffic SignsFashion-MNIST
MAML48.17±0.5768.57±0.6060.68±0.5352.37±0.7862.57±0.48
Meta-SGD51.76±0.6170.05±0.5464.28±0.6050.89±0.1.0262.83±0.64
Bayesian-TAML55.70±0.5372.40±0.5068.39±0.5064.17±0.7467.60±0.47

Balancing Variables

While running the code, you can see the tendency of the balancing variables every 1000 iterations. Below shows the example tendency of gamma for each layer over 10 randomly sampled tasks. As you can see, gamma increases as the task size (N) gets larger.

*** Gamma for task imbalance ***
              conv1 conv2 conv3 conv4 dense
task 1: N= 57 0.772 0.775 0.523 0.627 9.438
task 6: N= 75 0.807 0.917 0.834 0.851 4.897
task 3: N= 88 0.785 0.797 0.562 0.658 8.516
task 7: N=112 0.815 0.932 0.895 0.882 4.591
task 8: N=115 0.829 1.001 1.120 1.010 3.469
task 9: N=141 0.831 0.988 1.091 0.990 3.654
task 5: N=142 0.831 0.992 1.104 0.997 3.602
task 0: N=149 0.827 0.961 0.999 0.939 4.094
task 2: N=185 0.853 1.071 1.435 1.162 2.672
task 4: N=245 0.853 1.073 1.443 1.166 2.656

Also, below shows the example tendency of omega for each class. The left part shows the number of instances per class (C1, C2, ..., C5), and the right part shows the actual omega value for each class. As you can see, tail (or smaller) class is more emphasized with bigger omega, and vice versa.

*** Omega for class imbalance ***
         C1  C2  C3  C4  C5        C1    C2    C3    C4    C5
task 1:   1   4   6  17  29 --> 0.444 0.273 0.198 0.056 0.029
task 6:  15  15  15  15  15 --> 0.200 0.200 0.200 0.200 0.200
task 3:   1   4  17  31  35 --> 0.539 0.332 0.068 0.033 0.028
task 7:  13  13  18  20  48 --> 0.292 0.292 0.195 0.166 0.055
task 8:  23  23  23  23  23 --> 0.200 0.200 0.200 0.200 0.200
task 9:  14  20  26  38  43 --> 0.384 0.236 0.174 0.113 0.094
task 5:  14  22  26  33  47 --> 0.394 0.206 0.178 0.138 0.084
task 0:  13  13  28  47  48 --> 0.361 0.361 0.140 0.071 0.068
task 2:  37  37  37  37  37 --> 0.200 0.200 0.200 0.200 0.200
task 4:  49  49  49  49  49 --> 0.200 0.200 0.200 0.200 0.200

We summarize the behavior of them in the main paper as follows: <img align="middle" width="700" src="https://github.com/haebeom-lee/l2b/blob/master/images/tendency.png">

Citation

If you found the provided code useful, please cite our work.

@inproceedings{
    lee2020l2b,
    title={Learning to Balance: Bayesian Meta-Learning for Imbalanced and Out-of-distribution Tasks},
    author={Hae Beom Lee and Hayeon Lee and Donghyun Na and Saehoon Kim and Minseop Park and Eunho Yang and Sung Ju Hwang},
    booktitle={ICLR},
    year={2020}
}