Awesome
Directional Connectivity-based Segmentation of Medical Images
Pytorch implementation for CVPR2023 paper "Directional Connectivity-based Segmentation of Medical Images" [paper].
For another simple connectivity-based method, please also check BiconNet
Requirements
Pytorch 1.7.0+cu110
Code Stucture
The main stucture and important files or functions of this repository is as following:
- train.py: main file. Define your parameters, selection of GPU etc.
- solver.py: the training details and testing details.
- connect_loss.py: loss function for DconnNet
* connectivity_matrix: converting segmentation masks to connectivity masks
* Bilateral_voting: bilateral voting and convert connectivity-based output into segmentation map.
data_loader: your data loader files and the SDL weights for your dataset if needed.
model: DconnNet model files
scripts: scripts for training different datasets
Implementation
Train on datasets in the paper.
For training detail of each dataset, please check the general/scripts/
Please store the each dataset in the following path:
Retouch
/retouch
/Cirrus ### device, same for Spectrailis and Topcon
/train
/TRAIN002 ### volume id
/mask ### store .png masks here
/orig ### store .png images here
ISIC2018
The resized data we used and the training pipeline in our paper follows this site with following hyperparameters:
/ISIC2018_npy_all_224_320
/image
/label
Image size: (224, 320)
batch size: 10
epoch: 200
starting lr: 1e-4
Optimizer: Adam with weight decay 1e-8
lr_scheduler: CosineAnnealingWarmRestarts (T_0=15, T_mult=2, eta_min = 0.00001).
However, different settings (e.g., different sizes, pipelines) of ISIC data might yield different results. We do recommend you try DconnNet in your own ISIC data setting, following the guide in the next section.
CHASEDB1
/CHASEDB1
/img
/gt
Train on your own dataset using this code.
- Make your own dataloader.
- Replace your dataloader in main() function of
train.py
. If need k-fold validation, use exp_id to specify your sub-folds. - Specify your network setting in
train.py
- Run:
python train.py
Train DconnNet on your own codes.
Important: please follow these steps to ensure you get a correct implementation
-
Get our model files from
/model
-
In the training phase, please use
connect_loss.py
as the loss function- for single-class, use
connect_loss.single_class_forward
- for general multi-class,use
connect_loss.multi_class_forward
- for single-class, use
-
In the testing phase, please follow our official procedure in
test_epoch
of/solver.py
based on the number of your classes.- for single-class, we get the final predictions by
sigmoid --> threshold --> Bilateral_voting
- for general multi-class, we get the final predictions by
Bilateral_voting --> topK (softmax + topK)
- you might also need to create two variable
hori_translation
andverti_translation
in this step for matrix shifting purpose, you can follow the codes or customize your own shifting methods.
- for single-class, we get the final predictions by
Notice of the codes
Please always make sure the dimenstion of your data is correct. For example, in connect_loss.py
, we specified the shape of each each in the comment. When there is issue, please always check the dimension first.
Pretrained model
The pretrained model and predictions can be downloaded at here
If use SDL loss:
Please pre-calculate the mask size distribution and save it as a .npy file (i.e., the pos_cnt.npy in the Solver function) with the shape of (C, N) where C is the class number, and N is the sample number (number of images). For example, index (1,10) stands for the mask size (pixel count) of the second class in image 11.
Citation
If you find this work useful in your research, please consider citing:
Z. Yang and S. Farsiu, "Directional Connectivity-based Segmentation of Medical Images," in CVPR, 2023, pp. 11525-11535.