Home

Awesome

SiMT: Self-improving Momentum Target

Official PyTorch implementation of "Meta-Learning with Self-Improving Momentum Target" (NeurIPS 2022) by Jihoon Tack, Jongjin Park, Hankook Lee, Jaeho Lee, Jinwoo Shin.

TL;DR: We propose a meta-learning algorithm to generate a target model from which we distill the knowledge to the meta-model, forming a virtuous cycle of improvements.

<p align="center"> <img src=figure/concept_figure.png width="900"> </p>

1. Dependencies

conda create -n simt python=3.8 -y
conda activate simt

pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install torchmeta tensorboardX

2. Dataset

Download the following datasets and place at /data folder

3. Training

3.1. Training option

The options for the training method are as follows:

3.2. Training backbone algorithms

python main.py --mode <MODE> --model <MODEL> --dataset <DATASET>

3.3. Training SiMT

To train SiMT, one should choose the appropriate hyperparameters including momentum coefficient ETA, weight hyperparameter LAM, and dropout probability P.

python main.py --simt --mode <MODE> --model <MODEL> --dataset <DATASET> --eta ETA --lam LAM --drop_p P

4. Evaluation

4.1. Evaluation option

The options for the evaluation are as follows:

4.2. Evaluating backbone algorithms

python eval.py --mode <MODE> --model <MODEL> --dataset <DATASET> --load_path <PATH>

4.3. Evaluating SiMT

python main.py --simt --mode <MODE> --model <MODEL> --dataset <DATASET> --load_path <PATH>

Citation

@inproceedings{tack2022meta,
  title={Meta-Learning with Self-Improving Momentum Target},
  author={Jihoon Tack and Jongjin Park and Hankook Lee and Jaeho Lee and Jinwoo Shin},
  booktitle={Advances in Neural Information Processing Systems},
  year={2022}
}

Reference