Home

Awesome

LLM Calibration

This repository contains the code for the paper Large Language Models Must Be Taught to Know What They Don't Know by Sanyam Kapoor<sup>* </sup>, Nate Gruver<sup>* </sup>, Manley Roberts, Katherine Collins, Arka Pal, Umang Bhatt, Adrian Weller, Samuel Dooley, Micah Goldblum, and Andrew Gordon Wilson.

<figure> <img src="./assets/explainer_figure.png" alt="Image"> </figure>

We fine-tune various large language models that estimate well-calibrated uncertainties for both multiple-choice and open-ended question answering settings. Fine-tuning uses a dataset of approximately 20,000 generations from each base model labeled for correctness. At test/inference time, the probability of correctness defines the confidence defines the confidence of the model in its answer.

HuggingFace Release

We release the following calibration-tuned models as PEFT adapters via HuggingFace, along with the datasets used to train them.

<table> <tr> <td rowspan=6 valign="center">Open-Ended Generation</td> <td valign="top">Llama 2 7B</td> <td valign="top"><a href="https://huggingface.co/calibration-tuning/Llama-2-7b-hf-ct-oe" target="_blank"><img src="./assets/model-on-hf-md.svg"/></a></td> <td valign="top"><a href="https://huggingface.co/datasets/calibration-tuning/Llama-2-7b-hf-20k-oe" target="_blank"><img src="./assets/dataset-on-hf-md.svg"/></a></td> </tr> <tr> <td valign="top">Llama 2 7B Chat</td> <td valign="top"><a href="https://huggingface.co/calibration-tuning/Llama-2-7b-chat-hf-ct-oe" target="_blank"><img src="./assets/model-on-hf-md.svg"/></a></td> <td valign="top"><a href="https://huggingface.co/datasets/calibration-tuning/Llama-2-7b-chat-hf-20k-oe" target="_blank"><img src="./assets/dataset-on-hf-md.svg"/></a></td> </tr> <tr> <td valign="top">Llama 2 13B</td> <td valign="top"><a href="https://huggingface.co/calibration-tuning/Llama-2-13b-hf-ct-oe" target="_blank"><img src="./assets/model-on-hf-md.svg"/></a></td> <td valign="top"><a href="https://huggingface.co/datasets/calibration-tuning/Llama-2-13b-hf-20k-oe" target="_blank"><img src="./assets/dataset-on-hf-md.svg"/></a></td> </tr> <tr> <td valign="top">Llama 2 13B Chat</td> <td valign="top"><a href="https://huggingface.co/calibration-tuning/Llama-2-13b-chat-hf-ct-oe" target="_blank"><img src="./assets/model-on-hf-md.svg"/></a></td> <td valign="top"><a href="https://huggingface.co/datasets/calibration-tuning/Llama-2-13b-chat-hf-20k-oe" target="_blank"><img src="./assets/dataset-on-hf-md.svg"/></a></td> </tr> <tr> <td valign="top">Mistral 7B</td> <td valign="top"><a href="https://huggingface.co/calibration-tuning/Mistral-7B-v0.1-ct-oe" target="_blank"><img src="./assets/model-on-hf-md.svg"/></a></td> <td valign="top"><a href="https://huggingface.co/datasets/calibration-tuning/Mistral-7B-v0.1-20k-oe" target="_blank"><img src="./assets/dataset-on-hf-md.svg"/></a></td> </tr> <tr> <td valign="top">Mistral 7B Instruct</td> <td valign="top"><a href="https://huggingface.co/calibration-tuning/Mistral-7B-Instruct-v0.2-ct-oe" target="_blank"><img src="./assets/model-on-hf-md.svg"/></a></td> <td valign="top"><a href="https://huggingface.co/datasets/calibration-tuning/Mistral-7B-Instruct-v0.2-20k-oe" target="_blank"><img src="./assets/dataset-on-hf-md.svg"/></a></td> </tr> <tr> <td rowspan=6 valign="center">Multiple-Choice Question-Answering</td> <td valign="top">Llama 2 7B</td> <td valign="top"><a href="https://huggingface.co/calibration-tuning/Llama-2-7b-hf-ct-choice" target="_blank"><img src="./assets/model-on-hf-md.svg"/></a></td> <td valign="top"><a href="https://huggingface.co/datasets/calibration-tuning/Llama-2-7b-hf-20k-choice" target="_blank"><img src="./assets/dataset-on-hf-md.svg"/></a></td> </tr> <tr> <td valign="top">Llama 2 7B Chat</td> <td valign="top"><a href="https://huggingface.co/calibration-tuning/Llama-2-7b-chat-hf-ct-choice" target="_blank"><img src="./assets/model-on-hf-md.svg"/></a></td> <td valign="top"><a href="https://huggingface.co/datasets/calibration-tuning/Llama-2-7b-chat-hf-20k-choice" target="_blank"><img src="./assets/dataset-on-hf-md.svg"/></a></td> </tr> <tr> <td valign="top">Llama 2 13B</td> <td valign="top"><a href="https://huggingface.co/calibration-tuning/Llama-2-13b-hf-ct-choice" target="_blank"><img src="./assets/model-on-hf-md.svg"/></a></td> <td valign="top"><a href="https://huggingface.co/datasets/calibration-tuning/Llama-2-13b-hf-20k-choice" target="_blank"><img src="./assets/dataset-on-hf-md.svg"/></a></td> </tr> <tr> <td valign="top">Llama 2 13B Chat</td> <td valign="top"><a href="https://huggingface.co/calibration-tuning/Llama-2-13b-chat-hf-ct-choice" target="_blank"><img src="./assets/model-on-hf-md.svg"/></a></td> <td valign="top"><a href="https://huggingface.co/datasets/calibration-tuning/Llama-2-13b-chat-hf-20k-choice" target="_blank"><img src="./assets/dataset-on-hf-md.svg"/></a></td> </tr> <tr> <td valign="top">Mistral 7B</td> <td valign="top"><a href="https://huggingface.co/calibration-tuning/Mistral-7B-v0.1-ct-choice" target="_blank"><img src="./assets/model-on-hf-md.svg"/></a></td> <td valign="top"><a href="https://huggingface.co/datasets/calibration-tuning/Mistral-7B-v0.1-20k-choice" target="_blank"><img src="./assets/dataset-on-hf-md.svg"/></a></td> </tr> <tr> <td valign="top">Mistral 7B Instruct</td> <td valign="top"><a href="https://huggingface.co/calibration-tuning/Mistral-7B-Instruct-v0.2-ct-choice" target="_blank"><img src="./assets/model-on-hf-md.svg"/></a></td> <td valign="top"><a href="https://huggingface.co/datasets/calibration-tuning/Mistral-7B-Instruct-v0.2-20k-choice" target="_blank"><img src="./assets/dataset-on-hf-md.svg"/></a></td> </tr> </table>

