Home

Awesome

UNetMamba

πŸ‘€Introduction

UNetMamba is the official PyTorch implementation of paper UNetMamba: An Efficient UNet-Like Mamba for Semantic Segmentation of High-Resolution Remote Sensing Images. (IEEE GRSL undergoing review)

πŸ“‚Folder Structure

Prepare the following folders to organize this repo:

UNetMamba-main
β”œβ”€β”€ UNetMamba
|   β”œβ”€β”€config 
|   β”œβ”€β”€tools
|   β”œβ”€β”€unetmamba_model
|   β”œβ”€β”€train.py
|   β”œβ”€β”€loveda_test.py
|   β”œβ”€β”€vaihingen_test.py
β”œβ”€β”€ pretrain_weights (pretrained weights of backbones)
β”œβ”€β”€ model_weights (model weights trained on ISPRS vaihingen, LoveDA, etc)
β”œβ”€β”€ fig_results (the segmentation results)
β”œβ”€β”€ data
β”‚   β”œβ”€β”€ LoveDA
β”‚   β”‚   β”œβ”€β”€ Train
β”‚   β”‚   β”‚   β”œβ”€β”€ Urban
β”‚   β”‚   β”‚   β”‚   β”œβ”€β”€ images_png (original)
β”‚   β”‚   β”‚   β”‚   β”œβ”€β”€ masks_png (original)
β”‚   β”‚   β”‚   β”‚   β”œβ”€β”€ masks_png_convert (converted masks generated by tools/loveda_mask_convert.py)
β”‚   β”‚   β”‚   β”‚   β”œβ”€β”€ masks_png_convert_rgb (rgb format converted masks generated by tools/loveda_mask_convert.py)
β”‚   β”‚   β”‚   β”œβ”€β”€ Rural
β”‚   β”‚   β”‚   β”‚   β”œβ”€β”€ images_png 
β”‚   β”‚   β”‚   β”‚   β”œβ”€β”€ masks_png 
β”‚   β”‚   β”‚   β”‚   β”œβ”€β”€ masks_png_convert
β”‚   β”‚   β”‚   β”‚   β”œβ”€β”€ masks_png_convert_rgb
β”‚   β”‚   β”œβ”€β”€ Val (the same with Train)
β”‚   β”‚   β”œβ”€β”€ Test
β”‚   β”‚   β”œβ”€β”€ train_val (merge Train and Val)
β”‚   β”œβ”€β”€ vaihingen (a total of 33 original images)
β”‚   β”‚   β”œβ”€β”€ test_images (9 original images, randomly selected)
β”‚   β”‚   β”œβ”€β”€ test_masks (9 original rgb masks)
β”‚   β”‚   β”œβ”€β”€ test_masks_eroded (9 eroded rgb masks, xxxx_noBoundary.tif)
β”‚   β”‚   β”œβ”€β”€ train_images (22 original images, randomly selected in remaining images)
β”‚   β”‚   β”œβ”€β”€ train_masks (22 original rgb masks)
β”‚   β”‚   β”œβ”€β”€ val_images (remaining 2 original images)
β”‚   β”‚   β”œβ”€β”€ val_masks (remaining 2 original rgb masks)
β”‚   β”‚   β”œβ”€β”€ val_masks_eroded (remaining 2 eroded rgb masks, xxxx_noBoundary.tif)
β”‚   β”‚   β”œβ”€β”€ train_1024 (train set at 1024*1024)
β”‚   β”‚   β”œβ”€β”€ test_1024 (test set at 1024*1024)
β”‚   β”‚   β”œβ”€β”€ val_1024 (validation set at 1024*1024)
β”‚   β”‚   β”œβ”€β”€ ...

πŸ› Install

conda create -n UNetMamba-main python=3.8
conda activate UNetMamba-main
pip install -r UNetMamba/requirements.txt

πŸ’Tips: If you're having difficulty in installing "causal_conv1d" or "mamba_ssm", please refer to causal_conv1d or mamba_ssm to download the wheel files and then pip install them. For our UNetMamba, we installed both "causal_conv1d-1.2.0.post2+cu118torch2.0cxx11abiFALSE-cp38-cp38-linux_x86_64.whl" and "mamba_ssm-1.1.1+cu118torch2.0cxx11abiFALSE-cp38-cp38-linux_x86_64.whl". Moreover, UNetMamba is also compatible with the newest version of "causal_conv1d" and "mamba_ssm", please feel free to try😁.

🧩Pretrained Weights of Backbones

pretrain_weights

🧩Pretrained Weights of UNetMamba

model_weights

πŸ’ΏData Preprocessing

Download the datasets from the official website and split them as follows.

1️⃣LoveDA (LoveDA official)

python UNetMamba/tools/loveda_mask_convert.py --mask-dir data/LoveDA/Train/Rural/masks_png --output-mask-dir data/LoveDA/Train/Rural/masks_png_convert
python UNetMamba/tools/loveda_mask_convert.py --mask-dir data/LoveDA/Train/Urban/masks_png --output-mask-dir data/LoveDA/Train/Urban/masks_png_convert

