Home

Awesome

BAdam

The implementation for BAdam: A Memory Efficient Full Parameter Optimization Method for Large Language Models. This paper presents an algorithm named BAdam, which finetunes Llama 2-7b and Llama 3-8B using a single RTX3090 with Adam's update rule and mixed precision training. The core idea of BAdam is to sequentially solve block coordinate optimization sub-problems. From the implementation perspective, the algorithm runs Adam's update on a small portition (usually one single transformer layer) of the parameters, thereby requires much less memory in comparison to full parameter Adam finetuning. Using BAdam only requires one line modification of the original code.

MethodMinimum MemoryActual Memory Cost (Llama 3-8B)Actual Memory Cost (Llama 2-7B)
Adam$18M$144 GB+122.8 GB+
BAdam$2M + \frac{16M}{D}$23.5 GB21.8 GB
<!-- | LoRA | Data | Data | -->

Table 1: Comparison of Methods. $M$ stands for the number of model's parameters in billion and $D$ is the number of blocks used in BAdam. See Table 2 in paper for detailed analysis on memory consumption.

MethodLlama 3-8bLlama 2-7b
Pretrained model5.463.93
LoRA6.415.05
BAdam6.675.21
<!-- | LoRA | Data | Data | -->

Table 2: MT bench score. The model is instruction finetuned on Alpaca-GPT4 dataset using a single RTX3090. BAdam consistently outperforms LoRA in MT bench under various evaluation models.

One can also apply BAdam for larger models with size such as 13B, 22B, 30B, and 70B. The memory consumption can be estimated to be $2M + \frac{16M}{D}$ (GB), plus some additional memory consumption for gradient checkpointed activations and system use like PyTorch's pre-allocation, etc (minor part). When using model parallelism with $N$ GPUs, the memory cost can be estimated by $\frac{2M + 16M/D}{N}$ (GB), plus the additional communication buffers.

Change log

[24/06/16] We support model parallel using Deepspeed ZeRO-3 now!

[24/04/16] Our algorithm has been added to LLaMA-Factory. We would like to express our gratitude to their efforts on integrating BAdam!

[24/04/12] Add LoRA module detection. Make BlockOptimizer compatible with lr scheduler.

Table of Contents

Setup

To install BAdam from Pypi, one can run:

pip install badam

One may also choose to build from source by the following steps:

git clone git@github.com:Ledzy/BAdam.git
cd BAdam
pip install -e .

For those who are interested in reproducing the results in paper, please follow the steps below to setup environment:

conda create -n badam python=3.10
conda activate badam
pip install -r requirements.txt

Usage of BAdam

Partition by Module (A Single GPU)

BAdam uses mixed-precision training, make sure that the model is loaded in float16 precision for memory saving. One can simply add one line of code that wraps the original optimizer to use BAdam.

from badam import BlockOptimizer

# before training, add this line to wrap the original optimizer
optimizer = BlockOptimizer(
    base_optimizer=original_optimizer, # can be any torch.Optimizer
    named_parameters_list=list(model.named_parameters()), 
    switch_block_every=100, # switch to the new block every 50 updates, the $K$ Adam steps in paper. It can be set adaptively by $K = n/(BD)$, where $n$ is the number of training data points, $B$ is the batch size, and $D$ is the number of blocks in BAdam; see "Hyperparameter Suggestion" section for a detailed explaination about setting this hyperparameter. 
    switch_mode="random", # update order of blocks, one can choose "random" (random reshuffling update order), "ascending" (update from input layer to output layer), or "descending" (update from output layer to input layer). The default is "random".
    verbose=2 # information level, will print trainable parameters when setting to 2
)

The above code automatically creates a block partition according to model.named_parameters. Specifically, it treates each transformer layer module as a single block. For instance, for Llama 3-8B, the block partition ($D = 32$) will be

block 1: model.layers.0.
block 2: model.layers.1.
...
block 32: model.layers.31.

