Home

Awesome

<div align="center"> <h2> RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model </h2> </div> <br> <div align="center"> <img src="resources/RSPrompter.png" width="800"/> </div> <br> <div align="center"> <a href="https://kychen.me/RSPrompter"> <span style="font-size: 20px; ">Project Page</span> </a> &nbsp;&nbsp;&nbsp;&nbsp; <a href="https://arxiv.org/abs/2306.16269"> <span style="font-size: 20px; ">arXiv</span> </a> &nbsp;&nbsp;&nbsp;&nbsp; <a href="https://huggingface.co/spaces/KyanChen/RSPrompter"> <span style="font-size: 20px; ">HFSpace</span> </a> </div> <br> <br>

GitHub stars license arXiv Hugging Face Spaces

<br> <br> <div align="center">

English | 简体中文

</div>

Introduction

This repository is the code implementation of the paper RSPrompter: Learning to Prompt for Remote Sensing Instance Segmentation based on Visual Foundation Model, which is based on the MMDetection project.

The current branch has been tested under PyTorch 2.x and CUDA 12.1, supports Python 3.7+, and is compatible with most CUDA versions.

If you find this project helpful, please give us a star ⭐️, your support is our greatest motivation.

<details open> <summary>Main Features</summary> </details>

Update Log

🌟 2023.06.29 Released the RSPrompter project, which implements the SAM-seg, SAM-det, RSPrompter and other models in the paper based on Lightning and MMDetection.

🌟 2023.11.25 Updated the code of RSPrompter, which is completely consistent with the API interface and usage method of MMDetection.

🌟 2023.11.26 Added the LoRA efficient fine-tuning method, and made the input image size variable, reducing the memory usage of the model.

🌟 2023.11.26 Provided a reference for the memory usage of each model, see Common Problems for details.

🌟 2023.11.30 Updated the paper content, see Arxiv for details.

TODO

Table of Contents

Installation

Dependencies

Environment Installation

We recommend using Miniconda for installation. The following command will create a virtual environment named rsprompter and install PyTorch and MMCV.

Note: If you have experience with PyTorch and have already installed it, you can skip to the next section. Otherwise, you can follow these steps to prepare.

<details open>

Step 0: Install Miniconda.

Step 1: Create a virtual environment named rsprompter and activate it.

conda create -n rsprompter python=3.10 -y
conda activate rsprompter

Step 2: Install PyTorch2.1.x.

Linux/Windows:

pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121

Or

conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia

Step 3: Install MMCV2.1.x.

pip install -U openmim
mim install mmcv==2.1.0

Step 4: Install other dependencies.

pip install -U transformers==4.38.1 wandb==0.16.3 einops pycocotools shapely scipy terminaltables importlib peft==0.8.2 mat4py==0.6.0 mpi4py

Step 5: [Optional] Install DeepSpeed.

If you want to use DeepSpeed to train the model, you need to install DeepSpeed. The installation method of DeepSpeed can refer to the DeepSpeed official document.

pip install deepspeed==0.13.4

Note: The support for DeepSpeed under the Windows system is not perfect yet, we recommend that you use DeepSpeed under the Linux system.

</details>

Install RSPrompter

Download or clone the RSPrompter repository.

git clone git@github.com:KyanChen/RSPrompter.git
cd RSPrompter

Dataset Preparation

<details open>

Basic Instance Segmentation Dataset

We provide the instance segmentation dataset preparation method used in the paper.

WHU Building Dataset

NWPU VHR-10 Dataset

SSDD Dataset

Note: In the data folder of this project, we provide the instance labels of the above datasets, which you can use directly.

Organization Method

You can also choose other sources to download the data, but you need to organize the dataset in the following format:

${DATASET_ROOT} # Dataset root directory, for example: /home/username/data/NWPU
├── annotations
│   ├── train.json
│   ├── val.json
│   └── test.json
└── images
    ├── train
    ├── val
    └── test

Note: In the project folder, we provide a folder named data, which contains examples of the organization method of the above datasets.

Other Datasets

If you want to use other datasets, you can refer to MMDetection documentation to prepare the datasets.

</details>

Model Training

SAM-based Model

Config File and Main Parameter Parsing

We provide the configuration files of the SAM-based models used in the paper, which can be found in the configs/rsprompter folder. The Config file is completely consistent with the API interface and usage method of MMDetection. Below we provide an analysis of some of the main parameters. If you want to know more about the meaning of the parameters, you can refer to the MMDetection documentation.

<details open>

Parameter Parsing:

</details>

Single Card Training

python tools/train.py configs/rsprompter/xxx.py  # xxx.py is the configuration file you want to use

Multi-card Training

sh ./tools/dist_train.sh configs/rsprompter/xxx.py ${GPU_NUM}  # xxx.py is the configuration file you want to use, GPU_NUM is the number of GPUs used

Other Instance Segmentation Models

<details open>

If you want to use other instance segmentation models, you can refer to MMDetection to train the models, or you can put their Config files in the configs folder of this project, and then train them according to the above methods.

</details>

Model Testing

Single Card Testing:

python tools/test.py configs/rsprompter/xxx.py ${CHECKPOINT_FILE}  # xxx.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use

