Home

Awesome

Implementation of "FastMIM: Expediting Masked Image Modeling Pre-training for Vision".

<p align="center"> <img src="figs/fastmim.png" > </p> <p align="center"> </p> Comparison among the MAE, SimMIM and our FastMIM framework. MAE randomly masks and discards the input patches. Although there is only small amount of encoder patches, MAE can only be used to pre-train the isotropic ViT which generates single-scale intermediate features. SimMIM preserves input resolution and can serve as a generic framework for all kinds of vision backbones, but it needs to tackle with large amount of patches. Our FastMIM simply reduces the input resolution and replaces the pixel target with HOG target. These modifications are simple yet effective. FastMIM (i) pre-train faster; (ii) has a lighter memory consumption; (iii) can serve as a generic framework for all kinds of architectures; and (iv) achieves comparable and even better performances compared to previous methods.

Set up

- python==3.x
- cuda==10.x
- torch==1.7.0+
- mmcv-full-1.4.4+

# other pytorch/cuda/timm version can also work

# To pip your environment
sh requirement_pip_install.sh

# build your apex (optional)
cd /your_path_to/apex-master/;
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is:

│path/to/imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Pre-training on ImageNet-1K

<details> <summary> ViT-B </summary>

To train ViT-B on ImageNet-1K on a single node with 8 gpus:

python -m torch.distributed.launch --nproc_per_node=8 main_pretrain.py --model mim_vit_base --data_path /your_path_to/data/imagenet/ --epochs 800 --warmup_epochs 20 --blr 1.5e-4 --weight_decay 0.05 --output_dir /your_path_to/fastmim_pretrain_output/ --batch_size 512 --save_ckpt_freq 100 --num_workers 10 --mask_ratio 0.75 --norm_pix_loss --rrc_scale 0.2 1.0 --input_size 128 --decoder_embed_dim 256 --decoder_depth 1 --block_size 16 --mim_loss HOG
</details> <details> <summary> Swin-B </summary>

To train Swin-B on ImageNet-1K on a single node with 8 gpus:

python -m torch.distributed.launch --nproc_per_node=8 main_pretrain.py --model mim_swin_base --data_path /your_path_to/data/imagenet/ --epochs 400 --warmup_epochs 10 --blr 1.5e-4 --weight_decay 0.05 --output_dir /your_path_to/fastmim_pretrain_output/ --batch_size 256 --save_ckpt_freq 50 --num_workers 10 --mask_ratio 0.75 --norm_pix_loss --input_size 128 --rrc_scale 0.2 1.0 --window_size 4 --decoder_embed_dim 256 --decoder_depth 4 --mim_loss HOG --block_size 32
</details>

Finetuning on ImageNet-1K

<details> <summary> ViT-B </summary>

To fine-tune ViT-B on ImageNet-1K on a single node with 8 gpus:

python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py --model vit_base_patch16 --data_path /your_path_to/data/imagenet/ --batch_size 128 --accum_iter 1 --epochs 100 --blr 6e-4 --layer_decay 0.70 --weight_decay 0.05 --drop_path 0.1 --dist_eval --finetune /your_path_to_ckpt/checkpoint-799.pth --output_dir /your_path_to/fastmim_finetune_output/
</details> <details> <summary> Swin-B </summary>

To fine-tune Swin-B on ImageNet-1K on a single node with 8 gpus:

python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py --model swin_base_patch4_window7_224 --data_path /your_path_to/data/imagenet/ --batch_size 128 --epochs 100 --blr 1.0e-3 --layer_decay 0.80 --weight_decay 0.05 --drop_path 0.1 --dist_eval --finetune /your_path_to_ckpt/checkpoint-399.pth --output_dir /your_path_to/fastmim_finetune_output/
</details>

Notice

We build our object detection and sementic segmentation codebase upon mmdet-v2.23 and mmseg-v0.28, however, we also add some features from the updated mmdet version (e.g., simple copy-paste) into our mmdet-v2.23. If you directly download the mmdet-v2.23 from MMDet, the code may report some errors.

Results and Models

Classification on ImageNet-1K (ViT-B/Swin-B/PVTv2-b2/CMT-S)

Model#ParamsPT Res.PT EpochPT log/ckptFT Res.FT log/ckptTop-1 (%)
ViT-B86M128x128800log/ckpt224x224log/ckpt83.8
Swin-B88M128x128400log/ckpt224x224log/ckpt84.1
PVTv2-B225M128x128800224x224ckpt82.5
CMT-S25M128x128800224x224ckpt83.9

Object Detection on COCO (Swin-B based Mask R-CNN)

ModelBackbonePretrainLr schdbox APmask APConfigCheckpoint
Mask R-CNNSwin-BSimMIM3x52.346.4configlog/ckpt
Mask R-CNNSwin-BFastMIM3x52.046.0configlog/ckpt

Semantic Segmentation on ADE20K (ViT-B based UPerNet)

ModelBackbonePretrainCrop SizeBatch SizeLr schdmIoU(ss)ConfigCheckpoint
UPerNetViT-BFastMIM512x5121616000049.5configlog/ckpt

Citation

If you find this project useful in your research, please consider cite:

@article{guo2022fastmim,
  title={FastMIM: Expediting Masked Image Modeling Pre-training for Vision},
  author={Guo, Jianyuan and Han, Kai and Wu, Han and Tang, Yehui and Wang, Yunhe and Xu, Chang},
  journal={arXiv preprint arXiv:2212.06593},
  year={2022}
}

Acknowledgement

The classification task in this repo is based on MAE, SimMIM, SlowFast and timm.

The object detection task in this repo is baesd on MMDet, ViDet and Swin-Transformer-Object-Detection.

The semantic segmentation task in this repo is baesd on MMSeg and BEiT.

License

License: MIT