Awesome
Medical-CXR-VQA
Medical-CXR-VQA is an LLM-constructed large-scale chest x-ray dataset for the medical visual question answering task. This repository provides the code for generating the Medical-CXR-VQA dataset, as proposed in our paper, "Interpretable medical image Visual Question Answering via multi-modal relationship graph learning."
For more information about the dataset and the method, please refer to our paper.
For the code of our multi-modal relationship graph learning method, please refer to MMRGL (Coming soon!).
The Medical-CXR-VQA dataset is currently under review in Physionet. We will attach the link once it's available.
Data
After downloading from Physionet, put below files into code/data
.
-
medical_cxr_vqa_questions.csv: The generated answers and questions for the Medical-CXR-VQA dataset.
-
all_diseases.json: The key information of all the reports generated by the fine-tuned LLaMA 2.
-
mimic_all.csv: All metadata related to the CXR studies.
-
all_diseases_gpt4_100.json: This is the 100 examples(key information) generated by GPT-4, used for fine-tuning LLaMA 2.
Libs
Please put below files into code/libs
.
-
disease_lib_llm.csv: The initial disease name library used for generating questions and answers.
-
level_lib.csv: The initial level library used for generating questions and answers.
-
location_lib.csv: The initial location library used for generating questions and answers.
-
type_lib.csv: The initial type library used for generating questions and answers.
-
postlocation_lib.csv: The initial postlocation library used for generating questions and answers.
-
position_change.csv: The initial position change library used for generating questions and answers.
-
entity_dict.json: Disease names with appearance frequencies in the KeyInfo set.
-
type_dict.json: Disease types with appearance frequencies in the KeyInfo set.
-
level_dict.json: Disease levels with appearance frequencies in the KeyInfo set.
-
location_dict.json: Disease locations with appearance frequencies in the KeyInfo set.
File explanation
Below are the files provided in code/data
in this code.
"system_text.txt"
is the system prompt for ChatGPT."simple_system_text.txt"
is the simplified system prompt used for fine-tuned LLama 2."user_text.txt"
is the user input for ChatGPT. In our case, this is the input report. This file is used forfunction_testing.py
only."id100.pkl"
: This file stores the IDs for 100 examples that have been extracted for doctor evaluation. It is used for comparison between different dataset construction methods on the same data.
Steps to generate dataset
Please be aware that the generated dataset may not be exactly the same as our provided one due to randomness. The code we are providing here is for reference purposes.
Firstly,
cd code
1. (Optional) prepare 100 training data.
This step can be skipped because we will provide the annotated golden set for fine-tuning Llama2. The code for this step is still provided here for reference.
The generated all_diseases_gpt4_100.json
is provided in Physionet dataset.
1-1 Prerequisite
- Azure OpenAI Service access
- According to the PhysioNet Credentialed Data Use Agreement, it is prohibited to share access to the data with third parties, including sending it through APIs provided by companies like OpenAI, or using it in online platforms like ChatGPT. Therefore, Azure OpenAI Service is suggested.
- Download
mimic-cxr-report.zip
from MIMIC-CXR database
1.2. Preparing GPT-4 generated training data:
python main.py --model gpt-4 --system_text system_text.txt --select_num 100 --output_name data/all_diseases_gpt4_100.json
Parameter explanations:
--model
is the model name. The default value isgpt-4
. You can also usegpt-35-turbo
. Please note that the model_name here should be the same as your deployment name in your Azure portal, not the model name. In my case, "gpt-4" and "gpt-35-turbo" are the deployment names of my two models.--select_num
defines the number of examples to extract.--output_name
defines the name and the path to the output file.
then run preprocess_data() in fine_tune_llama.py to generate the training data(gpt4_100_data_effective_finetune_llama2.json
) for fine-tuning Llama 2.
2. set up LlaMa-Factory
Please refer to LlaMa-Factory for installation.
Next,
- Download Llama 2 checkpoint. Please refer to this link. Then store it into the path
LlaMa-Factory/model_checkpoints/llama-2-70b-chat
- Move the provided ds_config to
LlaMa-Factory
root derectory. - Modify
dataset_info.json
inLlaMa-Factory/data
by adding the defination for the newly created fine-tuning dataset. The format is shown below. The file_name needs to be compatible with the output name generated using GPT-4.
"gpt4_100": {
"file_name": "gpt4_100_data_effective_finetune_llama2.json",
"columns": {
"prompt": "query",
"query": "",
"response": "output",
"history": ""
}
}
3. fine-tune the model using the following command
deepspeed --num_gpus 6 --master_port=9901 src/train_bash.py \
--deepspeed ds_config.json \
--stage sft \
--model_name_or_path ../model_checkpoints/llama-2-70b-chat \
--do_train \
--dataset gpt4_100 \
--template llama2 \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../model_checkpoints/llama_finetune_gpt4_100 \
--overwrite_cache \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-3 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16 \
--overwrite_output_dir
- --dataset: the dataset name defined in Step 2
dataset_info.json
- --num_gpus: the number of GPUs used for finetuning.
4. combine the fine-tuned model with the original model using the following command: (need to change the model path)
python src/export_model.py \
--model_name_or_path ../model_checkpoints/llama-2-70b-chat \
--template llama2 \
--finetuning_type lora \
--checkpoint_dir ../model_checkpoints/llama_finetune_gpt4_100 \
--output_dir ../model_checkpoints/llama_finetune_gpt4_100_output
5. Generate entire dataset using fine-tuned model
python main.py --model llama_finetune_gpt4_100_output --output_name data/all_diseases_chatgptRaw.json
6. Follow up and post-processing
post-processing: llama2_postprocessing.py
python llama2_postprocessing.py --input_path data/all_diseases_chatgptRaw.json --output_path data/all_diseases_standardized.json
Follow-up question: follow_up_gen.py
python follow_up_gen.py --model_name llama_finetune_gpt4_100_output --raw_file data/all_diseases_standardized.json --followup_file data/all_diseases_standardized_fu.json
- --raw_file: path to input file
- --followup_file: path to output file
These two steps can be alternately repeated.
7. Question generation
python question_gen.py --json_path <path_to_the_final_all_diseases_json>