Home

Awesome

PlainMamba: Improving Non-hierarchical Mamba in Visual Recognition

<p align="center"> <img src="resources/plainmamba_teaser.png"/> </p>

This repository contains the official PyTorch implementation of our paper:

PlainMamba: Improving Non-hierarchical Mamba in Visual Recognition, Chenhongyi Yang*, Zehui Chen*, Miguel Espinosa*, Linus Ericsson, Zhenyu Wang, Jiaming Liu, Elliot J. Crowley, BMVC 2024

Usage

Environment Setup

Our classification codebase is built upon the MMClassification toolkit (old version).

conda create -n plain_mamba python=3.10 -y
source activate plain_mamba
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 -f https://download.pytorch.org/whl/torch_stable.html --no-cache
conda install -c conda-forge cudatoolkit-dev # Optional, only needed when facing cuda errors
pip install -U openmim
mim install mmcv-full
pip install mamba-ssm
pip install mlflow fvcore timm lmdb
cd plain_mamba
pip install -e .

cd downstream/mmdetection  # set up object detection and instance segmentation
pip install -e . 
cd downstream/mmsegmentation # set up semantic segmentation
pip install -e .

Data Preparation

For ImageNet experiment, we convert the dataset to LMDB format for efficient data loading. You can convert the dataset by running:

python tools/dataset_tools/create_lmdb_dataset.py \
       --train-img-dir data/imagenet/train \
       --train-out data/imagenet/imagenet_lmdb/train \
       --val-img-dir data/imagenet/val \
       --val-out data/imagenet/imagenet_lmdb/val

You will also need to download the ImageNet meta data from Link.

For downstream tasks, please follow MMDetection and MMSegmentation to set up your datasets.

After setting up, the datasets file structure should be as the following:

PlainMamba
|-- ...
|-- data
|   |__ imagenet
|       |-- imagenet_lmdb
|       |   |-- train
|       |   |   |-- data.mdb
|       |   |   |__ lock.mdb
|       |   |-- val
|       |   |   |-- data.mdb
|       |   |   |__ lock.mdb 
|       |__ meta
|           |__ ...
|__ downstream 
    |-- mmsegmentation
    |   |-- ...
    |   |__ data
    |       |__ ade
    |           |__ ADEChallengeData2016
    |               |-- annotations
    |               |   |__ ...
    |               |-- images
    |               |   |__ ...
    |               |-- objectInfo150.txt
    |               |__ sceneCategories.txt
    |   
    |__ mmdetection
        |-- ...
        |__ data
            |__ coco
                |-- train2017
                |   |__ ...
                |-- val2017
                |   |__ ...
                |__ annotations
                    |-- instances_train2017.json
                    |-- instances_val2017.json
                    |__ ...
      

ImageNet Classification

Training PlainMamba

# Example: Training PlainMamba-L1 model
zsh tool/dist_train.sh plain_mamba_configs/plain_mamba_l1_in1k_300e.py 8 

Testing PlainMamba

# Example: Testing PlainMamba-L1 model
zsh tool/dist_test.sh plain_mamba_configs/plain_mamba_l1_in1k_300e.py work_dirs/plain_mamba_l1_in1k_300e/epoch_300.pth 8 --metrics accuracy

COCO Object Detection and Instance Segmentation

Run cd downstream/mmdetection first.

Training Mask R-CNN using PlainMamba-Adapter

# Example: Training PlainMamba-Adapter-L1 Mask R-CNN with 1x schedule
zsh tools/dist_train.sh plain_mamba_det_configs/maskrcnn/l1_maskrcnn_1x.py 8

Training RetinaNet using PlainMamba-Adapter

# Example: Training PlainMamba-Adapter-L1 RetinaNet with 1x schedule
zsh tools/dist_train.sh plain_mamba_det_configs/retinanet/l1_retinanet_1x.py 8

Testing Mask R-CNN

# Example: Testing PlainMamba-Adapter-L1 Mask R-CNN 1x model
zsh tools/dist_test.sh plain_mamba_det_configs/maskrcnn/l1_maskrcnn_1x.py work_dirs/l1_maskrcnn_1x/epoch_12.pth 8 --eval bbox segm

Testing RetinaNet

# Example: Testing PlainMamba-Adapter-L1 RetinaNet 1x model
zsh tools/dist_test.sh plain_mamba_det_configs/retinanet/l1_retinanet_1x.py work_dirs/l1_retinanet_1x/epoch_12.pth 8 --eval bbox

ADE20K Semantic Segmentation

Run cd downstream/mmsegmentation first.

Training UperNet using PlainMamba

# Example: Training PlainMamba-L1 based UperNet
zsh tools/dist_train.sh plain_mamba_seg_configs/l1_upernet.py 8

Testing UperNet

# Example: Testing PlainMamba-L1 based UperNet
zsh tools/dist_test.sh plain_mamba_seg_configs/l1_upernet.py work_dirs/l1_upernet/iter_160000.pth 8 --eval mIoU

Benchmark results

ImageNet-1k Classification

Model#Params (M)Top-1 AccTop-5 AccConfigModel
PlainMamba-L17.377.994.0LinkLink
PlainMamba-L225.781.695.6LinkLink
PlainMamba-L350.582.395.9LinkLink

COCO Mask R-CNN 1x Schedule

Model#Params (M)AP BoxAP MaskConfigModel
PlainMamba-Adapter-L13144.139.1LinkLink
PlainMamba-Adapter-L25346.040.6LinkLink
PlainMamba-Adapter-L37946.841.2LinkLink

COCO RetinaNet 1x Schedule

Model#Params (M)AP BoxConfigModel
PlainMamba-Adapter-L11941.7LinkLink
PlainMamba-Adapter-L24043.9LinkLink
PlainMamba-Adapter-L36744.8LinkLink

ADE20K UperNet

Model#Params (M)mIoUConfigModel
PlainMamba-L13544.1LinkLink
PlainMamba-L25546.8LinkLink
PlainMamba-L38149.1LinkLink

Citation

@misc{yang2024plainmamba,
      title={PlainMamba: Improving Non-Hierarchical Mamba in Visual Recognition}, 
      author={Chenhongyi Yang and Zehui Chen and Miguel Espinosa and Linus Ericsson and Zhenyu Wang and Jiaming Liu and Elliot J. Crowley},
      year={2024},
      eprint={2403.17695},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}