Home

Awesome

Debiasing Cardiac Imaging with Controlled Latent Diffusion Models

Code repository for paper: Debiasing Cardiac Imaging with Controlled Latent Diffusion Models arxiv pre-print

Please note that this codebase is still a Work in Progress and some parts may be missing.

img

Diffusion model training is based on a very well documented ControlNet repo.

Setup

Create a new conda environment with

conda env create -f environment.yaml
conda activate debiasing-cardiac-mri

ControlNet model training

Before training, ensure that the pretrained Stable diffusion 2.1 model is downloaded: "v2-1_512-ema-pruned.ckpt" After download, run

python tool_add_control_sd21.py

to attach the ControlNet branch to the vanilla Stable Diffusion model.

To train the Stable Diffusion model with ControlNet run:

python train.py

For more detailed instructions, go to (https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md)

Synthetic dataset generation

When your diffusion model fine-tuning is ready, you can generate the unbiased synthetic dataset with the following script

python generate_synthetic_dataset.py

You can choose one of the generation methods:

Interactive generation

You can run a Gradio app to host your model and easily generate images with different prompts and masks.

python gradio_mask2image.py

img

In this application, cardiac masks are stacked as RGB images, thus the raw output from the model is an RGB image as well. Two columns on the right display unstacked ED and ES frames.

Downstream task: Classification model training

The framework allows to train classification models using different cardiac MRI inputs

Supported training inputs:

Fairness

Training

Framework supports weighted sampling method based on one or more sensitive attributes including sex, age and BMI.

Evaluation

Each trained model is evaluated in terms of fairness with following metrics:

Also, each sensitive subgroup is evaluated independently with standard performance metrics.

In the paper, we focus on Balanced Accuracy metric computed for each subpopulation.

Fairness metrics are provided by fairlearn library.

Usage

First, setup the config.yaml with training data type, disease to predict, batch size etc.

Then, you can run training with:

python cardioai/training.py

Optionally, you can specify GPU id for the training or directory to store the experiment logs:

python cardioai/training.py --gpu_id 0 --experiment_dir ./logs

After the training, model evaluation report is generated in the experiment directory.

Results reproduction

To reproduce the experiments from paper and obtain full performance and fairness reports with std values for 8 repeated runs:

./repeated.sh

Structure

The codebase includes several Python scripts and Jupyter notebooks, as well as configuration files and shell scripts.

To run the k-fold cross-validation training, use the kfold.sh script. You need to provide the GPU ID as an argument.

Configuration

You can adjust the parameters of the experiment in the config.yaml file. The parameters include the number of epochs, batch size, learning rate, and model type, among others.