Home

Awesome

Memory-Efficient 3D Denoising Diffusion Models for Medical Image Processing

We provide the Pytorch implementation of our MIDL 2023 submission "Memory-Efficient 3D Denoising Diffusion Models for Medical Image Processing"

Check out the project page!

The implementation is based on Diffusion Models for Medical Anomaly Detection and openai/guided-diffusion (MIT-License).

Usage

Installation

Install the necessary python packages as defined in environment.yaml. We recommend using mambaforge. You can create the environment using

mamba env create -n your_env_name --file environment.yaml

If you run into problems, you can try using different versions of these packages.

Training & Inference

You can use the run.sh file to run the training as well as the sampling for the different models. We have broken out the relevant parameters on the top of the file, adjust them corresponding to what model you'd like to train or sample from, and what part of the data.

A visualization of the training and sampling process is done using Tensorboard. The model checkpoints will be saved in a subdirectory of the runs folder, generated by tensorboard. To view and compare the different runs, run tensorboard --logdir=runs --bind_all, and open the provided link in your browser.

Data

We probide a torch.utils.data.Dataset implementation for BraTS2020 data, normalized as described in the paper. The implementation assumes that the data is stored in a directory structure like

root
  dataroot
    000001
      brats_train_001_t1_000_w.nii.gz
      brats_train_001_t1ce_000_w.nii.gz
      brats_train_001_t2_000_w.nii.gz
      brats_train_001_flair_000_w.nii.gz
      brats_train_001_seg_000_w.nii.gz
    000002
      brats_train_002_t1_000_w.nii.gz
      brats_train_002_t1ce_000_w.nii.gz
       ...

Copyright

Copyright 2023 Center of Image Analysis and Navigation, University of Basel

Licensed under Apache License, Version 2.0