Awesome
Introduction
Paper accepted at NeurIPS 2022.
This is a official repository of SemMAE. Our code references the MAE, thanks a lot for their outstanding work! For details of our work see Semantic-Guided Masking for Learning Masked Autoencoders.
<div align="center"> <img width="900", src="https://github.com/ucasligang/SemMAE/blob/main/src/figure1.png"> </div>Citation
@article{li2022semmae,
title={SemMAE: Semantic-Guided Masking for Learning Masked Autoencoders},
author={Li, Gang and Zheng, Heliang and Liu, Daqing and Wang, Chaoyue and Su, Bing and Zheng, Changwen},
journal={arXiv preprint arXiv:2206.10207},
year={2022}
}
This implementation is in PyTorch+GPU.
- This repo is based on
timm==0.3.2
, for which a fix is needed to work with PyTorch 1.8.1+. - It maybe needed for the repository: tensorboard. It can be installed by 'pip install '.
Process ImageNet dataset(including part mask and pixel values).
<table><tbody> <!-- START TABLE --> <!-- TABLE HEADER --> <tr><td align="left">size</td> <th valign="bottom">16x16 patch</th> <th valign="bottom">8x8 patch</th> <!-- TABLE BODY --> </tr> <tr><td align="left">link</td> <td align="center"><a href="https://drive.google.com/file/d/1bDvyl2azHGleaB6HGVPkveN-0mEjyLcV/view?usp=share_link">download</a></td> <td align="center"><a href="https://pan.baidu.com/s/1wPRWKkPVdHaSKvY61evoEA">pwd:1tum</a></td> </tr> </tr> <tr><td align="left">md5</td> <td align="center"><tt>losed</tt></td> <td align="center"><tt>waiting</tt></td> </tr> </tbody></table>Pretrained models
<table><tbody> <!-- START TABLE --> <!-- TABLE HEADER --> <th valign="bottom">800-epochs</th> <th valign="bottom">ViT-Base 16x16 patch</th> <th valign="bottom">ViT-Base 8x8 patch</th> <!-- TABLE BODY --> <tr><td align="left">pretrained checkpoint</td> <td align="center"><a href="https://drive.google.com/file/d/1GaGWNv8I-ADF8e-Bvftgr2k8qNeyLdTJ/view?usp=share_link">download</a></td> <td align="center"><a href="https://drive.google.com/file/d/1X0yHD4kEM8VCYwSmiNcJfK8jni15cvdH/view?usp=share_link">download</a></td> </tr> <tr><td align="left">md5</td> <td align="center"><tt>1482ae</tt></td> <td align="center"><tt>322b6a</tt></td> </tr> </tbody></table>Evaluation
As a sanity check, run evaluation using our ImageNet fine-tuned models:
<table><tbody> <!-- START TABLE --> <!-- TABLE HEADER --> <th valign="bottom">800-epochs</th> <th valign="bottom">ViT-Base 16x16 patch</th> <th valign="bottom">ViT-Base 8x8 patch</th> <!-- TABLE BODY --> <tr><td align="left">fine-tuned checkpoint</td> <td align="center"><a href="https://drive.google.com/file/d/1KD5JCj-cdcsPkGPQ9n5hwaSg2Rrvm88i/view?usp=share_link">download</a></td> <td align="center"><a href="https://drive.google.com/file/d/1WB0_Mx0XCPMiwnS1PVVD38lq0u9U49R8/view?usp=share_link">download</a></td> </tr> <tr><td align="left">md5</td> <td align="center"><tt>bbc5ef</tt></td> <td align="center"><tt>6abd9e</tt></td> </tr> <tr><td align="left">reference ImageNet accuracy</td> <td align="center">83.352</td> <td align="center">84.444</td> </tr> </tbody></table>Evaluate ViT-Base_16 in a single GPU (${IMAGENET_DIR}
is a directory containing {train, val}
sets of ImageNet):
python main_finetune.py --eval --resume SemMAE_epoch799_vit_base_checkpoint-99.pth --model vit_base_patch16 --batch_size 16 --data_path ${IMAGENET_DIR}
This should give:
* Acc@1 83.352 Acc@5 96.494 loss 0.745
Accuracy of the network on the 50000 test images: 83.4%
Evaluate ViT-Base_8 in a single GPU (${IMAGENET_DIR}
is a directory containing {train, val}
sets of ImageNet):
python main_finetune.py --eval --resume SemMAE_epoch799_vit_base_checkpoint_patch8-78.pth --model vit_base_patch8 --batch_size 8 --data_path ${IMAGENET_DIR}
This should give:
* Acc@1 84.444 Acc@5 97.032 loss 0.683
Accuracy of the network on the 50000 test images: 84.44%.
Note that all of our results are obtained on the pretraining 800-epoches setting, the best checkpoint is lost for vit_base_patch8(The paper reported a performance of 84.5% top-1 acc vs. 84.44% in 78-th epoch).
Pre-training
To pre-train ViT-Large (recommended default) with multi-node distributed training, run the following on 8 nodes with 8 GPUs each:
python -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \
--nnodes=${NNODES} --node_rank=\${SLURM_NODEID} --master_addr=${MASTER_ADDR} \
--use_env main_pretrain_setting3.py \
--output_dir ${OUTPUT_DIR} --log_dir=${OUTPUT_DIR} \
--batch_size 128 \
--model mae_vit_base_patch16 \
--norm_pix_loss \
--mask_ratio 0.75 \
--epochs 800 \
--warmup_epochs 40 \
--blr 1.5e-4 --weight_decay 0.05 \
--setting 3 \
--data_path ${DATA_DIR}
Note that the input path ${DATA_DIR} is our processed dataset path.
Contact
This repo is currently maintained by Gang Li(@ucasligang).