By default, the embedding layer and language modeling head is not included in the training blocks. One can add them as two additional blocks by setting include_embedding=True, include_lm_head=True.

<details><summary>Click to see more partition strategies and example code</summary>

One can also specify their own block list for the block optimizer. This can be achieved by adjusting the block_prefix_list argument. For instance, the following code snippets creat block partitions by self_attn and mlp modules (i.e., D = 32 * 2 = 64 for Llama 3-8B), and matrix modules (i.e., D = 32 * 7=224 for Llama 3-8B), respectively, which helps further reduce the memory cost:

# block partition by self_attn and mlp modules
block_prefix_list = []
for i in range(32):
    layer_prefix = [
        [f"model.layers.{i}.self_attn."],
        [f"model.layers.{i}.mlp."],
    ]
    block_prefix_list.extend(layer_prefix)

optimizer = BlockOptimizer(
    base_optimizer=original_optimizer,
    named_parameters_list=list(model.named_parameters_list), 
    switch_block_every=100,
    switch_mode="random",
    verbose=2,
    block_prefix_list=block_prefix_list # set the block list
)
#block partition by matrix modules
block_prefix_list = []
for i in range(32):
    layer_prefix = [
        [f"model.layers.{i}.self_attn.q_proj."],
        [f"model.layers.{i}.self_attn.k_proj."],
        [f"model.layers.{i}.self_attn.v_proj."],
        [f"model.layers.{i}.self_attn.o_proj."],
        [f"model.layers.{i}.mlp.gate_proj."],
        [f"model.layers.{i}.mlp.up_proj."],
        [f"model.layers.{i}.mlp.down_proj."],
    ]
    block_prefix_list.extend(layer_prefix)

optimizer = BlockOptimizer(
    base_optimizer=original_optimizer,
    named_parameters_list=list(model.named_parameters_list), 
    switch_block_every=100,
    switch_mode="random",
    verbose=2,
    block_prefix_list=block_prefix_list # set the block list
)

We have tested that block partition by self_attn and mlp modules achieves a MT-bench score 6.65 for finetuning Llama 3-8B. This score matches that (6.67) achieved by block partition by transformer layer modules, while further reduces the memory cost.

</details>

Important Notes:

Partition by Module (Model Parallel)

We support the model parallel offered by deepspeed ZeRO-3. It partitions the model, gradient, and optimizer states across different GPUs so that one can train large models (e.g., Llama 3-70B) that cannot be fit into a single GPU. Given $N$ GPUs, the per GPU memory cost can be estimated by $\frac{2M + 16M/D}{N}$ (GB), plus the additional cost for communication buffer and temporary parameter gathering buffer arised during forward/backward. These buffer sizes can be configurated manually and determines the efficieny of the communication system.

<details><summary>Click to see instructions for model parallelism</summary>

To use ZeRO-3, one needs to set ds_zero3_enabled=True when initializing the BlockOptimizer. Then, set block_optimizer.ds_optimizer = ds_optimizer after calling deepspeed.initialize.

from badam import BlockOptimizer

optimizer = BlockOptimizer(
    ...,
    ds_zero3_enabled=True # set it to True
)

model, ds_optimizer = deepspeed.initialize(model=model, optimizer=optimizer, ...)

# create the reference to the ds_optimizer, for the purpose of setup ZeRO-3's environment
optimizer.ds_optimizer = ds_optimizer

When using huggingface Trainer to control the workflow, accessing ds_optimizer is not direct. One can add the BAdamCallback which automatically handles the reference to ds_optimizer:

from badam.utils import BAdamCallback

callbacks = original_callbacks.append(BAdamCallback) # add the callback
trainer = YourTrainerClass(
    ...,
    callbacks=callbacks
)

The model parallelism results in noticable overhead due to the communication cost. In particular, we empirically observe about 3 times overhead when training Llama 3-8B with 4 RTX3090 GPUs (without NVLink) using ZeRO-3, in comparison to using a single GPU, under the same per_device_batch_size. Fortunately, one may use a larger per_device_batch_size to accelerate the training speed as ZeRO-3 greatly reduces the per GPU memory cost.

