Home

Awesome

DoReMi🎶: Domain Reweighting with Minimax Optimization

License: MIT arXiv

PyTorch implementation of DoReMi, an algorithm for optimizing data mixtures for language modeling datasets. Modern large language models are trained on many domains (web, books, arXiv, etc.), but how much of each domain to train on is unclear, especially since these models are going to be used for a variety of downstream tasks (no particular target distribution to optimize for). DoReMi tunes the data mixture to be robust to the target distribution using Distributionally Robust Optimization (DRO). DoReMi trains a small proxy model using DRO, which dynamically upweights or downweights domains based on the proxy model's excess loss compared to a pretrained reference model. The reference model provides an estimate of the best loss achievable to avoid being to pessimistic for high entropy / hard domains. The tuned data mixture can then be used to train a much larger model more efficiently. In the paper, a 280M proxy model can improve the training of a 8B parameter model (30x larger), allowing it to achieve the baseline 8B performance 2.6x faster. The graphic below provides an overview of DoReMi. Check out the paper for more details.

As a black box, this codebase outputs optimized domain weights given a text dataset. Some other useful components: fast, resumable dataloader with domain-level weighted sampling, simple downstream eval harness, and HuggingFace Trainer + FlashAttention2 integration.

High-level overview of DoReMi.

Getting started

To get started, please clone the repo and install it:

git clone git@github.com:/sangmichaelxie/doremi.git
pip install -e doremi
cd doremi && bash scripts/setup_flash.sh

The compilation in scripts/setup_flash.sh may take a significant amount of time (hours). All code should be run from the outermost doremi directory. Before you start, write paths to your cache directories, data directories, etc in a constants.sh file in the outer directory of this repo. You can also place any conda or virtualenv activation commands here. Here's an example of the contents of a constants.sh file (provided as a file called sample_constants.sh):

#!/bin/bash
CACHE=/path/to/cache
DOREMI_DIR=/path/to/this/repo
PILE_DIR=/path/to/pile
PREPROCESSED_PILE_DIR=/path/to/preprocessed  # will be created by scripts/run_preprocess_pile.sh
MODEL_OUTPUT_DIR=/path/to/model_output_dir
WANDB_API_KEY=key  # Weights and Biases key for logging
PARTITION=partition # for slurm
mkdir -p ${CACHE}
mkdir -p ${MODEL_OUTPUT_DIR}
source ${DOREMI_DIR}/venv/bin/activate  # if you installed doremi in venv

Here is how to run the sample script for data preprocessing on The Pile, which separates the Pile data into domains and tokenizes it:

bash scripts/run_preprocess_pile.sh

Here is a sample script to run 120M baseline, proxy, and main models (all 3 steps in the DoReMi pipeline), tested on one node with 8 A100 GPUs. This is a small version of the Pile experiments in the paper. The script will automatically run perplexity and few-shot evaluation:

bash scripts/run_pile.sh

These scripts run for 200k steps, following the paper. The DoReMi run outputs domain weights in the configs directory with filename <RUN_NAME>.json.

Running DoReMi on your own dataset

To run DoReMi on your own dataset, provide preprocessed (tokenized) data in the following format:

top_level/
    domain_name_1/
        files...
    domain_name_2/
        files...
    ...

where each inner directory (e.g., domain_name_1) can be loaded via HuggingFace's load_from_disk method. If your data is in a different format, you can add a custom data loading function in doremi/dataloader.py. You will also need to write a config file and save it to configs/ and write run scripts similar to scripts/runs/run_pile_baseline120M.sh and scripts/runs/run_pile_doremi120M.sh which refer to the config file. The config file specifies the mapping from domain name to mixture weight. The names do not have to be in order (DoReMi will always sort the domain names first to determine a fixed ordering) and the weights do not have to be normalized.

Tips and details

Note that there are a few differences between this repo and the paper, which was developed at Google, namely:

Sample run results

Below are results from one round of DoReMi on The Pile using 120M proxy and reference models (with scripts/run_pile.sh). We train a 120M model using the optimized weights (configs/pile_doremi_r1_120M_ref:pile_baseline_50kvocab_nopack_120M.json) and compare it to the baselines (gray). The two baselines represent two slightly different ways of calculating the baseline domain weights in the Pile (nopack counts the number of examples in each domain after padding each document to the context window length, whereas pack concatenates the documents within one domain first), which produce similar-performing models. The model trained with DoReMi domain weights surpasses the baseline one-shot performance very early during training, within 70k steps (3x faster) across all tasks. The DoReMi model surpasses the average baseline one-shot performance within 20k steps, has improved or comparable perplexity on 15/22 domains, and improves both uniformly averaged and worst-case perplexity across domains.

<img src="fewshot_120M_pile.png" width="60%">

Citation

If this was useful to you, please cite the paper:

@article{xie2023doremi,
  author = {Sang Michael Xie and Hieu Pham and Xuanyi Dong and Nan Du and Hanxiao Liu and Yifeng Lu and Percy Liang and Quoc V. Le and Tengyu Ma and Adams Wei Yu},
  journal = {arXiv preprint arXiv:2305.10429},
  title = {DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining},
  year = {2023},
}