Home

Awesome

DG-ADR (WACV 2025)

Official Code implementation of Divergent Domains, Convergent Grading: Enhancing Generalization in Diabetic Retinopathy Grading

Accepted at WACV 2025

Authors: Sharon Chokuwa and Muhammad Haris Khan

<img src="figures/synthetic_images.png" alt="picture alt" width="100%">

Abstract

Diabetic Retinopathy (DR) constitutes 5% of global blindness cases. While numerous deep learning approaches have sought to enhance traditional DR grading methods, they often falter when confronted with new out-of-distribution data thereby impeding their widespread application. In this study, we introduce a novel deep learning method for achieving domain generalization (DG) in DR grading and make the following contributions. First, we propose a new way of generating image-to-image diagnostically relevant fundus augmentations conditioned on the grade of the original fundus image. These augmentations are tailored to emulate the types of shifts in DR datasets thus increase the model's robustness. Second, we address the limitations of the standard classification loss in DG for DR fundus datasets by proposing a new DG-specific loss – domain alignment loss; which ensures that the feature vectors from all domains corresponding to the same class converge onto the same manifold for better domain generalization. Third, we tackle the coupled problem of data imbalance across DR domains and classes by proposing to employ Focal loss which seamlessly integrates with our new alignment loss. Fourth, due to inevitable observer variability in DR diagnosis that induces label noise, we propose leveraging self-supervised pretraining. This approach ensures that our DG model remains robust against early susceptibility to label noise, even when only a limited dataset of non-DR fundus images is available for pretraining. Our method demonstrates significant improvements over the strong Empirical Risk Minimization baseline and other recently proposed state-of-the-art DG methods for DR grading.

Our Method

<img src="figures/algorithm.png" alt="picture alt" width="60%">

Getting Started

Datasets

  1. For our main OOD generalization evaluation, our study utilized the following datasets: DeepDR, Messidor-2, IDRID, APTOS, FGADR, RLDR and DDR. The table below summarizes the OOD datasets sizes and origins used for our DG model:
DatasetDataset SizeDataset Origin
DeepDR1600Different hospitals in China
Messidor-21744France
IDRID516India
APTOS3656Rural India
FGADR1842UAE
RLDR1593USA
DDR1249723 provinces in China
EyePACS88698USA

OOD datasets sizes and origins used for our DG model.

  1. For finetuning the latent diffusion model only EyePACS is used. These datasets are collected and processed according to the GDR-Bench Dataset. The dataset sources are also present in the benchmark dataset.

  2. For our SSL pretraining we used: ORIGA, G1020, ODIR-5K, Drishti-GS, REFUGE, RFMiD, DIARETDB1, DRIONS-DB, DRIVE, JSIEC, CHASE DB1, Cataract dataset, Glaucoma detection dataset, ROC as well as DR1 and DR2. The statistics of the datasets used for the SSL pretraining are presented in the table below:

DatasetDataset Size
ORIGA650
G10201020
ODIR-5K8000
Drishti-GS101
REFUGE1200
RFMiD1200
DIARETDB189
DRIONS-DB110
DRIVE40
JSIEC997
CHASE-DB128
ROC100
DR1 and DR22046
cataract_dataset601
Fundus_Train_Val_Data650
Total16832

Detailed breakdown of the compositions of the retinal datasets utilized during the SSL pretraining phase.

Dependencies

For the DG model:

    pip install -r requirements.txt

Environment used for our experiments

Environment
    Python: 3.7.2
    CUDA: 12.2
    OS:  UBUNTU 22.04

How to Run

The presented final main results are averages over three trials employing random seeds (0,1,2). Use the run.sh to run.

python main.py \
--root PATH_TO_YOUR_DATASETS \
--algorithm DG_ADR \
--desc dg_adr_seed0 \
--seed 0 \
--val_epochs 5 \
--num_epochs 200 \
--lr 1e-3 \
--batch_size 128 \
--val_batch_size 256 \
--weight_decay 0.0005 \
--optim sgd \
--sd_param 0 \
--project_name dg_adr_seed0  \
--ssl_pretrained \
--checkpoint_path PATH_TO_YOUR_SSL_CHECKPOINT \
--trivial_aug \
--use_syn \
--dropout 0 \
--warm_up_epochs 0 \
--k 5 \
--margin 0.1 \
--loss_alpha 10.0 \
--weight_loss_alpha 1.0 \

Main Results

<img src="figures/main_results.png" alt="picture alt" width="100%">

Qualitative Results

<img src="figures/features_tsne.png" alt="picture alt" width="100%">

Pretrained Models

  1. SSL pretrained checkpoint

  2. Textual Inversion embeddings

  3. Dreambooth checkpoint - Not shared since they are too large, but can be reproduced using the same datasets as mentioned above and the reference code from DreamBooth training example. We used the following training scripts for this part finetune_dreambooth_grade_0.sh, finetune_dreambooth_grade_1.sh, finetune_dreambooth_grade_2.sh, finetune_dreambooth_grade_3.sh, finetune_dreambooth_grade_4.sh.

  4. DG checkpoints

Citation

@misc{chokuwa2024divergentdomainsconvergentgrading,
      title={Divergent Domains, Convergent Grading: Enhancing Generalization in Diabetic Retinopathy Grading}, 
      author={Sharon Chokuwa and Muhammad Haris Khan},
      year={2024},
      eprint={2411.02614},
      archivePrefix={arXiv},
      primaryClass={eess.IV},
      url={https://arxiv.org/abs/2411.02614}, 
}

References

This repository uses some codes from DGDR, SimCLR and DA-Fusion.