Home

Awesome

The Limited Impact of Medical Adaptation of Large Language and Vision-Language Models

<p align="center"> <img src="./figs/medical-dapt-concept-art.webp" alt="image" width="30%"> </p> <br>

This is the official repository for the EMNLP 2024 paper (Oral):

Daniel P. Jeong, Saurabh Garg, Zachary C. Lipton, and Michael Oberst. Medical Adaptation of Large Language and Vision-Language Models: Are We Making Progress? Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing (EMNLP).

and its extended version (preprint):

Daniel P. Jeong, Pranav Mani, Saurabh Garg, Zachary C. Lipton, and Michael Oberst. The Limited Impact of Medical Adaptation of Large Language and Vision-Language Models. arXiv:2411.08870.

In the extended version, we include additional results on closed-ended QA tasks based on clinical notes in addition to medical-exam-style QA, as well as a comparison of performance when using medical versus general domain models as an initialization for downstream supervised fine-tuning.

We include all of the code used for preprocessing the medical QA datasets and running the main zero-/few-shot prompting and supervised fine-tuning experiments discussed in the paper. For details on the overall experimental setup, see Section 3 of the extended version. For discussion of the results, see Sections 4 (zero-/few-shot prompting) and 5 (supervised fine-tuning) of the extended version.

🔍 Links For Quick Navigation

<br>

🤖 Models

For all medical and general-domain LLMs and VLMs used for evaluation, we use the HuggingFace checkpoints listed below. For each general-domain LLM/VLM, we list the corresponding medical counterpart(s).

LLMs

VLMs

For LLaVA-Med-7B and LLaVA-v0-7B, we note that the checkpoints provided are delta weights that cannot be used directly. Please see the instructions provided in the LLaVA-Med repository and the LLaVA repository for merging the delta weights with the base LLaMA-7B LLM weights: https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md#legacy-models-delta-weights.

<br>

📁 Datasets

As detailed in Section 3 and Appendix A.1 of the extended version, we use the following datasets for evaluation, listed below. For details on how each dataset was preprocessed, see Section 3 and Appendix A.1 in the paper.

Textual QA: Medical Knowledge

All of the textual medical knowledge QA datasets are directly accessible via HuggingFace (links included above).

Textual QA: Clinical Notes

Except for the CASI dataset, all of the textual clinical note QA datasets require credentialed access. The i2b2 dataset can be accessed via the Harvard DBMI Portal. The remaining datasets are all available via PhysioNet (links included above). Note that EHRNoteQA also requires downloading the clinical notes in MIMIC-IV (Johnson et al., 2020; PhysioNet). To gain credentials for PhysioNet, follow the instructions here.

Additional Preprocessing Steps:

Visual QA

All of the visual QA datasets are publicly available (links included above).

Additional Preprocessing Steps: For VQA-RAD, PathVQA, and SLAKE, follow the steps in ./notebooks/preprocess-vqa.ipynb.

Configuring the Dataset Paths

For all datasets, make sure to appropriately update the dataset config files (e.g., ./configs/llm/eval/dataset/mednli.yaml) and the default path config files (e.g., ./configs/llm/eval/paths/default.yaml) to point to the correct paths where you have downloaded the data.

<br>

🐍 Setting Up the Conda Environment

To set up the conda environment (llm-env) that we used all of the LLM experiments, run the following:

./scripts/setup/setup_llm.sh

To set up the conda environment (llava-env) that we used for all of the experiments with LLaVA-Med-7B and LLaVA-v0-7B, run the following:

./scripts/setup/setup_llava.sh

To set up the conda environment (open-flamingo-env) that we used for all of the experiments with Med-Flamingo-9B and Open-Flamingo-9B, run the following:

./scripts/setup/setup_flamingo.sh
<br>

📁 Loading the Data

Textual QA Datasets

For all of the textual QA datasets that are available via HuggingFace, instantiating the relevant dataset class (e.g., MedQADataset, MMLUMedicalDataset) in ./src/dataset.py will automatically download and cache the data to the path specified by the hf_cache_dir argument:

dataset = MedQADataset(
    name='medqa', # 5 options (for 4 options, use `medqa-usmle`)
    splits=['train', 'test'], 
    main_split='test',
    hf_cache_dir='/data'
)

You can also load the other datasets that require manual downloading and preprocessing beforehand in the same way, but be sure to update the paths in the dataset config files appropriately.

For zero-shot prompting, running the following will apply a prompt template specified in the argument to all of the QA examples in the main_split in the zero-shot format (i.e., system prompt + question):

dataset.load_and_apply_prompt_template(
    model_name='llama-3-8b', # Use the default prompt format for Llama-3-8B
    prompt_type='zero-shot', # Zero-shot prompting format
    tokenizer=tokenizer # Assuming model tokenizer has been loaded beforehand
)