See experiments/play.py for an example script of how to load and use the models.

Environment Setup

Create a new environment and activate, e.g. with conda,

conda create -y -n calibration-tuning python=3.11 pip -c conda-forge
conda activate calibration-tuning

And finally run,

pip install -e .

This will install the llm package.

NOTE: If a different PyTorch CUDA compilation is required, use extra index repositories. e.g. For CUDA 11.8 run,

pip install --no-cache-dir -U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118

Usage

All arguments from the main method in each of the scripts below qualify as command line arguments.

Environment Variables:

Arguments:

Dataset Generation

NOTE: The Story Cloze dataset (2018 version) requires manual download. See instructions here. After getting the CSV file, place it at ${HF_HOME}/datasets/story_cloze/2018.

Output Generation

To create a CSV dataset of open-ended generations at <outputs-log-dir>/outputs. <outputs-log-dir> is auto-generated or can be explicitly specified using --log-dir argument.

python experiments/generate.py outputs --dataset=all_20k_uniform --prompt-style=oe --model-name=llama2:13b-chat --max-new-tokens=30 --batch-size=16 --kshot=1

For multiple-choice generations,

python experiments/generate.py outputs --dataset=all_20k_uniform --prompt-style=choice --model-name=llama2:13b-chat --max-new-tokens=1 --batch-size=16

Uncertainty Query Label Generation

To generate the dataset with uncertainty query labels at <labels-log-dir>/labels (auto-generated or specified via --log-dir),

python experiments/generate.py labels --dataset=offline:<outputs-log-dir>/outputs --model-name=llama2:13b-chat --strategy=substring --batch-size=16

Use --strategy=fuzzy_gpt-3.5-turbo-1106 for generating labels via GPT 3.5 Turbo.

Training

Checkpoints will be saved in an auto-generated directory <train-log-dir>/checkpoint-<step>, or can be configured via --log-dir.

LoRA + Prompt

To use the labeled dataset for calibration-tuning (LoRA + Prompt),

torchrun --nnodes=1 --nproc_per_node=auto experiments/calibration_tune.py --dataset=offline:<labels-log-dir>/labels --model-name=llama2:13b-chat --batch-size=4 --kl-decay=1.0 --max-steps=5000

Use --scale-temp for temperature scaling of the uncertainty query predictions.

For other CLI arguments, see the main function of experiments/calibration_tune.py.

Probe / LoRA

To use the labeled dataset for training a classifier head (Probe),

torchrun --nnodes=1 --nproc_per_node=auto experiments/classifier_tune.py --dataset=offline:<labels-log-dir>/labels --model-name=llama2:13b-chat --batch-size=4 --max-steps=5000

Use --scale-temp for temperature scaling of the classifier. Use --with-lora to enable trainable LoRA parameters (LoRA).

For other CLI arguments, see the main function of experiments/classifier_tune.py.

Evaluation

LoRA + Prompt

For open-ended generation evaluations,

python experiments/evaluate.py --dataset=eval:mmlu --prompt-style=oe --model-name=llama2:13b-chat --query-peft-dir=<train-log-dir>/checkpoint-<step> --mode=query

For multiple-choice question-answering evaluations,

python experiments/evaluate.py --dataset=eval:mmlu --prompt-style=choice --model-name=llama2:13b-chat --query-peft-dir=<train-log-dir>/checkpoint-<step> --mode=query_choice

Use --scale-temp=query to use temperature scaling of the uncertainty query logits.

Probe / LoRA

For open-ended generation evaluations,

python experiments/evaluate.py --dataset=eval:mmlu --prompt-style=oe --model-name=llama2:13b-chat --query-peft-dir=<train-log-dir>/checkpoint-<step> --mode=class --with-classifier

For multiple-choice question-answering evaluations,

python experiments/evaluate.py --dataset=eval:mmlu --prompt-style=choice --model-name=llama2:13b-chat --query-peft-dir=<train-log-dir>/checkpoint-<step> --mode=class_choice --with-classifier

Use --scale-temp=probe to use temperature scaling of the uncertainty query logits.

LICENSE

Apache 2.0

Citation

Please cite our work as:

@inproceedings{kapoor2024llmcalibration,
    title={Large Language Models Must Be Taught to Know What They Don't Know},
    author={Sanyam Kapoor, Nate Gruver, Manley Roberts, Katherine Collins, Arka Pal, Umang Bhatt, Adrian Weller, Samuel Dooley, Micah Goldblum, Andrew Gordon Wilson},
    publisher={arXiv},
    year={2024}
}