Awesome
<p align="center"> <img src="assets/mobilemamba_logo.png" width="600px" /> </p>Official PyTorch implementation of "MobileMamba: Lightweight Multi-Receptive Visual Mamba Network".
Haoyang He<sup>1*</sup>, Jiangning Zhang<sup>2*</sup>, Yuxuan Cai<sup>3</sup>, Hongxu Chen<sup>1</sup> Xiaobin Hu<sup>2</sup>,
Zhenye Gan<sup>2</sup>, Yabiao Wang<sup>2</sup>, Chengjie Wang<sup>2</sup>, Yunsheng Wu<sup>2</sup>, Lei Xie<sup>1†</sup>
<sup>1</sup>College of Control Science and Engineering, Zhejiang University, <sup>2</sup>Youtu Lab, Tencent, <sup>3</sup>Huazhong University of Science and Technology
<div align="center"> <img src="assets/motivation.png" width="800px" /> </div>Abstract: Previous research on lightweight models has primarily focused on CNNs and Transformer-based designs. CNNs, with their local receptive fields, struggle to capture long-range dependencies, while Transformers, despite their global modeling capabilities, are limited by quadratic computational complexity in high-resolution scenarios. Recently, state-space models have gained popularity in the visual domain due to their linear computational complexity. Despite their low FLOPs, current lightweight Mamba-based models exhibit suboptimal throughput. In this work, we propose the MobileMamba framework, which balances efficiency and performance. We design a three-stage network to enhance inference speed significantly. At a fine-grained level, we introduce the Multi-Receptive Field Feature Interaction MRFFI module, comprising the Long-Range Wavelet Transform-Enhanced Mamba WTE-Mamba, Efficient Multi-Kernel Depthwise Convolution MK-DeConv, and Eliminate Redundant Identity components. This module integrates multi-receptive field information and enhances high-frequency detail extraction. Additionally, we employ training and testing strategies to further improve performance and efficiency. MobileMamba achieves up to 83.6% on Top-1, surpassing existing state-of-the-art methods which is maximum x21 faster than LocalVim on GPU. Extensive experiments on high-resolution downstream tasks demonstrate that MobileMamba surpasses current efficient models, achieving an optimal balance between speed and accuracy.
<div align="center"> <img src="assets/comparewithmamba.png" width="600px" /> </div>Top: Visualization of the Effective Receptive Fields (ERF) for different architectures. Bottom: Performance vs. FLOPs with recent CNN/Transformer/Mamba-based methods.<br>
Accuracy vs. Speed with Mamba-based methods.
Classification results
Image Classification for ImageNet-1K:
Model | FLOPs | #Params | Resolution | Top-1 | Cfg | Log | Model |
---|---|---|---|---|---|---|---|
MobileMamba-T2 | 255M | 8.8M | 192 x 192 | 71.5 | cfg | log | model |
MobileMamba-T2† | 255M | 8.8M | 192 x 192 | 76.9 | cfg | log | model |
MobileMamba-T4 | 413M | 14.2M | 192 x 192 | 76.1 | cfg | log | model |
MobileMamba-T4† | 413M | 14.2M | 192 x 192 | 78.9 | cfg | log | model |
MobileMamba-S6 | 652M | 15.0M | 224 x 224 | 78.0 | cfg | log | model |
MobileMamba-S6† | 652M | 15.0M | 224 x 224 | 80.7 | cfg | log | model |
MobileMamba-B1 | 1080M | 17.1M | 256 x 256 | 79.9 | cfg | log | model |
MobileMamba-B1† | 1080M | 17.1M | 256 x 256 | 82.2 | cfg | log | model |
MobileMamba-B2 | 2427M | 17.1M | 384 x 384 | 81.6 | cfg | log | model |
MobileMamba-B2† | 2427M | 17.1M | 384 x 384 | 83.3 | cfg | log | model |
MobileMamba-B4 | 4313M | 17.1M | 512 x 512 | 82.5 | cfg | log | model |
MobileMamba-B4† | 4313M | 17.1M | 512 x 512 | 83.6 | cfg | log | model |
Downstream Results
Object Detection and Instant Segmentation Results
Object Detection and Instant Segmentation Performance Based on Mask-RCNN for COCO2017:
Backbone | AP<sup>b</sup> | AP<sup>b</sup><sub>50</sub> | AP<sup>b</sup><sub>75</sub> | AP<sup>b</sup><sub>S</sub> | AP<sup>b</sup><sub>M</sub> | AP<sup>b</sup><sub>L</sub> | AP<sup>m</sup> | AP<sup>m</sup><sub>50</sub> | AP<sup>m</sup><sub>75</sub> | AP<sup>m</sup><sub>S</sub> | AP<sup>m</sup><sub>M</sub> | AP<sup>m</sup><sub>L</sub> | #Params | FLOPs | Cfg | Log | Model |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
MobileMamba-B1 | 40.6 | 61.8 | 43.8 | 22.4 | 43.5 | 55.9 | 37.4 | 58.9 | 39.9 | 17.1 | 39.9 | 56.4 | 38.0M | 178G | cfg | log | model |
Object Detection Performance Based on RetinaNet for COCO2017:
Backbone | AP | AP<sub>50</sub> | AP<sub>75</sub> | AP<sub>S</sub> | AP<sub>M</sub> | AP<sub>L</sub> | #Params | FLOPs | Cfg | Log | Model |
---|---|---|---|---|---|---|---|---|---|---|---|
MobileMamba-B1 | 39.6 | 59.8 | 42.4 | 21.5 | 43.4 | 53.9 | 27.1M | 151G | cfg | log | model |
Object Detection Performance Based on SSDLite for COCO2017:
Backbone | AP | AP<sub>50</sub> | AP<sub>75</sub> | AP<sub>S</sub> | AP<sub>M</sub> | AP<sub>L</sub> | #Params | FLOPs | Cfg | Log | Model |
---|---|---|---|---|---|---|---|---|---|---|---|
MobileMamba-B1 | 24.0 | 39.5 | 24.0 | 3.1 | 23.4 | 46.9 | 18.0M | 1.7G | cfg | log | model |
MobileMamba-B1-r512 | 29.5 | 47.7 | 30.4 | 8.9 | 35.0 | 47.0 | 18.0M | 4.4G | cfg | log | model |
Semantic Segmentation Results
Semantic Segmentation Based on Semantic FPN for ADE20k:
Backbone | aAcc | mIoU | mAcc | #Params | FLOPs | Cfg | Log | Model |
---|---|---|---|---|---|---|---|---|
MobileMamba-B4 | 79.9 | 42.5 | 53.7 | 19.8M | 5.6G | cfg | log | model |
Semantic Segmentation Based on DeepLabv3 for ADE20k:
Backbone | aAcc | mIoU | mAcc | #Params | FLOPs | Cfg | Log | Model |
---|---|---|---|---|---|---|---|---|
MobileMamba-B4 | 76.3 | 36.6 | 47.1 | 23.4M | 4.7G | cfg | log | model |
Semantic Segmentation Based on PSPNet for ADE20k:
Backbone | aAcc | mIoU | mAcc | #Params | FLOPs | Cfg | Log | Model |
---|---|---|---|---|---|---|---|---|
MobileMamba-B4 | 76.2 | 36.9 | 47.9 | 20.5M | 4.5G | cfg | log | model |
All Pretrained Weights and Logs
The model weights and log files for all classification and downstream tasks are available for download via weights.
Classification
Environments
pip3 install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118
pip3 install timm==0.6.5 tensorboardX einops torchprofile fvcore==0.1.5.post20221221
cd model/lib_mamba/kernels/selective_scan && pip install . && cd ../../../..
git clone https://github.com/NVIDIA/apex && cd apex && pip3 install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ (optional)
Prepare ImageNet-1K Dataset
Download and extract ImageNet-1K dataset in the following directory structure:
├── imagenet
├── train
├── n01440764
├── n01440764_10026.JPEG
├── ...
├── ...
├── train.txt (optional)
├── val
├── n01440764
├── ILSVRC2012_val_00000293.JPEG
├── ...
├── ...
└── val.txt (optional)
Test
Test with 8 GPUs in one node:
<details> <summary> MobileMamba-T2 </summary>python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t2 -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_T2/mobilemamba_t2.pth
This should give Top-1: 73.638 (Top-5: 91.422)
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t2s -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_T2s/mobilemamba_t2s.pth
This should give Top-1: 76.934 (Top-5: 93.100)
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t4 -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_T4/mobilemamba_t4.pth
This should give Top-1: 76.086 (Top-5: 92.772)
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t4s -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_T4s/mobilemamba_t4s.pth
This should give Top-1: 78.914 (Top-5: 94.160)
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_s6 -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_S6/mobilemamba_s6.pth
This should give Top-1: 78.002 (Top-5: 93.992)
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_s6s -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_S6s/mobilemamba_s6s.pth
This should give Top-1: 80.742 (Top-5: 95.182)
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b1 -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_B1/mobilemamba_b1.pth
This should give Top-1: 79.948 (Top-5: 94.924)
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b1s -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_B1s/mobilemamba_b1s.pth
This should give Top-1: 82.234 (Top-5: 95.872)
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b2 -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_B2/mobilemamba_b2.pth
This should give Top-1: 81.624 (Top-5: 95.890)
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b2s -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_B2s/mobilemamba_b2s.pth
This should give Top-1: 83.260 (Top-5: 96.438)
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b4 -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_B4/mobilemamba_b4.pth
This should give Top-1: 82.496 (Top-5: 96.252)
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b4s -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_B4s/mobilemamba_b4s.pth
This should give Top-1: 83.644 (Top-5: 96.606)
Train
Train with 8 GPUs in one node:
<details> <summary> MobileMamba-T2 </summary>python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t2 -m train
</details>
<details>
<summary>
MobileMamba-T2†
</summary>
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t2s -m train
</details>
<details>
<summary>
MobileMamba-T4
</summary>
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t4 -m train
</details>
<details>
<summary>
MobileMamba-T4†
</summary>
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t4s -m train
</details>
<details>
<summary>
MobileMamba-S6
</summary>
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_s6 -m train
</details>
<details>
<summary>
MobileMamba-S6†
</summary>
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_s6s -m train
</details>
<details>
<summary>
MobileMamba-B1
</summary>
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b1 -m train
</details>
<details>
<summary>
MobileMamba-B1†
</summary>
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b1s -m train
</details>
<details>
<summary>
MobileMamba-B2
</summary>
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b2 -m train
</details>
<details>
<summary>
MobileMamba-B2†
</summary>
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b2s -m train
</details>
<details>
<summary>
MobileMamba-B4
</summary>
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b4 -m train
</details>
<details>
<summary>
MobileMamba-B4†
</summary>
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b4s -m train
</details>
Down-Stream Tasks
Environments
pip3 install terminaltables pycocotools prettytable xtcocotools
pip3 install mmpretrain==1.2.0 mmdet==3.3.0 mmsegmentation==1.2.2
pip3 install mmcv==2.1.0 -f https://download.openmmlab.com/mmcv/dist/cu118/torch2.1/index.html
cd det/backbones/lib_mamba/kernels/selective_scan && pip install . && cd ../../../..
Prepare COCO and ADE20k Dataset
Download and extract COCO2017 and ADE20k dataset in the following directory structure:
downstream
├── det
├──── data
│ ├──── coco
│ │ ├──── annotations
│ │ ├──── train2017
│ │ ├──── val2017
│ │ ├──── test2017
├── seg
├──── data
│ ├──── ade
│ │ ├──── ADEChallengeData2016
│ │ ├──────── annotations
│ │ ├──────── images
Object Detection
<details> <summary> Mask-RCNN </summary>Train:
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_train.sh configs/mask_rcnn/mask-rcnn_mobilemamba_b1_fpn_1x_coco.py 4
Test:
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_test.sh configs/mask_rcnn/mask-rcnn_mobilemamba_b1_fpn_1x_coco.py ../../weights/downstream/det/maskrcnn.pth 4
</details>
<details>
<summary>
RetinaNet
</summary>
Train:
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_train.sh configs/retinanet/retinanet_mobilemamba_b1_fpn_1x_coco.py 4
Test:
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_test.sh configs/retinanet/retinanet_mobilemamba_b1_fpn_1x_coco.py ../../weights/downstream/det/retinanet.pth 4
</details>
<details>
<summary>
SSDLite
</summary>
Train with 320 x 320 resolution:
./tools/dist_train.sh configs/ssd/ssdlite_mobilemamba_b1_8gpu_2lr_coco.py 8
Test with 320 x 320 resolution:
./tools/dist_test.sh configs/ssd/ssdlite_mobilemamba_b1_8gpu_2lr_coco.py ../../weights/downstream/det/ssdlite.pth 8
Train with 512 x 512 resolution:
./tools/dist_train.sh configs/ssd/ssdlite_mobilemamba_b1_8gpu_2lr_512_coco.py 8
Test with 512 x 512 resolution:
./tools/dist_test.sh configs/ssd/ssdlite_mobilemamba_b1_8gpu_2lr_512_coco.py ../../weights/downstream/det/ssdlite_512.pth 8
</details>
Semantic Segmentation
<details> <summary> DeepLabV3 </summary>Train:
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_train.sh configs/deeplabv3/deeplabv3_mobilemamba_b4-80k_ade20k-512x512.py 4
Test:
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_test.sh configs/deeplabv3/deeplabv3_mobilemamba_b4-80k_ade20k-512x512.py ../../weights/downstream/seg/deeplabv3.pth 4
</details>
<details>
<summary>
Semantic FPN
</summary>
Train:
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_train.sh configs/sem_fpn/fpn_mobilemamba_b4-160k_ade20k-512x512.py 4
Test:
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_test.sh configs/sem_fpn/fpn_mobilemamba_b4-160k_ade20k-512x512.py ../../weights/downstream/seg/fpn.pth 4
</details>
<details>
<summary>
PSPNet
</summary>
Train:
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_train.sh configs/pspnet/pspnet_mobilemamba_b4-80k_ade20k-512x512.py 4
Test:
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_test.sh configs/pspnet/pspnet_mobilemamba_b4-80k_ade20k-512x512.py ../../weights/downstream/seg/pspnet.pth 4
</details>
Citation
If our work is helpful for your research, please consider citing:
@article{mobilemamba,
title={MobileMamba: Lightweight Multi-Receptive Visual Mamba Network},
author={Haoyang He and Jiangning Zhang and Yuxuan Cai and Hongxu Chen and Xiaobin Hu and Zhenye Gan and Yabiao Wang and Chengjie Wang and Yunsheng Wu and Lei Xie},
journal={arXiv preprint arXiv:2411.15941},
year={2024}
}
Acknowledgements
We thank but not limited to following repositories for providing assistance for our research: