Home

Awesome

Medical-CXR-VQA

Medical-CXR-VQA is an LLM-constructed large-scale chest x-ray dataset for the medical visual question answering task. This repository provides the code for generating the Medical-CXR-VQA dataset, as proposed in our paper, "Interpretable medical image Visual Question Answering via multi-modal relationship graph learning."

For more information about the dataset and the method, please refer to our paper.

For the code of our multi-modal relationship graph learning method, please refer to MMRGL (Coming soon!).

The Medical-CXR-VQA dataset is currently under review in Physionet. We will attach the link once it's available.

Data

After downloading from Physionet, put below files into code/data.

Libs

Please put below files into code/libs.

File explanation

Below are the files provided in code/data in this code.

  1. "system_text.txt" is the system prompt for ChatGPT.
  2. "simple_system_text.txt" is the simplified system prompt used for fine-tuned LLama 2.
  3. "user_text.txt" is the user input for ChatGPT. In our case, this is the input report. This file is used for function_testing.py only.
  4. "id100.pkl": This file stores the IDs for 100 examples that have been extracted for doctor evaluation. It is used for comparison between different dataset construction methods on the same data.

Steps to generate dataset

Please be aware that the generated dataset may not be exactly the same as our provided one due to randomness. The code we are providing here is for reference purposes.

Firstly,

cd code

1. (Optional) prepare 100 training data.

This step can be skipped because we will provide the annotated golden set for fine-tuning Llama2. The code for this step is still provided here for reference.

The generated all_diseases_gpt4_100.json is provided in Physionet dataset.

1-1 Prerequisite

  1. Azure OpenAI Service access
    • According to the PhysioNet Credentialed Data Use Agreement, it is prohibited to share access to the data with third parties, including sending it through APIs provided by companies like OpenAI, or using it in online platforms like ChatGPT. Therefore, Azure OpenAI Service is suggested.
  2. Download mimic-cxr-report.zip from MIMIC-CXR database

1.2. Preparing GPT-4 generated training data:

python main.py  --model gpt-4 --system_text system_text.txt --select_num 100 --output_name data/all_diseases_gpt4_100.json

Parameter explanations:

then run preprocess_data() in fine_tune_llama.py to generate the training data(gpt4_100_data_effective_finetune_llama2.json) for fine-tuning Llama 2.

2. set up LlaMa-Factory

Please refer to LlaMa-Factory for installation.

Next,

  1. Download Llama 2 checkpoint. Please refer to this link. Then store it into the path LlaMa-Factory/model_checkpoints/llama-2-70b-chat
  2. Move the provided ds_config to LlaMa-Factory root derectory.
  3. Modify dataset_info.json in LlaMa-Factory/data by adding the defination for the newly created fine-tuning dataset. The format is shown below. The file_name needs to be compatible with the output name generated using GPT-4.
  "gpt4_100": {
    "file_name": "gpt4_100_data_effective_finetune_llama2.json",
    "columns": {
      "prompt": "query",
      "query": "",
      "response": "output",
      "history": ""
    }
  }

3. fine-tune the model using the following command

deepspeed --num_gpus 6 --master_port=9901 src/train_bash.py \
    --deepspeed ds_config.json \
    --stage sft \
    --model_name_or_path ../model_checkpoints/llama-2-70b-chat \
    --do_train \
    --dataset gpt4_100 \
    --template llama2 \
    --finetuning_type lora \
    --lora_target q_proj,v_proj \
    --output_dir ../model_checkpoints/llama_finetune_gpt4_100 \
    --overwrite_cache \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 1e-3 \
    --num_train_epochs 3.0 \
    --plot_loss \
    --fp16 \
    --overwrite_output_dir

4. combine the fine-tuned model with the original model using the following command: (need to change the model path)

python src/export_model.py \
    --model_name_or_path ../model_checkpoints/llama-2-70b-chat \
    --template llama2 \
    --finetuning_type lora \
    --checkpoint_dir ../model_checkpoints/llama_finetune_gpt4_100 \
    --output_dir ../model_checkpoints/llama_finetune_gpt4_100_output

5. Generate entire dataset using fine-tuned model

python main.py  --model llama_finetune_gpt4_100_output --output_name data/all_diseases_chatgptRaw.json

6. Follow up and post-processing

post-processing: llama2_postprocessing.py

python llama2_postprocessing.py --input_path data/all_diseases_chatgptRaw.json --output_path data/all_diseases_standardized.json

Follow-up question: follow_up_gen.py

python follow_up_gen.py --model_name llama_finetune_gpt4_100_output --raw_file data/all_diseases_standardized.json --followup_file data/all_diseases_standardized_fu.json

These two steps can be alternately repeated.

7. Question generation

python question_gen.py --json_path <path_to_the_final_all_diseases_json>