Stable Alignment - Alignment Learning in Social Games

lint License: Apache 2.0

This is the official repo for the Stable Alignment project. We aim to provide a RLHF alternative which is superior in alignment performance, highly-efficient in data learning, and easy to deploy in scaled-up settings. Instead of training an extra reward model that can be gamed during optimization, we directly train on the recorded interaction data in simulated social games. We find high-quality data + reliable algorithm is the secret recipe for stable alignment learning.

The repo contains:

Sandbox Simulation


# install development environment
pip install -r requirements.txt
# install dependencies for package re-distribution
pip install -e .

Simulation Setup

Run Simulation

Navigate to the project root folder and run simulation with customized settings:

python stable_alignment/simulation.py \
    -model_type 'text-davinci-002' \
    -obs_model_type 'gpt-3.5-turbo' \
    -world_id 1 \
    -init_setting 'all_bad' \
    -n_round '2' \
    -size '4' \
    -dataset_name 'hh-rlhf'

We present an example simulation result in assets/sample_world. It is simulated with 100 text-davinci-003 based social agents and ChatGPT based observer agents. The simulation is run for 50 rounds of interactions.

Alignment Data Release

The alignment data used for training has been already included in the path assets/sandbox_v1.json and assets/sandbox_v2.json. Note that they are sampled from the full set of interaction data by a ratio of 5:1:1 for Alignment Imitation, Self-Critic, and Realignment data respectively. The full set of interaction data is available upon request.

<details> <summary> <strong> The Statistics of Alignment Data (Full Set) </strong> </summary>
Data / Social Agent Typetext-davinci-002text-davinci-003ChatGPTTotal
Alignment Imitation9.8k10k10k29.8k
Data / Social Agent Typetext-davinci-002text-davinci-003GPT4Total
Alignment Imitation18.2k10.4k20.2k48.8k

Training with Stable Alignment

torchrun --nproc_per_node=4 --master_port=36646 train_alignment.py \
      --model_name_or_path "/workspace/hhh_sft" \  # path to your SFT model
      --data_path "./assets/sandbox_v1.json" \ # path to the alignment data
      --bf16 True \
      --output_dir "/workspace/<your_output_lm_name>" \
      --num_train_epochs 7 \
      --per_device_train_batch_size 1 \  # batch size has to be 1 for alignment training
      --per_device_eval_batch_size 1 \
      --gradient_accumulation_steps 8 \
      --evaluation_strategy "no" \
      --save_strategy "steps" \
      --save_steps 200 \
      --save_total_limit 1 \
      --learning_rate 2e-5 \
      --weight_decay 0. \
      --warmup_ratio 0.03 \
      --lr_scheduler_type "cosine" \
      --logging_steps 1 \
      --fsdp "shard_grad_op auto_wrap" \  # change to "full_shard auto_wrap" if OOM
      --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
      --tf32 True \
      --model_max_length 360 \  # change to shorter length if OOM
      --rating_scale 7 \  # the scale of the ratings. 7 for 1-7, 10 for 1-10, etc.
      --margin 10 \  # constant, see the paper
      --max_flow False \  # mean or max for the penalty
      --ratio 0.2 \  # control the ratio of the penalty
      --num_comp 3

So(cially)-Good Language Model

We have released our models on huggingface! 🤗

Released models include:

  1. better-base, base model trained on LLaMA with AlpacaDataCleaned which is the fixed Alpaca instruction tuning dataset, and codealpaca which is the code pretraining dataset.

  2. hh-rlhf-sft, supervised fine-tuned model on better-base with the socially aligned demonstrations in Anthropic HH-RLHF dataset (the accepted samples in the dataset).

  3. socially-good-lm, socially aligned language model trained on hh-rlhf-sft with the stable alignment method.

After you download the model, you can run inference with the following command:

python stable_alignment/run_inference.py \
    --model_path './models/socially-good-lm' \
    --device 'cuda:0'


