Home

Awesome

Good helper is around you: Attention-driven Masked Image Modeling

This repository contains PyTorch implementation of AMT(Attention-driven masking and throwing strategy) with MAE and SimMIM.
For details see Good helper is around you: Attention-driven Masked Image Modeling. AMT

Preparation

Requirements

To configure environment, you can run:

conda create -n AMT python=3.8 -y
conda install pytorch=1.11 torchvision cudatoolkit=11.3 -c pytorch -y
pip install -r requirements.txt

Points for Attention

class Attention(nn.Module):
...
    return x,attn # add attn 
...
class Block(nn.Module):
...
    def forward(self, x, return_attention=False):# add return_attention as a flag
        y,attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        ...
...

Datasets

Please download and organize the datasets in this structure:

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
      ...
    class2/
      img2.jpeg
      ...
  val/
    class1/
      img3.jpeg
      ...
    class/2
      img4.jpeg
      ...

Getting Started

This repo supports both attention-driven masking and AMT strategies, you can switch by setting mask_ratio and throw_ratio.

We test on a 4-gpu server, accum_iter is for mataining effective batch size with the original method. The following script is an example of AMT with MAE.

Pre-training with MAE

OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=4 main_pretrain.py \
    --model vit_base_patch16 \
    --batch_size 128 \
    --accum_iter 8 \
    --blr 1.5e-4 \
    --data_path ${IMAGENET_DIR}

Finetuning with MAE

OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=4 main_finetune.py \
    --accum_iter 8 \
    --batch_size 32 \
    --model vit_base_patch16 \
    --finetune ${PRETRAIN_CHKPT} \
    --epochs 100 \
    --blr 1e-3 --layer_decay 0.75 \
    --weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
    --dist_eval --data_path ${IMAGENET_DIR}

Linprobing with MAE

OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=4 main_linprobe.py \
    --accum_iter 8 \
    --batch_size 512 \
    --model vit_base_patch16 --cls_token \
    --finetune ${PRETRAIN_CHKPT} \
    --epochs 90 \
    --blr 0.1 \
    --weight_decay 0.0 \
    --dist_eval --data_path ${IMAGENET_DIR}

This code is written by Zhengqi Liu. If you have questions or find bugs in the codes, feel free to contact Zhengqi Liu.

ATTN: This package is free for academic usage. You can run it at your own risk. For other purposes, please contact Jie Gui (guijie@ustc.edu).