Awesome
SiMT: Self-improving Momentum Target
Official PyTorch implementation of "Meta-Learning with Self-Improving Momentum Target" (NeurIPS 2022) by Jihoon Tack, Jongjin Park, Hankook Lee, Jaeho Lee, Jinwoo Shin.
TL;DR: We propose a meta-learning algorithm to generate a target model from which we distill the knowledge to the meta-model, forming a virtuous cycle of improvements.
<p align="center"> <img src=figure/concept_figure.png width="900"> </p>1. Dependencies
conda create -n simt python=3.8 -y
conda activate simt
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install torchmeta tensorboardX
2. Dataset
Download the following datasets and place at /data
folder
- Regression
- Classification
- mini-ImageNet, tiered-ImageNet, CUB, Cars
- All classification datasets require preprocessing with torchmeta library
3. Training
3.1. Training option
The options for the training method are as follows:
<MODE>
: {maml
,anil
,metasgd
,protonet
}<MODEL>
: {conv4
,resnet12
}<DATASET>
: {shapenet
,pose
,miniimagenet
,tieredimagenet
}, note thatpose
indicates Pascal dataset.- One can use
--simt
option to train the backbone meta-learning scheme<MODE>
with SiMT.
3.2. Training backbone algorithms
python main.py --mode <MODE> --model <MODEL> --dataset <DATASET>
3.3. Training SiMT
To train SiMT, one should choose the appropriate hyperparameters including momentum coefficient ETA
, weight hyperparameter LAM
, and dropout probability P
.
python main.py --simt --mode <MODE> --model <MODEL> --dataset <DATASET> --eta ETA --lam LAM --drop_p P
4. Evaluation
4.1. Evaluation option
The options for the evaluation are as follows:
<PATH>
: the path of the pre-trained checkpoints with the best validation accuracy (e.g.,./logs/experiment_name/best.model
).<MODE>
: {maml
,anil
,metasgd
,protonet
}<MODEL>
: {conv4
,resnet12
}<DATASET>
: {shapenet
,pose
,miniimagenet
,tieredimagenet
,cub
,cars
}, note thatpose
indicates Pascal dataset.- One can use
--simt
option to evaluate with the momentum network.
4.2. Evaluating backbone algorithms
python eval.py --mode <MODE> --model <MODEL> --dataset <DATASET> --load_path <PATH>
4.3. Evaluating SiMT
python main.py --simt --mode <MODE> --model <MODEL> --dataset <DATASET> --load_path <PATH>
Citation
@inproceedings{tack2022meta,
title={Meta-Learning with Self-Improving Momentum Target},
author={Jihoon Tack and Jongjin Park and Hankook Lee and Jaeho Lee and Jinwoo Shin},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}