Home

Awesome

Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training

This is an official implementation of the Sophia-G optimizer in the paper https://arxiv.org/abs/2305.14342 and GPT-2 training scripts. The code is based on nanoGPT and levanter. Please cite the paper and star this repo if you find Sophia useful. Thanks!

@article{liu2023sophia,
 title={Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training},
 author={Liu, Hong and Li, Zhiyuan and Hall, David and Liang, Percy and Ma, Tengyu},
 journal={arXiv preprint arXiv:2305.14342},
 year={2023}
}

News and Updates

Dependencies

General Usage

Below is an example code snippet for training a general model with NLL loss with SophiaG. Please refer to the next section for guidelines on hyperparameter tuning.

import torch
import torch.nn.functional as F
from sophia import SophiaG

# init model loss function and input data
model = Model()
data_loader = ...

# init the optimizer
optimizer = SophiaG(model.parameters(), lr=2e-4, betas=(0.965, 0.99), rho=0.01, weight_decay=1e-1)

total_bs = len(data_loader)
bs = total_bs * block_size
k = 10
iter_num = -1

# training loop
for epoch in range(epochs):
    for X, Y in data_loader:
        # standard training code
        logits, loss = model(X, Y)
        loss.backward()
        optimizer.step(bs=bs)
        optimizer.zero_grad(set_to_none=True)
        iter_num += 1

        if iter_num % k != k - 1:
            continue
        else:
            # update hessian EMA
            logits, _ = model(X, None)
            samp_dist = torch.distributions.Categorical(logits=logits)
            y_sample = samp_dist.sample()
            loss_sampled = F.cross_entropy(logits.view(-1, logits.size(-1)), y_sample.view(-1), ignore_index=-1)
            loss_sampled.backward()
            optimizer.update_hessian()
            optimizer.zero_grad(set_to_none=True)
            model.zero_grad()

Hyper-parameter Tuning

Definition of learning rate

Tuning the hyperparameter $\rho$

Tuning the learning rate and weight decay

<p align="center" width="100%"> <img src="assets/t5_winrate.png" style="width: 60%; min-width: 200px; display: block; margin: auto;"> </p>

Hyperparameters for GPT-2 models

Model Sizelr for Adamlr for Lionlr for Sophia$\rho$ for Sophiaweight decay for Sophia
125M6e-41e-46e-40.050.2
355M3e-41e-47e-40.080.2
770M2e-48e-53e-40.050.2

Reproduce GPT-2 Results

Prepare the OpenWebText data following nanoGPT:

$ python data/openwebtext/prepare.py

Start pre-training GPT2 Small (125M):

If you have a machine with 10 A5000 (24GB) GPUs,

$ torchrun --standalone --nproc_per_node=10 \
      train_sophiag.py \
      config/train_gpt2_small_sophiag.py \
      --batch_size=8 \
      --gradient_accumulation_steps=6

If you have a machine with 8 A100 (40GB) GPUs,

$ torchrun --standalone --nproc_per_node=8 \
      train_sophiag.py \
      config/train_gpt2_small_sophiag.py \
      --batch_size=12 \
      --gradient_accumulation_steps=5

To reproduce the AdamW baseline following nanoGPT:

$ torchrun --standalone --nproc_per_node=10 \
      train_adam.py \
      config/train_gpt2_small_adam.py \
      --batch_size=8 \
      --gradient_accumulation_steps=6

This will lead to results in the figure below:

<p align="center" width="100%"> <img src="assets/small_100k_plus.png" style="width: 60%; min-width: 200px; display: block; margin: auto;"> </p>

Start pre-training GPT2 Medium (355M):

If you have a machine with 8 A100 (40GB) GPUs,

$ torchrun --standalone --nproc_per_node=8 \
      train_sophiag.py \
      config/train_gpt2_medium_sophiag.py \
      --batch_size=6 \
      --gradient_accumulation_steps=10

To reproduce the AdamW baseline:

$ torchrun --standalone --nproc_per_node=8 \
      train_adam.py \
      config/train_gpt2_medium_adam.py \
      --batch_size=6 \
      --gradient_accumulation_steps=10

Please adjust nproc_per_node, batch_size, and gradient_accumulation_steps accordingly if you use other hardware setup. Make sure their product equals 480.

This will lead to results in the figure below:

<p align="center" width="100%"> <img src="assets/medium_100k_plus.png" style="width: 60%; min-width: 200px; display: block; margin: auto;"> </p>

Start pre-training GPT2 1.5B:

We use the Pile and GPT NeoX tokenizer. First set up TPU instances and environment following levanter. Then change GAMMA_SOPHIA_G to 200 in optim.py. The training script for 1.5B model is

gcloud compute tpus tpu-vm ssh <instance_name> \
      --zone <zone_name> \
      --worker=all \
      --command 'WANDB_API_KEY=<wandb_api_key> levanter/infra/launch.sh python levanter/examples/gpt2_example.py --config_path levanter/config/gpt2_1536_pile.yaml --trainer.beta1 0.965 --trainer.beta2 0.99 --trainer.min_lr_ratio 0.020 --trainer.weight_decay 0.15 --trainer.learning_rate 2.5e-4 --trainer.warmup_ratio 0.01'

Acknowledgement

The GPT-2 training code is based on nanoGPT, which is elegant and super efficient.