Awesome
Test-Time Training with Masked Autoencoders<br><sub>Official PyTorch Implementation</sub>
Paper | Project Page
Yossi Gandelsman*, Yu Sun*, Xinlei Chen and Alexei A. Efros
Setup
We provide an environment.yml
file that can be used to create a Conda environment:
conda env create -f environment.yml
conda activate ttt
Training MAE
To train a model on the main task, please use the code base from Masked Autoencoders Are Scalable Vision Learners. We provided a self-contained code for training here as well. Please run:
TIME=$(date +%s%3N)
DATA_PATH='...'
OUTPUT_DIR='...'
python -m torch.distributed.launch --nproc_per_node=8 main_pretrain.py \
--data_path ${DATA_PATH} \
--model mae_vit_large_patch16 \
--input_size 224 \
--batch_size 64 \
--mask_ratio 0.75 \
--warmup_epochs 40 \
--epochs 800 \
--blr 1e-3 \
--save_ckpt_freq 100 \
--output_dir ${OUTPUT_DIR} \
--dist_url "file://$OUTPUT_DIR/$TIME"
Alternatively, you can use a pretrained large VIT model from here:
mkdir checkpoints
cd checkpoints
wget https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large_full.pth
Training the classification head
To train the classification head, run this:
TIME=$(date +%s%3N)
DATA_PATH='...'
OUTPUT_DIR='...'
RESUME_MODEL='checkpoints/mae_pretrain_vit_large_full.pth'
python -m torch.distributed.launch --nproc_per_node=8 main_prob.py \
--batch_size 32 \
--accum_iter 4 \
--model mae_vit_large_patch16 \
--finetune ${RESUME_MODEL} \
--epochs 20 \
--input_size 224 \
--head_type vit_head \
--blr 1e-3 \
--norm_pix_loss \
--weight_decay 0.2 \
--dist_eval --data_path ${DATA_PATH} --output_dir ${OUTPUT_DIR}
Alternatively, you can use a pretrained model (with slightly different parameters) from here:
mkdir checkpoints
cd checkpoints
wget https://dl.fbaipublicfiles.com/mae/ttt/prob_lr1e-3_wd.2_blk12_ep20.pth
Test-time training
To train the model, you will first need to download the imagenet-c
dataset, from here.
After extracting the dataset, you can run test-time training on each of the test sets:
DATA_PATH_BASE='path_to_imagenet-c'
DATASET='gaussian_noise'
LEVEL='5'
RESUME_MODEL='checkpoints/mae_pretrain_vit_large_full.pth'
RESUME_FINETUNE='checkpoints/prob_lr1e-3_wd.2_blk12_ep20.pth'
OUTPUT_DIR_BASE='...'
python main_test_time_training.py \
--data_path "$DATA_PATH_BASE/$DATASET/$LEVEL" \
--model mae_vit_large_patch16 \
--input_size 224 \
--batch_size 128 \
--steps_per_example 20 \
--mask_ratio 0.75 \
--blr 1e-2 \
--norm_pix_loss \
--optimizer_type 'sgd' \
--classifier_depth 12 \
--head_type "vit_head" \
--single_crop \
--dataset_name "imagenet_c" \
--output_dir "$OUTPUT_DIR_BASE/$DATASET/" \
--dist_url "file://$OUTPUT_DIR_BASE/$TIME" \
--finetune_mode 'encoder' \
--resume_model ${RESUME_MODEL} \
--resume_finetune ${RESUME_FINETUNE}
Baseline evaluation
To evaluate the model without applying test-time training, run:
DATA_PATH_BASE='path_to_imagenet-c'
DATASET='gaussian_noise'
LEVEL='5'
RESUME_MODEL='checkpoints/mae_pretrain_vit_large_full.pth'
RESUME_FINETUNE='checkpoints/prob_lr1e-3_wd.2_blk12_ep20.pth'
OUTPUT_DIR_BASE='...'
python test_without_adaptation.py \
--data_path "$DATA_PATH_BASE/$DATASET/$LEVEL" \
--model mae_vit_large_patch16 \
--input_size 224 \
--resume_model ${RESUME_MODEL} \
--resume_finetune ${RESUME_FINETUNE} \
--output_dir "$OUTPUT_DIR_BASE/$DATASET/baseline" \
--classifier_depth 12 \
--head_type "vit_head"
BibTeX
@inproceedings{
gandelsman2022testtime,
title={Test-Time Training with Masked Autoencoders},
author={Yossi Gandelsman and Yu Sun and Xinlei Chen and Alexei A Efros},
booktitle={Advances in Neural Information Processing Systems},
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
year={2022},
url={https://openreview.net/forum?id=SHMi1b7sjXk}
}