Multi-card Testing:

sh ./tools/dist_test.sh configs/rsprompter/xxx.py ${CHECKPOINT_FILE} ${GPU_NUM}  # xxx.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use, GPU_NUM is the number of GPUs used

Note: If you need to get the visualization results, you can uncomment default_hooks-visualization in the Config file.

Image Prediction

Single Image Prediction:

python demo/image_demo.py ${IMAGE_FILE}  configs/rsprompter/xxx.py --weights ${CHECKPOINT_FILE} --out-dir ${OUTPUT_DIR}  # IMAGE_FILE is the image file you want to predict, xxx.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use, OUTPUT_DIR is the output path of the prediction result

Multi-image Prediction:

python demo/image_demo.py ${IMAGE_DIR}  configs/rsprompter/xxx.py --weights ${CHECKPOINT_FILE} --out-dir ${OUTPUT_DIR}  # IMAGE_DIR is the image folder you want to predict, xxx.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use, OUTPUT_DIR is the output path of the prediction result

Common Problems

<details open>

We have listed some common problems and their corresponding solutions here. If you find that some problems are missing, please feel free to provide PR to enrich this list. If you cannot get help here, please use issue to seek help. Please fill in all the required information in the template, which will help us locate the problem faster.

1. Do I need to install MMDetection?

We recommend that you do not install MMDetection because we have made some modifications to the code of MMDetection, which may cause errors in the code if you install MMDetection. If you encounter an error that the module has not been registered, please check:

2. How to evaluate the model after training with DeepSpeed?

We recommend that you use DeepSpeed to train the model because DeepSpeed can greatly improve the training speed of the model. However, the training method of DeepSpeed is different from that of MMDetection, so after using DeepSpeed to train the model, you need to use the method of MMDetection to evaluate it. Specifically, you need to:

python zero_to_fp32.py . $SAVE_CHECKPOINT_NAME -t $CHECKPOINT_DIR  # $SAVE_CHECKPOINT_NAME is the name of the converted model, $CHECKPOINT_DIR is the name of the model trained by DeepSpeed

3. About resource consumption

Here we list the resource consumption of using different models for your reference.

Model NameBackboneImage SizeGPUBatch SizeAcceleration StrategySingle Card Memory Usage
SAM-seg (Mask R-CNN)ViT-B/161024x10241x RTX 4090 24G8AMP FP1619.4 GB
SAM-seg (Mask2Former)ViT-B/161024x10241x RTX 4090 24G8AMP FP1621.5 GB
SAM-detResNet501024x10241x RTX 4090 24G8FP3216.6 GB
RSPrompter-anchorViT-B/161024x10241x RTX 4090 24G2AMP FP1620.9 GB
RSPrompter-queryViT-B/161024x10241x RTX 4090 24G1AMP FP16OOM
RSPrompter-queryViT-B/161024x10248x NVIDIA A100 40G1ZeRO-239.6 GB
RSPrompter-anchorViT-B/16512x5128x RTX 4090 24G4AMP FP1620.9 GB
RSPrompter-queryViT-B/16512x5128x RTX 4090 24G2ZeRO-221.1 GB

Note: Low-resolution input images can effectively reduce the memory usage of the model, but their actual performance has not been verified. For details, please refer to Config file.

4. Solution to dist_train.sh: Bad substitution

If you encounter the error Bad substitution when running dist_train.sh, please use bash dist_train.sh to run the script.

5. Unable to access and download the model on HuggingFace Spaces

If you are unable to access and download the model on HuggingFace Spaces, please use the download script to download. Please refer to the official processing method.

Here is the translation into English:

6. The segmentation loss is always 0 or results in NaN (Not a Number)

Due to a small batch size leading to unstable training, there are several different solutions below. You can choose any one of them:

  1. Increase the batch size to 2 or 4 (there might be insufficient GPU memory);

  2. Use the gradient accumulation method (modify the optim_wrapper in the Config file):

optim_wrapper = dict(
    type='AmpOptimWrapper',
    dtype='float16', # Change to 'bfloat16' for more stability
    optimizer=dict(
        type='AdamW',
        lr=base_lr,
        weight_decay=0.05),
    accumulative_counts=4  # Additional configuration needed, change to 4 or other numbers greater than 1
)
  1. Cancel the sine and cosine transformation in the Prompter during decoding (modify with_sincos=False in the Config file);

  2. Use a peft configuration with an input image size of 512 and increase the batch size.

</details>

Acknowledgement

This project is developed based on the MMDetection project. Thanks to the developers of the MMDetection project.

Citation

If you use the code or performance benchmarks of this project in your research, please refer to the bibtex below to cite RSPrompter.

@article{chen2024rsprompter,
  title={RSPrompter: Learning to prompt for remote sensing instance segmentation based on visual foundation model},
  author={Chen, Keyan and Liu, Chenyang and Chen, Hao and Zhang, Haotian and Li, Wenyuan and Zou, Zhengxia and Shi, Zhenwei},
  journal={IEEE Transactions on Geoscience and Remote Sensing},
  year={2024},
  publisher={IEEE}
}

License

This project is licensed under the Apache 2.0 license.

Contact

If you have any other questions❓, please contact us in time 👬