Awesome
[AAAI2022] UCTransNet
This repo is the official implementation of 'UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer' which is accepted at AAAI2022.
We propose a Channel Transformer module (CTrans) and use it to replace the skip connections in original U-Net, thus we name it 'U-CTrans-Net'.
Online Presentation Video is available for brief introduction.
🔥🔥🔥 For an improved version of UCTransNet, please refer to UDTransNet (Narrowing the semantic gaps in U-Net with learnable skip connections: The case of medical image segmentation ), which achieves higher performance and lower computational cost. 🔥🔥🔥
Requirements
Install from the requirements.txt
using:
pip install -r requirements.txt
Usage
Note: If you have some problems with the code, the issues may help.
1. Data Preparation
1.1. GlaS and MoNuSeg Datasets
The original data can be downloaded in following links:
- MoNuSeg Dataset - Link (Original)
- GLAS Dataset - Link (Original)
Then prepare the datasets in the following format for easy use of the code:
├── datasets
  ├── GlaS
  │  ├── Test_Folder
  │  │  ├── img
  │  │  └── labelcol
  │  ├── Train_Folder
  │  │  ├── img
  │  │  └── labelcol
  │  └── Val_Folder
  │  ├── img
  │  └── labelcol
  └── MoNuSeg
    ├── Test_Folder
    │  ├── img
    │  └── labelcol
    ├── Train_Folder
    │  ├── img
    │  └── labelcol
    └── Val_Folder
    ├── img
    └── labelcol
1.2. Synapse Dataset
The Synapse dataset we used is provided by TransUNet's authors. Please go to https://github.com/Beckschen/TransUNet/blob/main/datasets/README.md for details.
(Optional) 🔥🔥 Using customized datasets.
-
If you want to implement UCTransNet on a customized dataset, the easiest way is to organize the file structure similar to GlaS as described above.
-
Ensure that the images are in the
.jpg
format, and the mask IDs should match the image IDs but with the.png
extension. -
Any inconsistencies in the file structure or naming conventions may result in I/O errors.
2. Training
As mentioned in the paper, we introduce two strategies to optimize UCTransNet.
The first step is to change the settings in Config.py
,
all the configurations including learning rate, batch size and etc. are
in it.
2.1 Jointly Training
We optimize the convolution parameters in U-Net and the CTrans parameters together with a single loss. Run:
python train_model.py
2.2 Pre-training
Our method just replaces the skip connections in U-Net, so the parameters in U-Net can be used as part of pretrained weights.
By first training a classical U-Net using /nets/UNet.py
then using the pretrained weights to train the UCTransNet,
CTrans module can get better initial features.
This strategy can improve the convergence speed and may improve the final segmentation performance in some cases.
3. Testing
3.1. Get Pre-trained Models
Here, we provide pre-trained weights on GlaS and MoNuSeg, if you do not want to train the models by yourself, you can download them in the following links:
- GlaS:https://drive.google.com/file/d/1ciAwb2-0G1pZrt_lgSwd-7vH1STmxdYe/view?usp=sharing
- MoNuSeg: https://drive.google.com/file/d/1CJvHoh3VrPsBn_njZDo6SvJF_yAVe5MK/view?usp=sharing
3.2. Test the Model and Visualize the Segmentation Results
First, change the session name in Config.py
as the training phase.
Then run:
python test_model.py
You can get the Dice and IoU scores and the visualization results.
🔥🔥 The testing results of all classes in Synapse dataset can be downloaded through this link. 🔥🔥
4. Reproducibility
In our code, we carefully set the random seed and set cudnn as 'deterministic' mode to eliminate the randomness. However, there still exsist some factors which may cause different training results, e.g., the cuda version, GPU types, the number of GPUs and etc. The GPU used in our experiments is NVIDIA A40 (48G) and the cuda version is 11.2.
Especially for multi-GPU cases, the upsampling operation has big problems with randomness. See https://pytorch.org/docs/stable/notes/randomness.html for more details.
When training, we suggest to train the model twice to verify wheather the randomness is eliminated. Because we use the early stopping strategy, the final performance may change significantly due to the randomness.
Reference
- UNet++: https://github.com/qubvel/segmentation_models.pytorch
- Attention U-Net: https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets
- MultiResUNet: https://github.com/makifozkanoglu/MultiResUNet-PyTorch
- TransUNet: https://github.com/Beckschen/TransUNet
- Swin-Unet: https://github.com/HuCaoFighting/Swin-Unet
- MedT: https://github.com/jeya-maria-jose/Medical-Transformer
Citations
If this code is helpful for your study, please cite:
@article{UCTransNet,
title={UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-Wise Perspective with Transformer},
volume={36},
url={https://ojs.aaai.org/index.php/AAAI/article/view/20144},
DOI={10.1609/aaai.v36i3.20144},
number={3},
journal={Proceedings of the AAAI Conference on Artificial Intelligence},
author={Wang, Haonan and Cao, Peng and Wang, Jiaqi and Zaiane, Osmar R.},
year={2022},
month={Jun.},
pages={2441-2449}}
Contact
Haonan Wang (haonan1wang@gmail.com)