Make sure to use accelerate config to configurate the distributed training and then use proper command to launch your script in a distributed way, such as accelerate launch and deepspeed.

</details>

Partition by Parameter Ratio

Instead of partitioning block by the model's parameter, an alternative choice is to train all the parameters simultaneously with a fixed ratio. For instance, we can train 5% parameters of every transformer layer. Namely, each active block contains 5% parameters from every transformer layer. In this sense, the feature extractor of every layer are jointly trained, which may be preferred in certain scenarios. However, training a block consisting of parameters coming from all the transformer layers may lose partly the benefit of BP time saving of BAdam.

<details><summary>Click to see example code and instructions</summary>
from badam import BlockOptimizerRatio

optimizer = BlockOptimizerRatio(
    param_groups=param_groups, # param_group of torch.Optimizer, the same as the original optimizer
    named_parameters_list=list(self.model.named_parameters()),
    switch_every=100, # switch to the new block every 100 updates
    update_ratio=0.1, # ratio of trainable weight for each parameter
    mask_mode = "adjacent", # choices: ["adjacent", "scatter"], see Note below for more explanation
    lr=1e-6,
    betas=(0.9, 0.999), # betas for Adam update
    eps=1e-8, # eps of Adam update
)

Currently, the BlockOptimizerRatio only supports the Adam update. The repository is still under active development.

Notes:

</details>

Hyperparameter Suggestion

Run Paper Experiment

<details><summary>Llama 3-8B and Llama 2-7B on Alpaca-GPT4</summary>

Our implementation of finetuning Llama 3 and Llama 2 is based on Llama Factory. This repository mainly serves as the purpose for reproducing our paper's results. For better support on advanced algorithmic features, we suggest to use the latest version of Llama Factory.

For the experiment of finetuning Llama-2 7b on Alpaca-GPT4 dataset, change the working directory to llama:

cd llama-alpaca

Here is a sample command for running the code:

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage sft \
    --model_name_or_path meta-llama/Llama-2-7b \
    --do_train \
    --dataset alpaca_gpt4_en \
    --template default \
    --finetuning_type block \
    --output_dir ./outputs/llama2-7b \
    --overwrite_cache \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --lr_scheduler_type cosine \
    --logging_steps 1 \
    --save_steps 1000 \
    --val_size 500 \
    --eval_steps 20 \
    --evaluation_strategy steps \
    --learning_rate 1e-6 \
    --num_train_epochs 3 \
    --overwrite_output_dir \
    --plot_loss \
    --switch_block_every 100 \
    --switch_mode random

To finetune Llama 3-8B, one can set --model_name_or_path meta-llama/Meta-Llama-3-8B. We use learning rate 1e-6 for Llama 3-8B and learning rate 1e-5 for Llama 2-7B, respectively. It is important to note that the favorable learning rate may vary for different models and datasets.

Notes on arguments:

</details> <details><summary>RoBERTa-large on SuperGLUE</summary>

Our implementation for finetuning RoBERTa-large on superGLUE is based on jiant. To run the code, go to directory roberta-superglue first:

cd roberta-superglue

Before training the model, download the dataset using the following bash script. Adjust the script to download the required dataset.

EXP_DIR=./content/exp

python jiant/scripts/download_data/runscript.py \
    download \
    --tasks copa \
    --output_path ${EXP_DIR}/tasks

The finetuning command has the following form:

CUDA_VISIBLE_DEVICES=0 python badam_ft.py \
    --task_name boolq \
    --num_train_epochs 32 \
    --eval_every_steps 100 \
    --use_block_optim \
    --switch_every 100 \
    --switch_mode ascending \
    --train_batch_size 16 \
    --train_last_layer \
    --hf_pretrained_model_name FacebookAI/roberta-large

Notes on arguments:

</details>