Awesome
Hybrid Mamba for Few-Shot Segmentation
This repository contains the code for our NIPS 2024 paper "Hybrid Mamba for Few-Shot Segmentation", where we design a cross attention-like Mamba method to enable support-query interactions.
Abstract: Many few-shot segmentation (FSS) methods use cross attention to fuse support foreground (FG) into query features, regardless of the quadratic complexity. A recent advance Mamba can also well capture intra-sequence dependencies, yet the complexity is only linear. Hence, we aim to devise a cross (attention-like) Mamba to capture inter-sequence dependencies for FSS. A simple idea is to scan on support features to selectively compress them into the hidden state, which is then used as the initial hidden state to sequentially scan query features. Nevertheless, it suffers from (1) support forgetting issue: query features will also gradually be compressed when scanning on them, so the support features in hidden state keep reducing, and many query pixels cannot fuse sufficient support features; (2) intra-class gap issue: query FG is essentially more similar to itself rather than support FG, i.e., query may prefer not to fuse support but their own features from the hidden state, yet the effective use of support information leads to the success of FSS. To tackle them, we design a hybrid Mamba network (HMNet), including (1) a support recapped Mamba to periodically recap the support features when scanning query, so the hidden state can always contain rich support information; (2) a query intercepted Mamba to forbid the mutual interactions among query pixels, and encourage them to fuse more support features from the hidden state. Consequently, the support information is better utilized, leading to better performance. Extensive experiments have been conducted on two public benchmarks, showing the superiority of HMNet.
Dependencies
- Python 3.10
- PyTorch 1.12.0
- cuda 11.6
- torchvision 0.13.0
> conda env create -f env.yaml
Datasets
You can download the pre-processed PASCAL-5<sup>i</sup> and COCO-20<sup>i</sup> datasets here, and extract them into data/
folder. Then, you need to create a symbolic link to the pascal/VOCdevkit
data folder as follows:
> ln -s <absolute_path>/data/pascal/VOCdevkit <absolute_path>/data/VOCdevkit2012
The directory structure is:
../
├── HMNet/
└── data/
├── VOCdevkit2012/
│ └── VOC2012/
│ ├── JPEGImages/
│ ├── ...
│ └── SegmentationClassAug/
└── MSCOCO2014/
├── annotations/
│ ├── train2014/
│ └── val2014/
├── train2014/
└── val2014/
Models
- Download the pretrained backbones from here and put them into the
initmodel
directory. - Download exp.tar.gz to obtain all trained models for PASCAL-5<sup>i</sup> and COCO-20<sup>i</sup>.
Testing
- Commands:
sh test_pascal.sh {Split: 0/1/2/3} {Net: resnet50/vgg} {Postfix: manet/manet_5s} sh test_coco.sh {Split: 0/1/2/3} {Net: resnet50/vgg} {Postfix: manet/manet_5s} # e.g., testing split 0 under 1-shot setting on PASCAL-5<sup>i</sup>, with ResNet50 as the pretrained backbone: sh test_pascal.sh 0 resnet50 manet # e.g., testing split 0 under 5-shot setting on COCO-20<sup>i</sup>, with ResNet50 as the pretrained backbone: sh test_coco.sh 0 resnet50 manet_5s
References
This repo is mainly built based on BAM. Thanks for their great work!