Awesome
Conditional Diffusion Models for Weakly Supervised Medical Image Segmentation [MICCAI 2023]
[Paper
]
0. Setup
Environment
the required packages are mostly same as openai/improved_diffusion Clone this repository and navigate to it in your terminal. Then run:
pip install -e .
Dataset
supported 2d datasets: Synapse(128x128), BraTS(224*224) The data folder structure is like:
.
├── ...
├── data
│ ├── brats_patch
│ ├── flair
├── flair
├── training
├── normal
├── tumor
├── seg # label mask
├── validation
1. train DDPM models
for tumor images:
MODEL_FLAGS="--num_channels 128 --num_res_blocks 3 "
DIFFUSION_FLAGS="--diffusion_steps 4000 " # --use_kl True
TRAIN_FLAGS="--lr 1e-4 --batch_size 4" #--schedule_sampler loss-second-moment
python scripts/train_tumor_ddfm.py --save_dir runs/results/diff_brats \
--src_dir ${data_dir} --dataset brats \
$MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
conditional training:
MODEL_FLAGS="--image_size 64 --num_channels 192 --num_res_blocks 3 --learn_sigma True --class_cond True"
DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine --rescale_learned_sigmas False --rescale_timesteps False"
TRAIN_FLAGS="--lr 1e-4 --batch_size 4"
python scripts/train_tumor_ddfm.py --save_dir runs/results/diff_brats \
--src_dir ${data_dir} --dataset brats \
$MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
2. sampling from gaussian noise
The main file is scripts/image_tumor_sample.py, for different dataset, you need:
python scripts/image_tumor_sample.py
3. guided diffusion
Use an additional classifier model to guide the generation of samples, no need to retrain the DDPM model
Training models
Training a classifier
python scripts/classifier_tumor_train.py --save_dir runs/results/guided_diff_brats_cls_resnet \
--src_dir ${data_dir} --lr 3e-4 --batch_size 8 --model_type resnet50
Sampling:
python scripts/classifier_tumor_sample.py --save_dir runs/results/guided_diff_brats_cls_resnet \
--src_dir ${data_dir}
4. Weakly supervised segmentation
MODEL_FLAGS="--num_channels 128 --num_res_blocks 3 --learn_sigma True --class_cond True " # --learn_sigma True --class_cond True
DIFFUSION_FLAGS="--diffusion_steps 4000 "
TRAIN_FLAGS="--lr 3e-4 --batch_size 2"
python scripts/image_p_seg.py \
--save_dir {where_you_save_classifier} \
--model_path {where_you_save_diffusion} \
$MODEL_FLAGS $DIFFUSION_FLAGS -f 0 --batch_size 1 --guided