Home

Awesome

<p align="center" width="100%"> <img src="assets/images/logo.gif" alt="Stable Alignment" style="width: 100%; min-width: 400px; display: block; margin: auto;"> </p>

Stable Alignment - Alignment Learning in Social Games

lint License: Apache 2.0

This is the official repo for the Stable Alignment project. We aim to provide a RLHF alternative which is superior in alignment performance, highly-efficient in data learning, and easy to deploy in scaled-up settings. Instead of training an extra reward model that can be gamed during optimization, we directly train on the recorded interaction data in simulated social games. We find high-quality data + reliable algorithm is the secret recipe for stable alignment learning.

The repo contains:

Life is a game. Play by your rules!

<p> <img src="assets/images/overview.png" alt="Overview of Stable Alignment" style="width: 100%; min-width: 200px; display: block; margin: auto;"> </p>

Sandbox Simulation

Installation

# install development environment
pip install -r requirements.txt
# install dependencies for package re-distribution
pip install -e .

Simulation Setup

Run Simulation

Navigate to the project root folder and run simulation with customized settings:

python stable_alignment/simulation.py \
    -model_type 'text-davinci-002' \
    -obs_model_type 'gpt-3.5-turbo' \
    -world_id 1 \
    -init_setting 'all_bad' \
    -n_round '2' \
    -size '4' \
    -dataset_name 'hh-rlhf'

We present an example simulation result in assets/sample_world. It is simulated with 100 text-davinci-003 based social agents and ChatGPT based observer agents. The simulation is run for 50 rounds of interactions.

Alignment Data Release

<p> <img src="assets/images/back_scatter.png" alt="Back Scatter in SandBox" style="width: 100%; min-width: 200px; display: block; margin: auto;"> </p>

The alignment data used for training has been already included in the path assets/sandbox_v1.json and assets/sandbox_v2.json. Note that they are sampled from the full set of interaction data by a ratio of 5:1:1 for Alignment Imitation, Self-Critic, and Realignment data respectively. The full set of interaction data is available upon request.

<details> <summary> <strong> The Statistics of Alignment Data (Full Set) </strong> </summary>
Data / Social Agent Typetext-davinci-002text-davinci-003ChatGPTTotal
Alignment Imitation9.8k10k10k29.8k
Self-Critic17k20k20k57k
Realignment3.3k3k0.7k7k
Total30.1k33k30.7k93.8k
Data / Social Agent Typetext-davinci-002text-davinci-003GPT4Total
Alignment Imitation18.2k10.4k20.2k48.8k
Self-Critic36.3k18.3k40k94.6k
Realignment18.2k3.4k4.0k25.6k
Total72.7k32.1k64.2k169k
</details>

Training with Stable Alignment

torchrun --nproc_per_node=4 --master_port=36646 train_alignment.py \
      --model_name_or_path "/workspace/hhh_sft" \  # path to your SFT model
      --data_path "./assets/sandbox_v1.json" \ # path to the alignment data
      --bf16 True \
      --output_dir "/workspace/<your_output_lm_name>" \
      --num_train_epochs 7 \
      --per_device_train_batch_size 1 \  # batch size has to be 1 for alignment training
      --per_device_eval_batch_size 1 \
      --gradient_accumulation_steps 8 \
      --evaluation_strategy "no" \
      --save_strategy "steps" \
      --save_steps 200 \
      --save_total_limit 1 \
      --learning_rate 2e-5 \
      --weight_decay 0. \
      --warmup_ratio 0.03 \
      --lr_scheduler_type "cosine" \
      --logging_steps 1 \
      --fsdp "shard_grad_op auto_wrap" \  # change to "full_shard auto_wrap" if OOM
      --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
      --tf32 True \
      --model_max_length 360 \  # change to shorter length if OOM
      --rating_scale 7 \  # the scale of the ratings. 7 for 1-7, 10 for 1-10, etc.
      --margin 10 \  # constant, see the paper
      --max_flow False \  # mean or max for the penalty
      --ratio 0.2 \  # control the ratio of the penalty
      --num_comp 3

So(cially)-Good Language Model

Model Release Model Release

We have released our models on huggingface! 🤗

Released models include:

  1. better-base, base model trained on LLaMA with AlpacaDataCleaned which is the fixed Alpaca instruction tuning dataset, and codealpaca which is the code pretraining dataset.

  2. hh-rlhf-sft, supervised fine-tuned model on better-base with the socially aligned demonstrations in Anthropic HH-RLHF dataset (the accepted samples in the dataset).

  3. socially-good-lm, socially aligned language model trained on hh-rlhf-sft with the stable alignment method.

After you download the model, you can run inference with the following command:

python stable_alignment/run_inference.py \
    --model_path './models/socially-good-lm' \
    --device 'cuda:0'

Citation

Please cite our paper if you use the data or code in this repo:

@misc{liu2023sociallyaligned,
      title={Training Socially Aligned Language Models in Simulated Human Society},
      author={Ruibo Liu and Ruixin Yang and Chenyan Jia and Ge Zhang and Denny Zhou and Andrew M. Dai and Diyi Yang and Soroush Vosoughi},
      year={2023},
      eprint={2305.16960},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}