To randomly sample a prompt format using the context-free grammar we discuss in Section 3 and Appendix B, you can additionally pass the sample_kwargs argument to the dataset class, with the desired fixed random seeds.

dataset.load_and_apply_prompt_template(
    model_name='llama-3-8b', # Use the default prompt format for Llama-3-8B
    sample_kwargs=dict(prompt_template_seed=0)
    prompt_type='zero-shot', 
    tokenizer=tokenizer 
)

For few-shot prompting, call the sample_few_shot_qas() method before calling load_and_apply_prompt_template():

dataset = MedQADataset(
    splits=['train', 'test'], 
    main_split='test',
    hf_cache_dir='/data'
)
dataset.sample_few_shot_qas(n_shot=3, seed=0)
dataset.load_and_apply_prompt_template(
    model_name='llama-3-8b',
    sample_kwargs=dict(prompt_template_seed=0)
    prompt_type='few-shot', 
    tokenizer=tokenizer 
)

Visual QA Datasets

The MMMU-Medical datasets can be directly loaded from HuggingFace, as with all of the textual medical QA datasets. Below is an example for loading the MMMU (Basic Medical Science) dataset for 3-shot prompting LLaVA-Med-7B:

dataset = MMMUDataset(
    name='mmmu_basic-medical-science',
    splits=['train', 'test'], 
    main_split='test',
    hf_cache_dir='/data'
)
dataset.sample_few_shot_qas(n_shot=3, seed=0)
dataset.load_and_apply_prompt_template(
    model_name='llava-med-7b',
`   sample_kwargs=dict(prompt_template_seed=0)
    prompt_type='few-shot', 
    tokenizer=tokenizer 
)

All other visual QA datasets should be downloaded separately from the official repositories beforehand, as detailed here. For these datasets, which contain both closed-ended and open-ended QA examples, we performed additional preprocessing to only extract the closed-ended QA examples and format them into structured .jsonl files, as detailed in ./notebooks/preprocess-vqa.ipynb.

After running the notebook to execute all of the preprocessing steps, update data_root_dir in ./config/vlm/eval/paths/default.yaml to point to the path where the dataset is saved. Then the dataset can be loaded as follows (showing the VQA-RAD dataset as an example):

dataset = VQARADDataset(
    splits=['train', 'test'], 
    main_split='test',
    hf_cache_dir='/data'
)
<br>

📊 Zero-/Few-Shot Prompting Experiments with Model-Specific Prompt Selection (Section 4.1)

Medical LLM vs. General-Domain LLM

To evaluate all pairs of medical and general-domain LLMs on all textual QA datasets in the zero-shot and 3-shot settings, run the following script:

./scripts/eval/llm/compare_medical_general.sh "<gpu_indices>" "<decoding>" "<prompt_optimization_flag>"

All of the results will be automatically saved under the following directories (the brackets are placeholders):

# Greedy decoding
./results/llm/<dataset>/<model>/T=0,prompt=zero-shot,constrain_vocab=False,n_seeds=1 # Zero-shot
./results/llm/<dataset>/<model>/T=0,prompt=3-shot,constrain_vocab=False,n_seeds=1 # 3-shot

# Constrained decoding
./results/llm/<dataset>/<model>/T=0,prompt=zero-shot,constrain_vocab=True,n_seeds=1 # Zero-shot
./results/llm/<dataset>/<model>/T=0,prompt=3-shot,constrain_vocab=True,n_seeds=1 # 3-shot

Within each directory, the test_results.pkl will contain all of the predictions generated for the test set, along with the exact-match accuracies. The best prompt will be saved as template_str.yaml in the jinja2 format.

Medical VLM vs. General-Domain VLM

To evaluate all pairs of medical and general-domain VLMs on all visual QA datasets in the zero-shot and 3-shot settings, run the following script:

./scripts/eval/vlm/compare_medical_general.sh "<gpu_indices>" "<decoding>" "<prompt_optimization_flag>"

All of the results will be automatically saved under the following directories:

# Greedy decoding
./results/vlm/<dataset>/<model>/T=0,prompt=zero-shot,constrain_vocab=False,n_seeds=1 # Zero-shot
./results/vlm/<dataset>/<model>/T=0,prompt=3-shot,constrain_vocab=False,n_seeds=1 # 3-shot

# Constrained decoding
./results/vlm/<dataset>/<model>/T=0,prompt=zero-shot,constrain_vocab=True,n_seeds=1 # Zero-shot
./results/vlm/<dataset>/<model>/T=0,prompt=3-shot,constrain_vocab=True,n_seeds=1 # 3-shot

Within each directory, the test_results.pkl will contain all of the predictions generated for the test set, along with the exact-match accuracies. The best prompt will be saved as template_str.yaml in the jinja2 format.

<br>

📊 Zero-/Few-Shot Prompting Experiments with Prompt Optimized Only for the Medical Model (Section 4.2)

Medical LLM vs. General-Domain LLM

After running the LLM experiments with independent prompt selection, run the following script:

./scripts/eval/llm/compare_medical_general_medopt.sh "<gpu_indices>" "<decoding>"

All of the results will be saved in the exact same format as before and will only update the test_results.pkl file with the exact-match accuracies calculated. In the .pkl file, the corresponding entries will have the additional _med suffix to distinguish them from the results of the independent prompt selection experiments.

Medical VLM vs. General-Domain VLM

After running the VLM experiments with independent prompt selection, run the following script:

./scripts/eval/vlm/compare_medical_general_medopt.sh "<gpu_indices>" "<decoding>"

All of the results will be saved in the exact same format as before and will only update the test_results.pkl file with the exact-match accuracies calculated. In the .pkl file, the corresponding entries will have the additional _med suffix to distinguish them from the results of the independent prompt selection experiments.

<br>

📊 Supervised Fine-Tuning Experiments (Section 5)

LoRA Fine-Tuning and Evaluation for LLMs

To fine-tune a given medical/general-domain LLM on a textual QA dataset, run the following script:

./scripts/hpo/llm/run_lora_hpo.sh "<model>" "<dataset>" "<lora_r>" "<n_nodes>" "<head_node_ip>" "<gpu_indices>"

After running all of the sweeps, run the following script to select the best checkpoint across all hyperparameter trials:

./scripts/eval/llm/find_best_model.sh "<model>" "<dataset>" "<ft_method>" "<n_gpus>"

To run the final evaluation with the best checkpoint, run the following script:

./scripts/eval/llm/eval_finetuned.sh "<model>" "<dataset>" "<ft_method>" "<n_gpus>"

The evaluation result will be saved in the test_results.pkl file under the ./results/llm/<dataset>/<model>_lora-<dataset>-best/T=0,prompt=zero-shot,constrain_vocab=False,n_seeds=1 directory.

LoRA Fine-Tuning and Evaluation for LLaVA-Med-7B and LLaVA-v0-7B

To fine-tune LLaVA-Med-7B or LLaVA-v0-7B on a visual QA dataset, run the following script:

./scripts/hpo/vlm/run_lora_hpo.sh "<model>" "<dataset>" "<lora_r>" "<n_nodes>" "<head_node_ip>" "<gpu_indices>"

After running all of the sweeps, run the following script to select the best checkpoint across all hyperparameter trials:

./scripts/eval/vlm/find_best_model.sh "<model>" "<dataset>" "<ft_method>" "<n_gpus>"

To run the final evaluation with the best checkpoint, run the following script:

./scripts/eval/vlm/eval_finetuned.sh "<model>" "<dataset>" "<ft_method>" "<n_gpus>"

The evaluation result will be saved in the test_results.pkl file under the ./results/vlm/<dataset>/<model>_lora-<dataset>-best/T=0,prompt=zero-shot,constrain_vocab=False,n_seeds=1 directory.

Parameter-Efficient Fine-Tuning and Evaluation for Med-Flamingo-9B and Open-Flamingo-9B

To fine-tune Med-Flamingo-9B or Open-Flamingo-9B on a visual QA dataset, run the following script:

./scripts/hpo/vlm/run_ft_hpo.sh "<model>" "<dataset>" "<gpu_indices>"

After running all of the sweeps, run the following script to select the best checkpoint across all hyperparameter trials:

./scripts/eval/vlm/find_best_model.sh "<model>" "<dataset>" "<ft_method>" "<n_gpus>"

To run the final evaluation with the best checkpoint, run the following script:

./scripts/eval/vlm/eval_finetuned.sh "<model>" "<dataset>" "<ft_method>" "<n_gpus>"

The evaluation result will be saved in the test_results.pkl file under the ./results/vlm/<dataset>/<model>_ft-<dataset>-best/T=0,prompt=zero-shot,constrain_vocab=False,n_seeds=1 directory.

<br>

🙂 Citing Our Work (BibTeX)

# EMNLP 2024 Version
@inproceedings{jeong-etal-2024-medical,
    title = "Medical Adaptation of Large Language and Vision-Language Models: Are We Making Progress?",
    author = "Jeong, Daniel P and Garg, Saurabh and Lipton, Zachary Chase and Oberst, Michael",
    editor = "Al-Onaizan, Yaser and Bansal, Mohit and Chen, Yun-Nung",
    booktitle = "Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing",
    month = nov,
    year = "2024",
    address = "Miami, Florida, USA",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2024.emnlp-main.677",
    doi = "10.18653/v1/2024.emnlp-main.677",
    pages = "12143--12170"
}

# Extended Version
@article{jeong-etal-2024-limited,
    title = "The Limited Impact of Medical Adaptation of Large Language and Vision-Language Models",
    author = "Jeong, Daniel P and Mani, Pranav and Garg, Saurabh and Lipton, Zachary Chase and Oberst, Michael",
    journal = "arXiv preprint arXiv:2411.08870",
    year = "2024"
}