Home

Awesome

SAIM

Official PyTorch Implementation of Exploring Stochastic Autoregressive Image Modeling for Visual Representation, Accepted by AAAI 2023.

Introduction

Pipeline

SAIM is a novel self-supervised pre-training framework that performs autoregressive image modeling with stochastic permutation strategy. Our method significantly improves the performance of autoregressive image modeling and achieves the best accuracy (83.9%) on the vanilla ViT-Base model among methods using only ImageNet-1K data.

Main Results on ImageNet-1k

The following table provides pretrained checkpoints and logs used in the paper.

SAIM-Base
pretrained checkpointsdownload
logsdownload

The results of Finetune and Linear probing on ImageNet-1k are as following:

ModelsArchitecturePretrain EpochsFT acc@1(%)LIN acc@1(%)FT logs/weightsLIN logs/weights
BEiTViT-B80083.237.6--
MAEViT-B160083.667.8--
SimMIMViT-B160083.856.7--
iGPTiGPT-L-72.665.2--
ViT-iGPTViT-B30082.720.4--
SAIMViT-B30083.658.5--
SAIMViT-B80083.962.5log/weightlog/weight

Getting Started

Install

git clone https://github.com/qiy20/SAIM
cd SAIM
conda create -n saim python=3.9
conda activate saim
conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia
pip install timm==0.4.5

Data preparation

You can download the ImageNet-1K here and prepare the ImageNet-1K follow this format:

imagenet
  ├── train
  │   ├── class1
  │   │   ├── img1.jpeg
  │   │   ├── img2.jpeg
  │   │   └── ...
  │   ├── class2
  │   │   ├── img3.jpeg
  │   │   └── ...
  │   └── ...
  └── val
      ├── class1
      │   ├── img4.jpeg
      │   ├── img5.jpeg
      │   └── ...
      ├── class2
      │   ├── img6.jpeg
      │   └── ...
      └── ...

Pretrain

python -m torch.distributed.launch --nproc_per_node 32 main_pretrain.py \
    --batch_size 64 --epochs 800 \
    --model saim_base --query_depth 12 --prediction_head_type MLP \
    --gaussian_kernel_size 9 --gaussian_sigma 1 --norm_pix_loss \
    --blr 2e-4 --warmup_epochs 30 --weight_decay 0.5 \
    --data_path <imagenet-path> --output_dir <output-directory>

Finetune

python -m torch.distributed.launch --nproc_per_node 32 main_finetune.py \
    --model vit_base_patch16 --cls_token --batch_size 32 \
    --blr 5e-4 --layer_decay 0.65 --epochs 100 --warmup_epochs 20 \
    --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
    --dist_eval --data_path <imagenet-path> \
    --finetune <pretrained-ckpt> --output_dir <output-directory>

Linear Probing

python -m torch.distributed.launch --nproc_per_node 32 main_linprobe.py \
    --model vit_base_patch16 --cls_token --batch_size 64 \
    --blr 0.1 --epochs 90 --warmup_epochs 0 --weight_decay 0.0 \
    --dist_eval --data_path <imagenet-path> \
    --finetune <pretrained-ckpt> --output_dir <output-directory>

Visualization

SAIM-attention_v11

We show example results for ImageNet validation set. Description of images from left to right: (a) the original image, (b) the attention map of ViT-iGPT, (c) the attention map of SAIM. SAIM focuses on the main information of the image, and obtains human-level attention representation with unlabeled data.

Acknowledgement

The pretraining and finetuning of our project are based on DeiT , BEiT and MAE.

LICENSE

SAIM is released under the MIT License.

Citation

@inproceedings{qi2023exploring,
  title={Exploring Stochastic Autoregressive Image Modeling for Visual Representation},
  author={Qi, Yu and Yang, Fan and Zhu, Yousong and Liu, Yufei and Wu, Liwei and Zhao, Rui and Li, Wei},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={37},
  number={2},
  pages={2074--2081},
  year={2023}
}