Awesome
Pyramid Transformer Net-3D (PTNet3D)
Public code for 3D version of Pyramid Transformer Network (PTNet3D). Our PTNet achieve consistently better performance in high-resolution and longitudinal infant brain MRI synthesis than previous SOTA methods.
Example on synthesizing 3-month-old infant T2w conditioned on T1w:
Paper link:
Usage:
Training (quick start):
python train.py --name YourProjName --checkpoints_dir YourModelPath --dataroot YourDatasetPath
You can change more hyperparameters in the option scripts. For a complete set of options, please check option scripts in ./options
Data structure:
./opt.dataroot
- train_A # your source domain scans
- train_B # your target domain scans
- test_A # will be used for inference
the model is trained to convert modality in train_A to modality in train_B, please make sure all scans are well-coregistered and artifacts-free. Scans in train_A and train_B shall have the same file name.
Inference:
python test.py --name YourProjName --checkpoints_dir YourModelPath --dataroot YourDatasetPath --whichmodel YourModelName
the model specified by opt.whichmodel will be used to convert the scans in /opt.dataroot/test_A. opt.whichmodel should be the full name of stored checkpoint.
Training tips:
The experiments in our paper were conducted using 2D slice or 3D block (64x64x64). Using the default configuration, a GPU with 24GB VRAM should be able to hold a mini-batch up to 4 (64x64x64 blocks).
In 2D experiments, we just resliced volumetric scan by taking the axial slice. It resulted in ~200 slices per scan.
In 3D experiments, we first use sliding window sampling to sample 3D blocks from the whole brain (non-zero region) with overlapping. By doing that, we generated over 200 ROIs per scans. Therefore, in training, we trained the model for 10 epochs on dHCP and 6 epochs on BCP because we have ~ 50,000 ROIs on each datasets.
In this repo, to simplify the sampling, we choose random sampling in non-zero region of the volumetric scan instead of excessive overlapping sampling (details in ./data/aligned_dataset.py). This results in ONE ROI per scan. Therefore, you should use a much larger number epochs for training. For instance, if we re-train the model on dHCP dataset, we should use 2,000 epochs instead of 10 epochs for training because in each epoch, there are only 291 randomly-sampled ROIs.
A tip for selecting epoch numbers: epoch num = 500,000 / number of volumetric scans you have. For example, if you have 1,000 scans, you can use 500 epochs (250 for niter in train_options.py and 250 for niter_decay in train_options.py).
Citation:
If you use our code in your research, please cite our paper: PTNet3D: A 3D High-Resolution Longitudinal Infant Brain MRI Synthesizer Based on Transformers
@ARTICLE{9774943,
author={Zhang, Xuzhe and He, Xinzi and Guo, Jia and Ettehadi, Nabil and Aw, Natalie and Semanek, David and Posner, Jonathan and Laine, Andrew and Wang, Yun},
journal={IEEE Transactions on Medical Imaging},
title={PTNet3D: A 3D High-Resolution Longitudinal Infant Brain MRI Synthesizer Based on Transformers},
year={2022},
volume={},
number={},
pages={1-1},
doi={10.1109/TMI.2022.3174827}}