Awesome
CriDiff
The official implement of MICCAI 2024 paper CriDiff: Criss-cross Injection Diffusion Framework via Generative Pre-train for Prostate Segmentation.
Environment Installation
conda create -n CriDiff python=3.8 -y
conda activate CriDiff
git clone https://github.com/LiuTingWed/CriDiff.git
cd CriDiff
pip install -r requirements.txt
Datasets Preparation
Download Datasets
4 datasets need download (NCI-ISBI, ProstateX, Promise12, CCH-TRUSPS) from:
Google Driver | Baidu Driver (6666)
I'm not sure about the copyright status of these datasets. If you are the owner of these datasets, please submit an issue to let me know so that I can remove them accordingly.
Check data branch like this:
The body and detail are generated by extract_boundary/generate_body_detail.py.
Please check this .py for more details.
Download Pre-train Weight
Training & Inference & Evaluation
Generative pretrain
This stage relies on accelerate, please install it and set it up.
python generative_pretrain/train_generator_accelerate.py --dataset_root xxx/DATASET_NAME/images/train
Training
Before training, please check --dataset_root, --cp_condition_net, --cp_stage1, --checkpoint_save_dir in train.py
python -m torch.distributed.launch --nproc_per_node=2 train.py
Why can't the model perform training and validation simultaneously?
The output of diffusion models is related to the randomly sampled noise: different noise leads to different outputs. I have not addressed the issue of fluctuating model performance between the training and validation stages, for detailed descriptions please refer to this link. Therefore, I would recommend saving all checkpoints, and then using two separate GPUs for validation to ensure that others can also achieve consistent performance. Well, I hope someone smarter than me tell me why :-).
Inference
After training, in path --checkpoint_save_dir/job_name will have many .pth file.
Check --loadDir, --loadDer_cp and --dataset_root in infer_allCp_xxxx.py and run it.
Evaluation
The prediction of CriDiff is this link, run eval_dice_iou_hd95_asd/eval.py to eval it.
Thanks
This repository refer to med-seg-diff-pytorch and denoising-diffusion-pytorch. Some very concise diffusion frameworks are helpful to me.
Citation
@inproceedings{liu2024cridiff,
title={CriDiff: Criss-cross Injection Diffusion Framework via Generative Pre-train for Prostate Segmentation},
author={Liu, Tingwei and Zhang, Miao and Liu, Leiye and Zhong, Jialong and Wang, Shuyao and Piao, Yongri and Lu, Huchuan},
booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
pages={102--112},
year={2024},
organization={Springer}
}