Home

Awesome

<hr> <h1 align="center"> MambaRoll <br> <sub>Physics-Driven Autoregressive State Space Models for Medical Image Reconstruction</sub> </h1> <div align="center"> <a href="https://bilalkabas.github.io/" target="_blank">Bilal&nbsp;Kabas</a><sup>1,2</sup> &ensp; <b>&middot;</b> &ensp; <a href="https://github.com/fuat-arslan" target="_blank">Fuat&nbsp;Arslan</a><sup>1,2</sup> &ensp; <b>&middot;</b> &ensp; <a href="https://github.com/Valiyeh" target="_blank">Valiyeh&nbsp;A. Nezhad</a><sup>1,2</sup> &ensp; <b>&middot;</b> &ensp; <a href="https://scholar.google.com/citations?hl=en&user=_SujLxcAAAAJ" target="_blank">Saban&nbsp;Ozturk</a><sup>1,2</sup> &ensp; <b>&middot;</b> &ensp; <a href="https://kilyos.ee.bilkent.edu.tr/~saritas/" target="_blank">Emine U.&nbsp;Saritas</a><sup>1,2</sup> &ensp; <b>&middot;</b> &ensp; <a href="https://kilyos.ee.bilkent.edu.tr/~cukur/" target="_blank">Tolga&nbsp;ร‡ukur</a><sup>1,2</sup> &ensp;

<span></span>

<sup>1</sup>Bilkent University โ€ƒ <sup>2</sup>UMRAM <br>

</div> <hr> <h3 align="center">[<a href="https://arxiv.org/abs/2412.09331">arXiv</a>]</h3>

Official PyTorch implementation of MambaRoll, a novel physics-driven autoregressive state space model for enhanced fidelity in medical image reconstruction. In each cascade of an unrolled architecture, MambaRoll employs an autoregressive framework based on physics-driven state space modules (PSSM), where PSSMs efficiently aggregate contextual features at a given spatial scale while maintaining fidelity to acquired data, and autoregressive prediction of next-scale feature maps from earlier spatial scales enhance capture of multi-scale contextual features

<p align="center"> <img src="figures/architecture.png" alt="architecture"> </p>

โš™๏ธ Installation

This repository has been developed and tested with CUDA 12.2 and Python 3.12. Below commands create a conda environment with required packages. Make sure conda is installed.

conda env create --file requirements.yaml
conda activate mambaroll
<details> <summary>[Optional] Setting Up Faster and Memory-efficient Radon Transform</summary><br>

We use a faster (over 100x) and memory-efficient (~4.5x) implementation of Radon transform (torch-radon). To install, run commands below within mambaroll conda environment.

git clone https://github.com/matteo-ronchetti/torch-radon.git
cd torch-radon
python setup.py install
</details>

๐Ÿ—‚๏ธ Prepare dataset

MambaRoll supports reconstructions for MRI and CT modalities. Therefore, we have two dataset classes: (1) MRIDataset and (2) CTDataset in datasets.py.

1. MRI dataset folder structure

MRI dataset has subfolders for each undersampling rate, e.g. us4x, us8x, etc. There is a separate .npz file for each contrast.

<details> <summary>Details for npz files</summary><br>

A <contrast>.npz file has the following keys:

Variable keyDescriptionShape
image_fsCoil-combined fully-sampled MR image.n_slices x 1 x height x width
image_usMulti-coil undersampled MR image.n_slices x n_coils x height x width
us_masksK-space undersampling masks.n_slices x 1 x height x width
coilmapsCoil sensitivity maps.n_slices x n_coils x height x width
subject_idsCorresponding subject ID for each slice.n_slices
us_factorUndersampling factor.(Single integer value)
</details>
fastMRI/
โ”œโ”€โ”€ us4x/
โ”‚   โ”œโ”€โ”€ train/
โ”‚   โ”‚   โ”œโ”€โ”€ T1.npz
โ”‚   โ”‚   โ”œโ”€โ”€ T2.npz
โ”‚   โ”‚   โ””โ”€โ”€ FLAIR.npz
โ”‚   โ”œโ”€โ”€ test/
โ”‚   โ”‚   โ”œโ”€โ”€ T1.npz
โ”‚   โ”‚   โ”œโ”€โ”€ T2.npz
โ”‚   โ”‚   โ””โ”€โ”€ FLAIR.npz
โ”‚   โ””โ”€โ”€ val/
โ”‚       โ”œโ”€โ”€ T1.npz
โ”‚       โ”œโ”€โ”€ T2.npz
โ”‚       โ””โ”€โ”€ FLAIR.npz
โ”œโ”€โ”€ us8x/
โ”‚   โ”œโ”€โ”€ train/...
โ”‚   โ”œโ”€โ”€ test/...
โ”‚   โ””โ”€โ”€ val/...
โ”œโ”€โ”€ ...

2. CT dataset folder structure

Each split in CT dataset contains images with different undersampling rates.

<details> <summary>Details for npz files</summary><br>

image_fs.npz files have the fully-sampled data with the following key:

Variable keyDescriptionShape
image_fsFully-sampled CT image.n_slices x 1 x height x width

A us<us_factor>x.npz file has the following keys:

Variable keyDescriptionShape
image_usUndersampled CT image.n_slices x 1 x height x width
sinogram_usCorresponding sinograms for undersampled CTs.n_slices x 1 x detector_positions x n_projections
projection_anglesProjection angles at which the Radon transform performed on fully-sampled images to obtain undersampled ones.n_slices x n_projections
subject_idsCorresponding subject ID for each slice.n_slices
us_factorUndersampling factor.(Single integer value)
</details>
lodopab-ct/
โ”œโ”€โ”€ train/
โ”‚   โ”œโ”€โ”€ image_fs.npz
โ”‚   โ”œโ”€โ”€ us4x.npz
โ”‚   โ””โ”€โ”€ us6x.npz
โ”œโ”€โ”€ test/
โ”‚   โ”œโ”€โ”€ image_fs.npz
โ”‚   โ”œโ”€โ”€ us4x.npz
โ”‚   โ””โ”€โ”€ us6x.npz
โ””โ”€โ”€ val/
    โ”œโ”€โ”€ image_fs.npz
    โ”œโ”€โ”€ us4x.npz
    โ””โ”€โ”€ us6x.npz

๐Ÿƒ Training

Run the following command to start/resume training. Model checkpoints are saved under logs/$EXP_NAME/MambaRoll/checkpoints directory, and sample validation images are saved under logs/$EXP_NAME/MambaRoll/val_samples. The script supports both single and multi-GPU training. By default, it runs on a single GPU. To enable multi-GPU training, set --trainer.devices argument to the list of devices, e.g. 0,1,2,3. Be aware that multi-GPU training may lead to convergence issues. Therefore, it is only recommended during inference/testing.

python main.py fit \
    --config $CONFIG_PATH \
    --trainer.logger.name $EXP_NAME \
    --model.mode $MODE \
    --data.dataset_dir $DATA_DIR \
    --data.contrast $CONTRAST \
    --data.us_factor $US_FACTOR \
    --data.train_batch_size $BS_TRAIN \
    --data.val_batch_size $BS_VAL \
    [--trainer.max_epoch $N_EPOCHS] \
    [--ckpt_path $CKPT_PATH] \
    [--trainer.devices $DEVICES]

<details> <summary>Example Commands</summary>

MRI reconstruction using fastMRI dataset:

python main.py fit \
  --config configs/config_fastmri.yaml \
  --trainer.logger.name fastmri_t1_us8x \
  --data.dataset_dir ../datasets/fastMRI \
  --data.contrast T1 \
  --data.us_factor 8 \
  --data.train_batch_size 1 \
  --data.val_batch_size 16 \
  --trainer.devices [0]

CT reconstruction using LoDoPaB-CT dataset:

python main.py fit \
  --config configs/config_ct.yaml \
  --trainer.logger.name ct_us4x \
  --data.dataset_dir ../datasets/lodopab-ct/ \
  --data.us_factor 4 \
  --data.train_batch_size 1 \
  --data.val_batch_size 16 \
  --trainer.devices [0]
</details>

Argument descriptions

ArgumentDescription
--configConfig file path. Available config files: 'configs/config_fastmri.yaml' and 'configs/config_ct.yaml'
--trainer.logger.nameExperiment name.
--model.modeMode depending on data modality. Options: 'mri', 'ct'.
--data.dataset_dirData set directory.
--data.contrastSource contrast, e.g. 'T1', 'T2', ... for MRI. Should match the folder name for that contrast.
--data.us_factorUndersampling factor, e.g 4, 8.
--data.train_batch_sizeTrain set batch size.
--data.val_batch_sizeValidation set batch size.
--trainer.max_epoch[Optional] Number of training epochs (default: 50).
--ckpt_path[Optional] Model checkpoint path to resume training.
--trainer.devices[Optional] Device or list of devices. For multi-GPU set to the list of device ids, e.g 0,1,2,3 (default: [0]).

๐Ÿงช Testing

Run the following command to start testing. The predicted images are saved under logs/$EXP_NAME/MambaRoll/test_samples directory. By default, the script runs on a single GPU. To enable multi-GPU testing, set --trainer.devices argument to the list of devices, e.g. 0,1,2,3.

python main.py test \
    --config $CONFIG_PATH \
    --model.mode $MODE \
    --data.dataset_dir $DATA_DIR \
    --data.contrast $CONTRAST \
    --data.us_factor $US_FACTOR \
    --data.test_batch_size $BS_TEST \
    --ckpt_path $CKPT_PATH

Argument descriptions

Some arguments are common to both training and testing and are not listed here. For details on those arguments, please refer to the training section.

ArgumentDescription
--data.test_batch_sizeTest set batch size.
--ckpt_pathModel checkpoint path.

โœ’๏ธ Citation

You are encouraged to modify/distribute this code. However, please acknowledge this code and cite the paper appropriately.

@article{kabas2024mambaroll,
  title={Physics-Driven Autoregressive State Space Models for Medical Image Reconstruction}, 
  author={Bilal Kabas and Fuat Arslan and Valiyeh A. Nezhad and Saban Ozturk and Emine U. Saritas and Tolga ร‡ukur},
  year={2024},
  journal={arXiv:2412.09331}
}

๐Ÿ’ก Acknowledgments

This repository uses code from the following projects:

<hr> Copyright ยฉ 2024, ICON Lab.