Home

Awesome

Plug-and-Play Learned Proximal Trajectory for 3D Sparse-View X-Ray Computed Tomography

Code for the paper "Plug-and-Play Learned Proximal Trajectory for 3D Sparse-View X-Ray Computed Tomography".

Pre-requisites

You will need the following libraries:

The torch and TIGRE code is run with cuda 11.3.

pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113

Instructions on how to install the TIGRE toolbox for 3D X-ray Cone-Beam Computed Tomography can be found here https://github.com/CERN/TIGRE. In this code we use the version 2.1.0.

Learned Proximal Trajectory

Saving a pre-defined trajectory

Running this command will save a pre-defined optimization trajectory on the Walnut-CBCT dataset with $\eta = 1/L$, $\lambda=5$ and $K=200$ iterations. You need to specify:

python -m tigre_rc  \
    --dataset_name='walnut' \
    --input_dir='<input_dir>' \
    --input_file='dataset_50p.csv' \
    --split_set='train' \
    --acquisition_id='8' \
    --output_base='train' \
    --output_dir='<output_dir>' \
    --init_mode='FDK' \
    --ideal_denoiser \
    --num_full_proj=1200 \
    --num_proj=50 \
    --block_size=50 \
    --init_lr=1.0 \
    --lambda_reg=5.0 \
    --fista \
    --reconstruction_alg='proximal_splitting' \
    --n_iter=200 \
    --save_trajectory \
    --trajectory_output_dir=<trajectory_output_dir> \

Training

Running this command will train a Learned Proximal Operator on the previously saved optimization trajectory. This is the configuration used for the results in Tab.1 . You need to specify:

Options:

This will save a summary.csv logs file as well as images during training and a checkpoint last.pth.tar to use during inference.

python -m train_trajectory  \
    --batch_size=32 \
    --patch_size=256 \
    --model_name='unet' \
    --no-residual_learning \
    --skip_connection \
    --encoder_channels 32 32 64 64 128 \
    --decoder_channels 64 64 32 32 \
    --scale_skip_connections 1 1 1 1 \
    --upscaling_layer='transposeconv_nogroup' \
    --input_dir=<trajectory_output_dir> \
    --output_base='train' \
    --output_dir=<output_dir> \
    --activation='SiLU' \
    --loss='mse' \
    --dropout=0.5 \
    --bias_free \
    --clip_grad_norm=1e-2 \
    --ema \
    --ema_decay=0.999 \
    --timestep_dim=128 \
    --stem_size=5 \
    --num_proj=50 \
    --num_epochs=1000 \
    --num_warmup_epochs=5 \
    --lr_scheduler='CosineAnnealingLR' \
    --init_lr=1e-4 \
    --min_lr=1e-8 \
    --weight_decay=1e-6 \
    --dataset_size=6400 \
    --log_interval=100 \
    --dataset_name='walnut' \
    --validation_interval=50 \
    --final_activation='Identity' \
    --augmentation \
    --optimizer='adam' \
    --drop_last \
    --num_workers=8 \
    --pin_memory \
    --memmap \
    --amp \

Inference

Running this command will launch a PnP-PGD algorithm with our Learned Proximal Trajectory procedure. We use the parameters $\eta=1/L$, $\lambda=10$ and $K=500$. To run this command, you will need two available GPUs, for speed and also because communication between torch and TIGRE does not work very well on a single GPU. You will need to specifiy:

python -m tigre_rc \
    --benchmark \
    --dataset_name='walnut' \
    --input_dir=<input_dir> \
    --input_file='dataset_50p.csv' \
    --split_set='validation' \
    --output_base='train' \
    --output_dir=<output_dir> \
    --num_proj=50 \
    --reconstruction_alg='proximal_splitting' \
    --init_lr=1. \
    --init_mode='FDK' \
    --pnp \
    --fista \
    --reg_checkpoint=<path/to/last.pth.tar> \
    --ema \
    --stopping='num_iter' \
    --timesteps \
    --ckpt_fidelity \
    --ckpt_reconstruction \
    --ckpt_criterion \
    --lambda_reg=10.0 \
    --block_size=50 \
    --n_iter=501 \
    --reg_batch_size=16 \
    --output_suffix='something' \

Evaluation

You evaluate PSNR and SSIM on a PnP experiment by running the following command. You will need to specifiy:

This will generate a results_<split_set>.csv file with the metrics for each reconstruction and the average over the set.

python -m test  \
    --input_dir=<input_dir> \
    --input_file='dataset_50p.csv' \
    --dataset_name='walnut' \
    --file_format='{acquisition_id}_criterion.raw' \
    --split_set='validation' \
    --memmap \
    --bench_location=<output_dir> \

PnP-$\alpha$PGD

The procedure to train a Gaussian denoising network for PnP-PGD and run the inference is very similar to our scheme.

Training

python -m train_postp  \
    --batch_size=32 \
    --patch_size=256 \
    --model_name='unet' \
    --residual_learning \
    --skip_connection \
    --encoder_channels 32 32 64 64 128 \
    --decoder_channels 64 64 32 32 \
    --scale_skip_connections 1 1 1 1 \
    --upscaling_layer='transposeconv_nogroup' \
    --input_dir=<input_dir> \
    --input_file='dataset_50p.csv' \
    --output_base='train' \
    --output_dir=<output_dir> \
    --activation='SiLU' \
    --dropout=0.5 \
    --bias_free \
    --clip_grad_norm=1e-2 \
    --pnp \
    --ema \
    --axial_center_crop \
    --center_crop \
    --stem_size=5 \
    --num_proj=50 \
    --init_lr=1e-4 \
    --min_lr=1e-8 \
    --weight_decay=1e-6 \
    --dataset_size=6400 \
    --num_epochs=1000 \
    --num_warmup_epochs=5 \
    --log_interval=100 \
    --dataset_name='walnut' \
    --validation_interval=10 \
    --final_activation='Identity' \
    --augmentation \
    --loss='mse' \
    --optimizer='adam' \
    --lr_scheduler='CosineAnnealingLR' \
    --drop_last \
    --num_workers=8 \
    --pin_memory \
    --memmap \

Inference

python -m tigre_rc \
    --benchmark \
    --dataset_name='walnut' \
    --input_dir=<input_dir> \
    --input_file='dataset_50p.csv' \
    --split_set='validation' \
    --output_base='train' \
    --output_dir=<output_dir> \
    --num_proj=50 \
    --reconstruction_alg='proximal_splitting' \
    --init_lr=1. \
    --init_mode='FDK' \
    --pnp \
    --fista \
    --reg_checkpoint=<path/to/last.pth.tar> \
    --ema \
    --stopping='num_iter' \
    --ckpt_fidelity \
    --ckpt_reconstruction \
    --ckpt_criterion \
    --lambda_reg=10.0 \
    --block_size=50 \
    --n_iter=501 \
    --reg_batch_size=16 \
    --output_suffix='something' \

Test

python -m test  \
    --input_dir=<input_dir> \
    --input_file='dataset_50p.csv' \
    --dataset_name='walnut' \
    --file_format='{acquisition_id}_criterion.raw' \
    --split_set='validation' \
    --memmap \
    --bench_location=<output_dir> \