Home

Awesome

KARD: Knowledge-Augmented Reasoning Distillation

Official Code Repository for the paper Knowledge-Augmented Reasoning Distillation for Small Language Models in Knowledge-intensive Tasks (NeurIPS 2023).

In this repository, we initially provide the preprocessing, trainig, and inference code with our KARD method.

Abstract

<img align="middle" width="900" src="https://github.com/Nardien/KARD/blob/main/images/concept_figure.PNG">

Large Language Models (LLMs) have shown promising performance in knowledge-intensive reasoning tasks that require a compound understanding of knowledge. However, deployment of the LLMs in real-world applications can be challenging due to their high computational requirements and concerns on data privacy. Previous studies have focused on building task-specific small Language Models (LMs) by fine-tuning them with labeled data or distilling LLMs. However, these approaches are ill-suited for knowledge-intensive reasoning tasks due to the limited capacity of small LMs in memorizing the knowledge required. Motivated by our theoretical analysis on memorization, we propose Knowledge-Augmented Reasoning Distillation (KARD), a novel method that fine-tunes small LMs to generate rationales obtained from LLMs with augmented knowledge retrieved from an external knowledge base. Moreover, we further propose a neural reranker to obtain documents relevant to rationale generation. We empirically show that KARD significantly improves the performance of small T5 and GPT models on the challenging knowledge-intensive reasoning datasets, namely MedQA-USMLE, StrategyQA, and OpenbookQA. Notably, our method makes the 250M T5 models achieve superior performance against the fine-tuned 3B models, having 12 times larger parameters, on both MedQA-USMLE and StrategyQA benchmarks.

Installation

Python version: 3.8.0

python -m pip install -r requirements.txt

Dataset

Follow below steps to setup the dataset for training.

  1. Download the raw dataset from this link.
  2. Download the rationales from the teacher language model (ChatGPT) from this link.
  3. Run python preprocess_cot.py --dataset {medqa_usmle_hf,strategyqa,obqa}
  4. Run python preprocess_cot_grounded.py --dataset {medqa_usmle_hf,strategyqa,obqa} --knowledge_base {wikipedia,pubmed}

You can also download the preprocessed data from this link.

LM Training Guide

If you want to run Knowledge-Augmented Reasoning Distillation, run the below script:

sh scripts/run_kard.sh {GPU ID} {Batch Size per GPU} {Model Size:base,large,xl} {Dataset:medqa,strategyqa,obqa}

We also provide the script for Reasoning Distillation without knowledge augmnetation.

sh scripts/run_rd.sh {GPU ID} {Batch Size per GPU} {Model Size:base,large,xl} {Dataset:medqa,strategyqa,obqa}

Both training script supports multi-gpu training.

For example, if you want to run KARD on the xl-sized LM training on medqa dataset with 4 gpus with batch size 8 per GPU, run as follows:

sh scripts/run_kard.sh 0,1,2,3 8 xl medqa

Reranker Training Guide

To train the reranker, check the reranker folder.

Inference Guide

After the LM and reranker training, run the following code:

python generate_predict.py --checkpoint_path "/path/to/checkpoint/" --retriever_type {sparse,dense} --dataset {medqa_usmle_hf,strategyqa,obqa} --dense_retriever_path "/path/to/retriever/"

If you train both the LM and reranker following the code above, run the code as follows:

python generate_predict.py --checkpoint_path "./save/flan-t5-base/medqa/kard_wikipedia" --retriever_type dense --dataset medqa_usmle_hf --dense_retriever_path "./reranker/save/colbert_lr1e-4"

You can adjust the following hyperparmeters:

TODOs

Please feel free to leave it on Issues if you find any problems.

Disclaimer

This is not an official product of KRAFTON. This repository is intended solely for research purposes.

Citation

If you found this work useful, please cite our work.

@inproceedings{DBLP:conf/nips/KangLBKS23,
  author       = {Minki Kang and
                  Seanie Lee and
                  Jinheon Baek and
                  Kenji Kawaguchi and
                  Sung Ju Hwang},
  title        = {Knowledge-Augmented Reasoning Distillation for Small Language Models in Knowledge-Intensive Tasks},
  booktitle    = {Advances in Neural Information Processing Systems 37: Annual Conference
                  on Neural Information Processing Systems 2023, NeurIPS 2023, December
                  10-16, 2023, New Orleans},
  year         = {2023},
}