Awesome
PlainMamba: Improving Non-hierarchical Mamba in Visual Recognition
<p align="center"> <img src="resources/plainmamba_teaser.png"/> </p>This repository contains the official PyTorch implementation of our paper:
Usage
Environment Setup
Our classification codebase is built upon the MMClassification toolkit (old version).
conda create -n plain_mamba python=3.10 -y
source activate plain_mamba
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 -f https://download.pytorch.org/whl/torch_stable.html --no-cache
conda install -c conda-forge cudatoolkit-dev # Optional, only needed when facing cuda errors
pip install -U openmim
mim install mmcv-full
pip install mamba-ssm
pip install mlflow fvcore timm lmdb
cd plain_mamba
pip install -e .
cd downstream/mmdetection # set up object detection and instance segmentation
pip install -e .
cd downstream/mmsegmentation # set up semantic segmentation
pip install -e .
Data Preparation
For ImageNet experiment, we convert the dataset to LMDB format for efficient data loading. You can convert the dataset by running:
python tools/dataset_tools/create_lmdb_dataset.py \
--train-img-dir data/imagenet/train \
--train-out data/imagenet/imagenet_lmdb/train \
--val-img-dir data/imagenet/val \
--val-out data/imagenet/imagenet_lmdb/val
You will also need to download the ImageNet meta data from Link.
For downstream tasks, please follow MMDetection and MMSegmentation to set up your datasets.
After setting up, the datasets file structure should be as the following:
PlainMamba
|-- ...
|-- data
| |__ imagenet
| |-- imagenet_lmdb
| | |-- train
| | | |-- data.mdb
| | | |__ lock.mdb
| | |-- val
| | | |-- data.mdb
| | | |__ lock.mdb
| |__ meta
| |__ ...
|__ downstream
|-- mmsegmentation
| |-- ...
| |__ data
| |__ ade
| |__ ADEChallengeData2016
| |-- annotations
| | |__ ...
| |-- images
| | |__ ...
| |-- objectInfo150.txt
| |__ sceneCategories.txt
|
|__ mmdetection
|-- ...
|__ data
|__ coco
|-- train2017
| |__ ...
|-- val2017
| |__ ...
|__ annotations
|-- instances_train2017.json
|-- instances_val2017.json
|__ ...
ImageNet Classification
Training PlainMamba
# Example: Training PlainMamba-L1 model
zsh tool/dist_train.sh plain_mamba_configs/plain_mamba_l1_in1k_300e.py 8
Testing PlainMamba
# Example: Testing PlainMamba-L1 model
zsh tool/dist_test.sh plain_mamba_configs/plain_mamba_l1_in1k_300e.py work_dirs/plain_mamba_l1_in1k_300e/epoch_300.pth 8 --metrics accuracy
COCO Object Detection and Instance Segmentation
Run cd downstream/mmdetection
first.
Training Mask R-CNN using PlainMamba-Adapter
# Example: Training PlainMamba-Adapter-L1 Mask R-CNN with 1x schedule
zsh tools/dist_train.sh plain_mamba_det_configs/maskrcnn/l1_maskrcnn_1x.py 8
Training RetinaNet using PlainMamba-Adapter
# Example: Training PlainMamba-Adapter-L1 RetinaNet with 1x schedule
zsh tools/dist_train.sh plain_mamba_det_configs/retinanet/l1_retinanet_1x.py 8
Testing Mask R-CNN
# Example: Testing PlainMamba-Adapter-L1 Mask R-CNN 1x model
zsh tools/dist_test.sh plain_mamba_det_configs/maskrcnn/l1_maskrcnn_1x.py work_dirs/l1_maskrcnn_1x/epoch_12.pth 8 --eval bbox segm
Testing RetinaNet
# Example: Testing PlainMamba-Adapter-L1 RetinaNet 1x model
zsh tools/dist_test.sh plain_mamba_det_configs/retinanet/l1_retinanet_1x.py work_dirs/l1_retinanet_1x/epoch_12.pth 8 --eval bbox
ADE20K Semantic Segmentation
Run cd downstream/mmsegmentation
first.
Training UperNet using PlainMamba
# Example: Training PlainMamba-L1 based UperNet
zsh tools/dist_train.sh plain_mamba_seg_configs/l1_upernet.py 8
Testing UperNet
# Example: Testing PlainMamba-L1 based UperNet
zsh tools/dist_test.sh plain_mamba_seg_configs/l1_upernet.py work_dirs/l1_upernet/iter_160000.pth 8 --eval mIoU
Benchmark results
ImageNet-1k Classification
Model | #Params (M) | Top-1 Acc | Top-5 Acc | Config | Model |
---|---|---|---|---|---|
PlainMamba-L1 | 7.3 | 77.9 | 94.0 | Link | Link |
PlainMamba-L2 | 25.7 | 81.6 | 95.6 | Link | Link |
PlainMamba-L3 | 50.5 | 82.3 | 95.9 | Link | Link |
COCO Mask R-CNN 1x Schedule
Model | #Params (M) | AP Box | AP Mask | Config | Model |
---|---|---|---|---|---|
PlainMamba-Adapter-L1 | 31 | 44.1 | 39.1 | Link | Link |
PlainMamba-Adapter-L2 | 53 | 46.0 | 40.6 | Link | Link |
PlainMamba-Adapter-L3 | 79 | 46.8 | 41.2 | Link | Link |
COCO RetinaNet 1x Schedule
Model | #Params (M) | AP Box | Config | Model |
---|---|---|---|---|
PlainMamba-Adapter-L1 | 19 | 41.7 | Link | Link |
PlainMamba-Adapter-L2 | 40 | 43.9 | Link | Link |
PlainMamba-Adapter-L3 | 67 | 44.8 | Link | Link |
ADE20K UperNet
Model | #Params (M) | mIoU | Config | Model |
---|---|---|---|---|
PlainMamba-L1 | 35 | 44.1 | Link | Link |
PlainMamba-L2 | 55 | 46.8 | Link | Link |
PlainMamba-L3 | 81 | 49.1 | Link | Link |
Citation
@misc{yang2024plainmamba,
title={PlainMamba: Improving Non-Hierarchical Mamba in Visual Recognition},
author={Chenhongyi Yang and Zehui Chen and Miguel Espinosa and Linus Ericsson and Zhenyu Wang and Jiaming Liu and Elliot J. Crowley},
year={2024},
eprint={2403.17695},
archivePrefix={arXiv},
primaryClass={cs.CV}
}