Home

Awesome

MAE for Self-supervised ViT

Introduction

This is an unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-supervised ViT. This repo is mainly based on moco-v3, pytorch-image-models, BEiT and MAE-pytorch.

<img src="figures/mae.png" alt="image-mae" style="zoom: 33%;" />

TODO

Main Result

We support two representations (repre.) for classification: GAP (Global Average Pooling) and Cls-token. According to paper, MAE works similarily well with both of them. In Cls-token mode, it is trained in encoder of MAE.

For k-NN evaluation, we use k=10 as default.

ViT-Small

pretrain epochrepre.ft. top1lin.k-NNconfigweightlog
100GAP76.58%34.65%19.7%pretrain finetunepretrain finetunepretrain finetune
100Cls-token75.77%38.95%23.7%pretrain finetunepretrain finetunepretrain finetune
200GAP76.86%36.46%19.8%pretrain finetunepretrain finetunepretrain finetune
400GAP77.56% / 80.02% / 80.89%36.98%20.8%pretrain finetunepretrain finetunepretrain finetune
800GAP77.93% / 80.87% / 81.11%36.88%20.7%pretrain finetunepretrain finetunepretrain finetune
1600GAP--pretrain finetunepretrain finetunepretrain finetune

ViT-Base

pretrain epochrepre.ft. top1k-NNconfigweightlog
400GAP83.08%28.9%pretrain finetunepretrain finetunepretrain finetune

ViT-Large

pretrain epochrepre.ft. top1lin.k-NNconfigweightlog
100GAP83.51%58.90%33.08%pretrain finetunepretrain finetunepretrain finetune

Usage

Preparation

The code has been tested with CUDA 11.4, PyTorch 1.8.2.

Notes:

  1. The batch size specified by -b is batch-size per card.
  2. The learning rate specified by --lr is the base lr (corresponding to 256 batch-size), and is adjusted by the linear lr scaling rule.
  3. In this repo, only multi-gpu, DistributedDataParallel training is supported; single-gpu or DataParallel training is not supported.
  4. We support cls-token (token) and global averaging pooling (GAP) for classification. Please verify the correspondence of pretraining and finetuning/linear probing. For cls-token mode during pretraining, cls-token is trained in encoder.

Self-supervised Pre-Training

Below is examples for MAE pre-training.

ViT-Small with 1-node (8-GPU, NVIDIA GeForce RTX 3090) training, batch-size 4096, GAP.

sh run_pretrain.sh \
	--config cfgs/pretrain/Vit-S_100E_GAP.yaml \
	--data_path /path/to/train/data

End-to-End Fine-tuning

ViT-Small with 1-node (8-GPU, NVIDIA GeForce RTX 3090) training, 50epochs, batch-size 4096, GAP.

sh run_finetune.sh \
	--config cfgs/finetune/ViT-S_50E_GAP.yaml \
	--data_path /path/to/data \
	--finetune /path/to/pretrain/model

Linear Classification

According to paper, we have two training modes: SGD + 4096 batch-size and LARS + 16384 batch-size.

ViT-Small with 1-node (8-GPU, NVIDIA GeForce RTX 3090) training, 50epochs, SGD + batch-size 4096, GAP.

sh run_lincls.sh \
	--config cfgs/lincls/ViT-S_SGD_GAP.yaml \
	--data_path /path/to/data \
	--finetune /path/to/pretrain/model

k-NN Evaluation of Pretrain Model

ViT-Small with 1-node (8-GPU, NVIDIA GeForce RTX 3090), GAP.

sh run_knn.sh \
	--config cfgs/finetune/ViT-S_50E_GAP.yaml \
	--data_path /path/to/data \
	--finetune /path/to/pretrain/model \
	--save_path /path/to/save/result

Visualization of Restruction

ViT-Base Pretrained by 400 Epochs.

python tools/run_mae_vis.py \
	--config cfgs/pretrain/ViT-B_400E_Norm_GAP.yaml \
	--save_path output/restruct/ \
	--model_path /path/to/pretrain/model \
	--img_path /path/to/image

Visualization of Restruction

ViT-Small w/ CLS-Token Pretrained by 100 Epochs.

python tools/vit_explain.py
--config cfgs/finetune/ViT-S_50E_CLS-Token.yaml
--finetune /path/to/pretrain/model
--image_path /path/to/image
--head_fusion max
--discard_ratio 0.9

License

This project is under the CC-BY-NC 4.0 license. See LICENSE for details.

Citation

If you use the code of this repo, please cite the original paper and this repo:

@Article{he2021mae,
  author  = {Kaiming He* and Xinlei Chen* and Saining Xie and Yanghao Li and Piotr Dolla ́r and Ross Girshick},
  title   = {Masked Autoencoders Are Scalable Vision Learners},
  journal = {arXiv preprint arXiv:2111.06377},
  year    = {2021},
}
@misc{yang2021maepriv,
  author       = {Lu Yang* and Pu Cao* and Yang Nie and Qing Song},
  title        = {MAE-priv},
  howpublished = {\url{https://github.com/BUPT-PRIV/MAE-priv}},
  year         = {2021},
}