Awesome
PET Guided Attention Network for segmentation of lung tumors from PET/CT images that accounts for missing PET images
This project is my part of my Master Thesis at ETH Zurich, Switzerland.
Please find the complete report here <br> Master Thesis report
The work was accepted in “ICML 2020 Global Health workshop” for a poster presentation. <br> Extended abstract <br> Poster <br>
The work was also accepted at DAGM-GCPR conference <br>
Introduction
PET/CT imaging is the gold standard for the diagnosis and staging of lung cancer. However, especially in healthcare systems with limited resources, costly PET/CT images are often not readily available. Conventional machine learning models either process CT or PET/CT images but not both. Models designed for PET/CT images are hence restricted by the number of PET images, such that they are unable to additionally leverage CT-only data. In this work, we apply the concept of visual soft attention to efficiently learn a model for lung cancer segmentation from only a small fraction of PET/CT scans and a larger pool of CT-only scans. We show that our model is capable of jointly processing PET/CT as well as CT-only images, which performs on par with the respective baselines whether or not PET images are available at test time. We then demonstrate that the model learns efficiently from only a few PET/CT scans in a setting where mostly CT-only data is available, unlike conventional models.
We achieve this robustness of the model by allowing PET images to act as an optional guide in addition to the gating/query signal (which is usually the encoded feature representation) of a simple attention gate.
The following Figures show the architecture of the network and the attention gate respectively.
<img src = "architectures/Architecture.png" width=80% align="center" /> <img src = "architectures/Gate.png" width=70% align="center" />Prerequisites
The Baseline models require conda environment. Main dependencies are:
- PyTorch (GPU)
- cv2 (OpenCV)
- cc3d (See [https://github.com/seung-lab/connected-components-3d])
Further, dependencies can be found in the environemnt.yml
Conda environment
Create the conda environment
conda env create -f environment.yml
Usage
Data pre-processing
To pre-process the data and generate data augmentations run
> python data/data_preprocessing.py --save_dir tumor_data/ --data_aug
usage: Data pre-processing for PAG model [-h] [--anim] [--save_dir SAVE_DIR]
[--data_aug]
optional arguments:
-h, --help show this help message and exit
--anim Generate animations of pre-processed images?
--save_dir SAVE_DIR Where to save data?
--data_aug Should perform data augmentation?
Train + Valid + Test
The configurations are set according to defined in config.json file
> python -u loader.py --config config.json --train # To train the model
> python -u loader.py --config config.json --valid # To generate predictions on validation data
> python -u loader.py --config config.json --test # To generate predictions on test data
To submit the job on LSF batching system. The model reuqires GPU for training, validation and testing
> bsub -o output.txt -W 48:00 -R "rusage[mem=15000,ngpus_excl_p=1]" -R "select[gpu_mtotal0>=14000]" python -u loader.py --train
usage: Segmentation of lung tumors PAG model and baselines [-h]
[--config CONFIG]
[--method METHOD]
[--train] [--valid]
[--test]
[--exp_name EXP_NAME]
[--save_dir SAVE_DIR]
[--ckpt_file CKPT_FILE]
[--fold FOLD]
[--n_folds N_FOLDS]
[--n_epochs N_EPOCHS]
[--lr LR]
optional arguments:
-h, --help show this help message and exit
--config CONFIG Configuration file
--method METHOD Which model?
--train Train the model?
--valid Train the model?
--test Train the model?
--exp_name EXP_NAME Name of the experiment
--save_dir SAVE_DIR Where do you want to save ?
--ckpt_file CKPT_FILE
Optional path for checkpoint
--fold FOLD Fold no. in CV experiments
--n_folds N_FOLDS No. of folds
--n_epochs N_EPOCHS Number of epochs
--lr LR learning rate
Analyze the results
Generate the metrics for experiments
To calculate metrics on validation dataset.
python validate/validate.py --dirs multi-baselines/ct/cv0/ --detect --analyze --valid ;
usage: Data analysis [-h] --dirs DIRS [DIRS ...] [--anim] [--analyze]
[--detect] [--train] [--valid] [--test] [--config CONFIG]
[--clusters CLUSTERS]
optional arguments:
-h, --help show this help message and exit
--dirs DIRS [DIRS ...]
All the directories you want
--anim Do you want animations?
--analyze Calculate metrics?
--detect Generate metrics?
--train Analyze for training data?
--valid Analyze for validation data?
--test Analyze for test data?
--config CONFIG json file
--clusters CLUSTERS Name of the file to store summary statistics
To generate the CV metrics for a directory
python validate/cv_analyze.py --dir multi-baselines/ct/;
PAG model results when complete PET/CT data
python validate/dir_analyze.py --dirs \
multi-baselines/ct/ \
multi-baselines/ct_attn/ \
pag/less/complete/PAG-ct/ \
pag/less/complete/PAG-ct-pet/ \
pag/fractions/complete/ \
multi-baselines/bimodal_attn/ --save_dir metrics_plots_v1/
PAG model results fractions
python validate/rstr_analyze.py --save_dir metrics_plots_v1/