Home

Awesome

<p align="center"> <img src="RL4LMs_logo.png" width=512px> </p> <h1 align="center"> :robot: RL4LMs :rocket: </h1> <h3 align="center"> A modular RL library to fine-tune language models to human preferences </h3> <br>

We provide easily customizable building blocks for training language models including implementations of on-policy algorithms, reward functions, metrics, datasets and LM based actor-critic policies

Paper Link: https://arxiv.org/abs/2210.01241

Website Link: https://rl4lms.apps.allenai.org/

Thoroughly tested and benchmarked with over 2000 experiments :fire: (GRUE benchmark :trophy:) on a comprehensive set of:

All of these building blocks can be customizable allowing users to train transformer-based LMs to optimize any arbitrary reward function on any dataset of their choice.

Recent updates (v0.2.0) on 23-Nov-22

Recent updates (v0.2.1)


Install

Local Installation

git clone https://github.com/allenai/RL4LMs.git
cd RL4LMs
pip install -e .

Docker

We provide also a Dockerfile for development using docker containers containing all the dependencies.

docker build . -t rl4lms

Additional dependencies

Optionally, coreNLP libraries are required for certain metric computations (eg. SPICE) which can be downloaded through cd rl4lms/envs/text_generation/caption_metrics/spice && bash get_stanford_models.sh


Quick Start - Train PPO/NLPO using pre-defined YAML configs

We provide a simple training API that can be invoked via train script that allows to train PPO, NLPO or a supervised model by using a config file (YAML).

For example, to train T5-base on CNN/DM summarization on PPO using Rouge-1 as reward function, you can run:

python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/summarization/t5_ppo.yml

Config files for all tasks can be found here.

YAML file schema - Configuring building blocks

Config file contains details about hyper-parameter settings for building blocks which are described below:


Custom Building Blocks :wrench:

RL4LMs provide complete customizability - with respect to adding new tasks/datasets, reward functions, evaluation metric, on-policy algorithms and actor-critic policies.

Adding dataset

Users can create their own datasets by sub-classing TextGenPool just by overriding prepare(cls, split: str, **args) -> 'TextGenPool': method to return an instance of TextGenPool. An example is shown below:

from rl4lms.data_pools.text_generation_pool import Sample, TextGenPool

class MyDataPool(TextGenPool):
   @classmethod
   def prepare(cls, split: str):
       .. 
       samples = []
       for ix, item in enumerate(..):
           sample = Sample(id=f"{split}_{ix}",
                           prompt_or_input_text=item["document"],
                           references=[item["target"]]
                           )
           samples.append(sample)
       pool_instance = cls(samples)
       return pool_instance

Adding reward function

Custom reward funtions can be implemented easily by sub-classing RewardFunction (a callable) which takes observation ($s$), next observation ($s'$), action ($a$), done (indicating whether episode is finished) and meta info (containing other information about textual input). Here, Observation is a data class object consisting of generated text (at a particular step), prompt text, context text (at that step), reference text which can be used to compute token-level or sentence level rewards.

from rl4lms.envs.text_generation.observation import Observation
from rl4lms.envs.text_generation.reward import RewardFunction


class MyRewardFunction(RewardFunction):
   def __init__(self, *args) -> None:
       super().__init__()

   def __call__(self, prev_observation: Observation,
                action: int,
                current_observation: Observation,
                done: bool,
                meta_info: Dict[str, Any] = None) -> float:
       if done:
           reward = ..
           return reward
       return 0

:bulb: In addition to traditional NLG metrics, for quick prototyping, we provide two synthetic reward functions which trains LMs to generate numbers in increasing order and generate dates. These can be used to quickly test different algorithms and policies. Corresponding configs can be found here (numbers, dates)

Adding custom metrics

Users can create their own evaluation metric which then will be used to periodically evaluate the model on validation split of dataset. This can be done by sub-classing BaseMetric which takes prompt texts, generated texts, reference texts, meta_infos, current LM model, split name as inputs and returns a dict with metric name as key and value consisting of tuple of sentence-level scores and corpus level scores. An example is as follows:


from rl4lms.envs.text_generation.metric import BaseMetric

class MyMetric(BaseMetric):
   def __init__(self) -> None:
       super().__init__()

   def compute(self,
               prompt_texts: List[str],
               generated_texts: List[str],
               reference_texts: List[List[str]],
               meta_infos: List[Dict[str, Any]] = None,
               model: PreTrainedModel = None,
               split_name: str = None):
       metric_dict = {
           "custom_metrics/my_metric": ([0.4, 0.7, 0.9], 0.7)
       }
       return metric_dict

Adding custom on-policy algorithms

In addition to supported on-policy algorithms (PPO, NLPO, A2C,TRPO), users can implement their own on-policy algorithms with ease by sub-classing stable-baselines3's OnPolicyAlgorithm. Since we provide wrappers for on-policy algorithms that handles rollouts using LM policies, environment, computing rewards etc, users just need to implement train() method with custom loss functions.

from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm

class MyOnPolicyAlgorithm(OnPolicyAlgorithm):
    def __init__(**args):
        super().__init__(**args)

    def train(self) -> None:
        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            # Do a complete pass on the rollout buffer
            for rollout_data in self.rollout_buffer.get(self.batch_size):
              # compute loss

Adding custom policies

We provide LM based actor-critic policy implementations that wraps causal LM and seq2seq LMs. These can be also extended (for eg: use a different critic architecture) by overriding appropriate methods (eg. evaluate_actions())

Registry

Finally, just register your custom components by adding them to corresponding registry, after which they can be used directly from configs similar to pre-defined components :wave:

Crowdsourcing templates

We have provided the crowdsourcing templates we used on mechanical turk, along with example inputs in scripts/crowdworking_templates. You might find these a helpful starting point either for evaluating your own model's generations, or for gathering training data for a learned reward function.


Logging and Experiment Results

Additionally, we support WANDB logging and warm-starting of training by storing checkpoints and other training artifacts in a user-specified path. This is especially useful for running preemptible jobs on large, scheduled clusters.

Artifacts include (1) jsonl file containing rollout infos at specified intervals (2) jsonl file containing training infos at specified intervals (3) jsonl file containing validation metrics at specified intervals (4) jsonl file containing test metrics before and after training (5) json file with validation predictions at specified intervals (6) json file with test predictions before and after training (7) trained LM model (8) config json used to run the experiment

Complete usage is as follows:

WANDB_API_KEY=<YOUR-WANDB-API-KEY-HERE>  python scripts/training/train_text_generation.py \
--config_path <PATH-TO-CONFIG-FILE> \
--experiment_name <EXPERIMENT-NAME> \
--base_path_to_store_results <PATH-TO-STORE-RESULTS> \
--log_to_wandb

Citation

@inproceedings{Ramamurthy2022IsRL,
  title={Is Reinforcement Learning (Not) for Natural Language Processing?: Benchmarks, Baselines, and Building Blocks for Natural Language Policy Optimization},
  author={Rajkumar Ramamurthy and Prithviraj Ammanabrolu and Kiant{\'e} Brantley and Jack Hessel and Rafet Sifa and Christian Bauckhage and Hannaneh Hajishirzi and Yejin Choi},
  journal={arXiv preprint arXiv:2210.01241},
  url={https://arxiv.org/abs/2210.01241},
  year={2022}
}

Questions/Discussion/Ideas?

For discussion, questions, ideas exchange, join our slack channel Slack