Home

Awesome

Sam2Rad: A Segmentation Model for Medical Images with Learnable Prompts

Requirements

  1. Clone the Repository
git clone https://github.com/aswahd/SamRadiology.git
cd sam2rad
  1. Set Up a Virtual Environment It’s recommended to use a virtual environment to manage dependencies.
python3 -m venv .venv
source .venv/bin/activate
  1. Install Dependencies
pip install -r requirements.txt
  1. Download Pre-trained Weights Download the pre-trained weights from the official SAM repository and place them in the weights directory:

Quickstart

File structure:

root
├── Train
│   ├── imgs
            ├── 1.png
            ├── 2.png
            ├── ...
            |
│   └── gts
            ├── 1.png
            ├── 2.png
            ├── ...
└── Test
    ├── imgs
            ├── 1.png
            ├── 2.png
            ├── ...
    └── gts
            ├── 1.png
            ├── 2.png
            ├── ...

Download Sample Dataset:

Models

Sam2Rad supports various image encoders and mask decoders, allowing flexibility in model architecture.

Supported Image Encoders

All supported image encoders are available in the sam2rad/encoders/build_encoder.py.

Supported Mask Decoders

All supported mask decoders are available in the sam2rad/decoders/build_decoder.py.

Training

Prepare a configuration file for training. Here is an example configuration file for training on the ACDC dataset:

image_size: 1024
image_encoder: "sam2_tiny_hiera_adapter"
mask_decoder: "sam2_lora_mask_decoder"
sam_checkpoint: "weights/sam2_hiera_tiny.pt"
wandb_project_name: "ACDC"

dataset:
  name: acdc
  root: /path/to/your/dataset
  image_size: 1024
  split: 0.0526 # 0.0263 # training split
  seed: 42
  batch_size: 4
  num_workers: 4
  num_classes: 3
  num_tokens: 10

training:
  max_epochs: 200
  save_path: checkpoints/ACDC

inference:
  name: acdc_test
  root: /path/to/your/test_data
  checkpoint_path: /path/to/your/checkpoint 
source .venv/bin/activate
CUDA_VISIBLE_DEVICES=0 python train.py --config /path/to/your/config.yaml

Replace /path/to/your/config.yaml with the actual path to your configuration file.

Evaluation

Ensure your configuration file points to the correct checkpoint and data paths:

inference:
  model_checkpoint: checkpoints/your_model_checkpoint
  input_images: /path/to/your/test_images
  output_dir: /path/to/save/segmentation_results
  image_size: 1024

Run the evaluation script:

python sam2rad/evaluation/eval_bounding_box.py --config /path/to/your/config.yaml
python sam2rad/evaluation/eval_prompt_learner.py --config /path/to/your/config.yaml

Citation

If you use Sam2Rad in your research, please consider citing our paper:

@article{wahd2024sam2radsegmentationmodelmedical,
  title={Sam2Rad: A Segmentation Model for Medical Images with Learnable Prompts},
  author={Assefa Seyoum Wahd and Banafshe Felfeliyan and Yuyue Zhou and Shrimanti Ghosh and Adam McArthur and Jiechen Zhang and Jacob L. Jaremko and Abhilash Hareendranathan},
  year={2024},
  eprint={2409.06821},
  archivePrefix={arXiv},
  primaryClass={cs.CV},
  url={https://arxiv.org/abs/2409.06821},
}