Home

Awesome

EMC-Click: Efficient Mask Correction for Click-Based Interactive Image Segmentation (CVPR2023)

The Pytorch code for "Efficient Mask Correction for Click-Based Interactive Image Segmentation" in CVPR2023.

Abstract

The goal of click-based interactive image segmentation is to extract target masks with the input of positive/negative clicks. Every time a new click is placed, existing methods run the whole segmentation network to obtain a corrected mask, which is inefficient since several clicks may be needed to reach satisfactory accuracy. To this end, we propose an efficient method to correct the mask with a lightweight mask correction network. The whole network remains a low computational cost from the second click, even if we have a large backbone. However, a simple correction network with limited capacity is not likely to achieve comparable performance with a classic segmentation network. Thus, we propose a click-guided self-attention module and a click-guided correlation module to effectively exploits the click information to boost performance. First, several templates are selected based on the semantic similarity with click features. Then the self-attention module propagates the template information to other pixels, while the correlation module directly uses the templates to obtain target out- lines. With the efficient architecture and two click-guided modules, our method shows preferable performance and efficiency compared to existing methods.

<p align="center"> <img src="assets/firstimage.jpg" width="90%" height="90%"> </p> <br/> <br/>

Environment setup

  1. Install the requirements by executing
pip install -r requirements.txt
  1. Prepare the dataset and pretrained backbone weights following: Data_Weight_Preparation.md
<br/>

Evaluation

Download the pretrained checkpoints from Releases and put them into weights directory.

Run

python -m torch.distributed.launch --master_port=4321 --nproc_per_node=8 scripts/evaluate_model.py EMC-Click \
    --model_dir='./weights/' \
    --checkpoint=hr18s.pth,hr18.pth,hr32.pth,segb0.pth,segb3.pth \
    --n-clicks=20 \
    --gpus=0,1,2,3,4,5,6,7 \
    --target-iou=0.9 \
    --thresh=0.5 \
    --eval-mode='emc-click' \
    --datasets=GrabCut,Berkeley,SBD,DAVIS,PascalVOC

to evaluate all models on the GrabCut, Berkeley, SBD, DAVIS, PascalVOC datasets.

<br/>

Train

Run

CONFIG=models/emcclick/hrnet18s_att_cclvis.py
EXP_NAME=hrnet18s_att_cclvis
nGPUS=4
nBS=64
nWORKERS=4
PORT=`expr $RANDOM + 5000`

python -m torch.distributed.launch --nproc_per_node=$nGPUS --master_port=$PORT \
    train.py $CONFIG \
    --ngpus=$nGPUS \
    --workers=$nWORKERS \
    --batch-size=$nBS \
    --exp-name=$EXP_NAME

to train with the hrnet18s backbone.

You could find a templet in ./trainval_scripts/train_xxx.sh.

<br/>

Acknowledgement

The code is implemented based on RITM and ClickSEG. We would like to express our sincere thanks to the contributors.

<br/>

License

The code is released under the MIT License. It is a short, permissive software license. Basically, you can do whatever you want as long as you include the original copyright and license notice in any copy of the software/source.

<br/>

Citation

If you find this work is useful for your research, please cite our papers:

@inproceedings{emcclick,
  title={Efficient Mask Correction for Click-Based Interactive Image Segmentation},
  author={Du, Fei and Yuan, Jianlong and Wang, Zhibin and Wang, Fan},
  booktitle={CVPR},
  year={2023}
}