python UNetMamba/tools/loveda_mask_convert.py --mask-dir data/LoveDA/Val/Rural/masks_png --output-mask-dir data/LoveDA/Val/Rural/masks_png_convert
python UNetMamba/tools/loveda_mask_convert.py --mask-dir data/LoveDA/Val/Urban/masks_png --output-mask-dir data/LoveDA/Val/Urban/masks_png_convert

python UNetMamba/tools/loveda_mask_convert.py --mask-dir data/LoveDA/train_val/Rural/masks_png --output-mask-dir data/LoveDA/train_val/Rural/masks_png_convert
python UNetMamba/tools/loveda_mask_convert.py --mask-dir data/LoveDA/train_val/Urban/masks_png --output-mask-dir data/LoveDA/train_val/Urban/masks_png_convert

2️⃣Vaihingen (Vaihingen official)

Generate the train set.

python UNetMamba/tools/vaihingen_patch_split.py 
--img-dir "data/vaihingen/train_images" --mask-dir "data/vaihingen/train_masks" 
--output-img-dir "data/vaihingen/train_1024/images" --output-mask-dir "data/vaihingen/train_1024/masks" 
--mode "train" --split-size 1024 --stride 1024

Generate the validation set. (Tip: the eroded one.)

python UNetMamba/tools/vaihingen_patch_split.py 
--img-dir "data/vaihingen/val_images" --mask-dir "data/vaihingen/val_masks_eroded" 
--output-img-dir "data/vaihingen/val_1024/images" --output-mask-dir "data/vaihingen/val_1024/masks"
--mode "val" --split-size 1024 --stride 1024 --eroded

Generate the test set. (Tip: the eroded one.)

python UNetMamba/tools/vaihingen_patch_split.py 
--img-dir "data/vaihingen/test_images" --mask-dir "data/vaihingen/test_masks_eroded" 
--output-img-dir "data/vaihingen/test_1024/images" --output-mask-dir "data/vaihingen/test_1024/masks"
--mode "val" --split-size 1024 --stride 1024 --eroded

Generate the masks_1024_rgb (RGB format ground truth labels) for visualization.

python UNetMamba/tools/vaihingen_patch_split.py 
--img-dir "data/vaihingen/val_images" --mask-dir "data/vaihingen/val_masks" 
--output-img-dir "data/vaihingen/val_1024/images" --output-mask-dir "data/vaihingen/val_1024/masks_rgb" 
--mode "val" --split-size 1024 --stride 1024 --gt

python UNetMamba/tools/vaihingen_patch_split.py 
--img-dir "data/vaihingen/test_images" --mask-dir "data/vaihingen/test_masks" 
--output-img-dir "data/vaihingen/test_1024/images" --output-mask-dir "data/vaihingen/test_1024/masks_rgb" 
--mode "val" --split-size 1024 --stride 1024 --gt

πŸ‹Training

"-c" means the path of the config, use different config to train different models in different datasets.

python UNetMamba/train.py -c UNetMamba/config/loveda/unetmamba.py
python UNetMamba/train.py -c UNetMamba/config/vaihingen/unetmamba.py

🎯Testing

"-c" denotes the path of the config, Use different config to test different models in different datasets

"-o" denotes the output path

"-t" denotes the test time augmentation (TTA), can be [None, 'lr', 'd4'], default is None, 'lr' is flip TTA, 'd4' is multiscale TTA

"--rgb" denotes whether to output masks in RGB format

1️⃣LoveDA (Online Testing)

python UNetMamba/loveda_test.py -c UNetMamba/config/loveda/unetmamba.py -o fig_results/loveda/unetmamba_test
python UNetMamba/loveda_test.py -c UNetMamba/config/loveda/unetmamba.py -o fig_results/loveda/unetmamba_test -t 'd4'
python UNetMamba/loveda_test.py -c UNetMamba/config/loveda/unetmamba.py -o fig_results/loveda/unetmamba_rgb -t 'd4' --rgb --val

2️⃣Vaihingen

python UNetMamba/vaihingen_test.py -c UNetMamba/config/vaihingen/unetmamba.py -o fig_results/vaihingen/unetmamba_test
python UNetMamba/vaihingen_test.py -c UNetMamba/config/vaihingen/unetmamba.py -o fig_results/vaihingen/unetmamba_test -t 'lr'
python UNetMamba/vaihingen_test.py -c UNetMamba/config/vaihingen/unetmamba.py -o fig_results/vaihingen/unetmamba_rgb --rgb

πŸ€Citation

If you find this project useful in your research, please consider citing: UNetMamba: An Efficient UNet-Like Mamba for Semantic Segmentation of High-Resolution Remote Sensing Images.

@article{zhu2024unetmamba,
  title={UNetMamba: An Efficient UNet-Like Mamba for Semantic Segmentation of High-Resolution Remote Sensing Images},
  author={Zhu, Enze and Chen, Zhan and Wang, Dingkai and Shi, Hanru and Liu, Xiaoxuan and Wang, Lei},
  journal={arXiv preprint arXiv:2408.11545},
  year={2024}
}

❀Acknowledgement