Awesome
PRIOR: Prototype Representation Joint Learning from Medical Images and Reports
Official repository for the paper Prototype Representation Joint Learning from Medical Images and Reports, ICCV 2023.
Introduction
Contrastive learning based vision-language joint pretraining has emerged as a successful representation learning strategy. In this paper, we present a prototype representation learning framework incorporating both global and local alignment between medical images and reports. In contrast to standard global multi-modality alignment methods, we employ a local alignment module for finegrained representation. Furthermore, a cross-modality conditional reconstruction module is designed to interchange information across modalities in the training phase by reconstructing masked images and reports. For reconstructing long reports, a sentence-wise prototype memory bank is cons tructed, enabling the network to focus on low-level localized visual and high-level clinical linguistic features. Additionally, a non-auto-regressive generation paradigm is proposed for reconstructing non-sequential reports. Experimental results on five downstream tasks, including supervised classification, zero-shot classification, image-to-text retrieval, semantic segmentation, and object detection, show the proposed method outperforms other state-of-the-art methods across multiple datasets and under different dataset size settings.
Setup
Run
pip install -r requirements.txt
Data Preparation
MIMIC-CXR Dataset
-
Download the Version 2 of the MIMIC-CXR-JPG from
https://physionet.org/content/mimic-cxr-jpg/2.0.0/
to<image_dataset_path>
-
Download the reports from MIMIC-CXR
https://physionet.org/content/mimic-cxr/2.0.0/
to<report_dataset_path>
-
Run scripts to make json file for pre-training
cd codes/
python prior/data/pretrain/mimiccxr.py build --root_image_path <image_dataset_path> --root_report_path <report_dataset_path> --save_path <dataset_json_path> --meta_csv_path <meta_csv_path>
- Add
<dataset_json_path>
toconfigs/data/mimiccxr.yaml
_target_: prior.data.pretrain.mimiccxr.MimicCxrDataset
dataset_path: <dataset_json_path> # update this line
image_transform: ???
text_transform: ???
num_colors: 1
rate: 1.0
Pre-train
Run
cd codes/
python scripts/train.py +experiments/pre_train=train_prior
Pre-trained weights
We released our pre-trained model on pretrained/prior_resnet50.pt, you can download image encoder here. The whole image-text part is released on Google Drive.
Downstream tasks
Supervised finetuning
This part is similar to LOVT. The classification model is based on TorchVision; the segmentation model is based on SMP, and the detection model is based on Lightning-flash. Since this part has not addtional technical contribution, we do not provide the codes currently. We will update this part in the future.
Zero-shot classification
Before running the codes, make sure you have downloaded the CheXpert dataset in <root_path>
and downloaded the whole image-text pre-trianed weights of PRIOR from Google Drive to <pretrained_path>
.
Add <root_path>
and <dataset_path>
to configs/data/chexpert_zero_shot.yaml
_target_: prior.data.zero_shot_classification.chexpert_zls.CheXPertZeroClsDataset
dataset_path: ??? # update this line
transform: ???
num_colors: 1
root_path: ??? # update this line
rate: 1.0
Add <pretrained_path>
to configs/experiments/zero_shot_classification/test_prior.yaml
zero_shot_classification_model:
......
pre_trained_path: # update this line
......
Then run
cd codes/
python scripts/downstream.py +experiments/zero_shot_classification=test_prior
Acknowledgement
Some of the code is borrowed from LOVT, GLoRIA. Thanks for their great work.
Citation
If you find this work useful in your research, please cite:
@inproceedings{PRIOR,
title={PRIOR: Prototype Representation Joint Learning from Medical Images and Reports},
author={Cheng, Pujin and Lin, Li and Lyu, Junyan and Huang, Yijin and Luo, Wenhan and Tang, Xiaoying},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year={2023}
}