Home

Awesome

<p align="center"> <img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/logo.png" width=90% alt="OAT" /> </p>

PyPI - Version PyPI - Python Version License arXiv

Installation | Usage | Examples | Benchmarking | Citation


Introduction

Oat 🌾 is a simple yet efficient system for running online LLM alignment algorithms. Its key features include:

LLM alignment as contextual dueling bandits

LLM alignment is essentially an online learning and decision making problem where the agent (e.g., the LLM policy with an optional built-in reward model) interacts with the environment (i.e., humans) to achieve either of the two distinct objectives: minimizing cumulative regret in the Explore & Exploit setting or minimizing anytime regret in the Best Arm Identification setting.

In our paper, we formalize LLM alignment as a contextual dueling bandit (CDB) problem (see illustration below) and propose a sample-efficient alignment approach based on Thompson sampling.

<p align="center"> <img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e0da719024bdc16fb4a993a8405e15cb0cf2b53a/interface.png" width=80%/> </p>

The CDB framework necessitates an efficient online training system to validate the proposed method and compare it with other baselines. Oat 🌾 is developed as part of this research initiative.

Using the CDB framework, existing LLM alignment paradigms can be summarized as follows:

<p align="center"> <img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/acbb25a20dd6c1e7619539b0fa449076ade2f873/compare.png" width=95%/> </p>

For more details, please check out our paper!

Installation

In a python environment with supported versions (>=3.8, <=3.10), you could install oat via PyPI:

pip install vllm==0.6.2 && pip install oat-llm

Or you could also install in "editable" mode for local development:

git clone git@github.com:sail-sg/oat.git
cd oat
pip install vllm==0.6.2 && pip install -e .

Usage

Below is an example to align a 1-B Pythia SFT Model on the tl;dr dataset using online SimPO with PairRM as the preference oracle:

[!WARNING] Aligning with PairRM provides a lightweight example setup. For reproducing results from the paper or developing custom online alignment algorithms, we recommend using stronger reward models (or GPT-as-a-judge) as a preference oracle. This approach better approximates the ideal case of a human population. See the examples.

python -m oat.experiment.main \
    --gpus 2 \
    --collocate \
    --dap-algo SimPO \
    --beta 2 \
    --preference-oracle pairrm \
    --pretrain trl-lib/pythia-1b-deduped-tldr-sft \
    --prompt-data lkevinzc/tldr-with-sft-reference \
    --output_key pythia-1b-reference \
    --sync-params-every 1 \
    --rollout-batch-size-per-device 64 \
    --pi-buffer-maxlen-per-device 64 \
    --train-batch-size-per-device 8 \
    --use-wb \
    --wb-run-name 1b_pairrm_simpo_online

This example completes in less than two hours on two A100-40G GPUs!

To run an offline SimPO baseline for comparison, we disable weights synchronization from the learner to actors by adjusting the sync-params-every argument:

python -m oat.experiment.main \
    --gpus 2 \
    --collocate \
    --dap-algo SimPO \
    --beta 2 \
    --preference-oracle pairrm \
    --pretrain trl-lib/pythia-1b-deduped-tldr-sft \
    --prompt-data lkevinzc/tldr-with-sft-reference \
    --output_key pythia-1b-reference \
-   --sync-params-every 1 \
+   --sync-params-every 9999 \ # any number > total gradient step (50000//128=390)
    --rollout-batch-size-per-device 64 \
    --pi-buffer-maxlen-per-device 64 \
    --train-batch-size-per-device 8 \
    --use-wb \
-   --wb-run-name 1b_pairrm_simpo_online
+   --wb-run-name 1b_pairrm_simpo_offline

Finally, we run SEA SimPO (with $\gamma=1$, see here for the meaning of $\gamma$) to verify its capability of sample-efficient alignment. This experiment utilizes 4 GPUs, with a reduced per-device training batch size to accommodate the training of an additional epistemic reward model. The per-device rollout batch size and buffer length are adjusted to ensure a global batch size of 128. Additionally, 10 response candidates are generated for exploration using BAI Thompson sampling.

python -m oat.experiment.main \
-   --gpus 2 \
+   --gpus 4 \
    --dap-algo SimPO \
    --beta 2 \
    --preference-oracle pairrm \
    --pretrain trl-lib/pythia-1b-deduped-tldr-sft \
    --prompt-data lkevinzc/tldr-with-sft-reference \
    --output_key pythia-1b-reference \
    --sync-params-every 1 \
-   --rollout-batch-size-per-device 64 \
-   --pi-buffer-maxlen-per-device 64 \
-   --train-batch-size-per-device 8 \
+   --rollout-batch-size-per-device 32 \
+   --pi-buffer-maxlen-per-device 32 \
+   --train-batch-size-per-device 1 \
+   --learn-rm \
+   --exp-method EnnBAITS \
+   --num_samples 10 \
    --use-wb \
-   --wb-run-name 1b_pairrm_simpo_online
+   --wb-run-name 1b_pairrm_simpo_sea
<p align="center"> <img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/example_result.png" width=55%/> </p>

Check out this tutorial for more examples covering:

Benchmarking

The benchmarking compares oat with the online DPO implementation from huggingface/trl. Below, we outline the configurations used for oat and present the benchmarking results. Notably, oat 🌾 achieves up to 2.5x computational efficiency compared to trl 🤗.

<p align="center"> <img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/system_configs.png" width=97%/> </p> <p align="center"> <img src="https://gist.githubusercontent.com/lkevinzc/98afee30a5141d7068a0b35a88901a31/raw/e23f40d33e8a2fa4220e8122c152b356084b8afb/bench_results.png" width=65% /> </p>

Please refer to Appendix C of our paper for a detailed discussion of the benchmarking methods and results.

Citation

If you find this work useful for your research, please consider citing

@article{
  liu2024sea,
  title={Sample-Efficient Alignment for LLMs},
  author={Zichen Liu and Changyu Chen and Chao Du and Wee Sun Lee and Min Lin},
  journal={arXiv preprint arXiv:2411.01493},
  year={2024}
}

License

oat is distributed under the terms of the Apache2 license.

Acknowledgement

We thank the following awesome projects that have contributed to the development of oat:

Disclaimer

This is not an official Sea Limited or Garena Online Private Limited product.