Home

Awesome

MS-MLP: Mixing and Shifting in Vision Transformers

This is the official implementation of our MS-MLP -- "Mixing and Shifting: Exploiting Global and Local Dependencies in Vision MLPs", by Huangjie Zheng, Pengcheng He, Weizhu Chen and Mingyuan Zhou.

multi-scale-regional-mixing-teaser

The proposed mixing and shifting operation exploit both long-range and short-range dependencies without self-attention. In a MLP-based archtecture, Mix-Shift-MLP (MS-MLP) makes the size of the local receptive field used for mixing increase with respect to the amount of relative distance achived by the spatial shifting. This directly contributes to the interactions between neighbor and distant tokens.

Model Overview

msblock-teaser

In each MS-block, we first split the feature map into several groups along the channel dimension, with the first group regarded as the source of query tokens. In the other groups, as the centers of the attended regions (marked with yellow stars) become more and more distant, we gradually increase the mixing spatial range from 1 x 1 to 7 x 7. After the mixing operation, we shift the split channel groups to align their mixed center tokens with the query and then continue the channel-wise mixing with a channel MLP.

Image Classification on ImageNet-1K

NetworkResolutionTop-1 (%)ParamsFLOPsThroughput (image/s)model
MS-MLP-Tiny224x22482.128M4.9G792.0download
MS-MLP-Small224x22483.450M9.0G483.8download
MS-MLP-Base224x22483.888M16.1G366.5download

Getting Started

Install

git clone https://github.com/JegZheng/MS-MLP
cd MS-MLP
conda create -n msmlp python=3.7 -y
conda activate msmlp
conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch
pip install timm==0.3.2
pip install cupy-cuda101
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8

Data preparation

We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to load data:

Evaluation

To evaluate a pre-trained MS-MLP on ImageNet val, run:

python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py --eval \
--cfg <config-file> --resume <checkpoint> --data-path <imagenet-path> 

For example, to evaluate the MS-MLP-Tiny with a single GPU:

python -m torch.distributed.launch --nproc_per_node 1 --nnodes=1 --master_port 12345 main.py --eval \
--cfg configs/msmlp_tiny_patch4_shift5_224.yaml --resume <msmlp-tiny.pth> --data-path <imagenet-path>

Training from scratch

To train a MS-MLP on ImageNet from scratch, run:

python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345  main.py \ 
--cfg <config-file> --data-path <imagenet-path> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]

Notes:

For example, to train MS-MLP with 8 GPU on a single node for 300 epochs, run:

MS-MLP-Tiny:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 main.py \
--cfg configs/msmlp_tiny_patch4_shift5_224.yaml --data-path <imagenet-path> --batch-size 128 --cache-mode no \
--accumulation-steps 0 --output <output-path>

MS-MLP-Small:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 main.py \
--cfg configs/msmlp_small_patch4_shift5_224.yaml --data-path <imagenet-path> --batch-size 128 --cache-mode no \
--accumulation-steps 0 --output <output-path>

MS-MLP-Base:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 main.py \
--cfg configs/msmlp_base_patch4_shift5_224.yaml --data-path <imagenet-path> --batch-size 64 --cache-mode no \
--accumulation-steps 2 --output <output-path>

For multi-node training, please add --node_rank, --master_addr, --master_port options. For example:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=$RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT  main.py \
--cfg configs/msmlp_base_patch4_shift5_224.yaml --data-path <imagenet-path> --batch-size 64 --cache-mode no \
--accumulation-steps 0 --output <output-path>

Throughput

To measure the throughput, run:

python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345  main.py \
--cfg <config-file> --data-path <imagenet-path> --batch-size 64 --throughput --amp-opt-level O0

Citation

If you find this repo useful to your project, please consider to cite it with following bib:

@misc{zheng2022mixing,
  title={Mixing and Shifting: Exploiting Global and Local Dependencies in Vision MLPs}, 
  author={Huangjie Zheng and Pengcheng He and Weizhu Chen and Mingyuan Zhou},
  year={2022},
  eprint={2202.06510},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}

Acknowledgement

Our codebase is built based on Swin-Transformer, AS-MLP and Focal-Transformer. We thank the authors for the nicely organized code!