Home

Awesome

OptimalShardedDataParallel

Description

Optimal Sharded Data Parallel (OSDP), an automated parallel training system that combines the advantages from both data and model parallelism, which has a number of advanced characteristics:

Feel free to contribute codes, create issues and pull requests.

Papers

Environment

The following command create the conda environment to be used:

$ conda env create -f environment.yml

Or prepare the environment by:

$ sh prepare_env.sh

Implementation

Example using OSDP:

from data_parallel.optimal_sharded_data_parallel import OptimalShardedDataParallel as OSDP
...
sharded_module = OSDP(my_module, model_description, device_information)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
for sample, label in dataload.next_batch:
  out = sharded_module(x=sample, y=3, z=torch.Tensor([1]))
  loss = criterion(out, label)
  loss.backward()
  optim.step()

OSDP training for GPT models

Execute the train_gpt2_...py file through the scripts/script_gpt2_...sh script, and deploy the OSDP experiment by specifying fsdp_type as OSDP (specify fsdp_type as FSDP to deploy the comparative experiment).

$ cd gpt
$ sh scripts/script_gpt2_osdp.sh
$ sh scripts/script_gpt2_fsdp.sh

Experimental results

We show the system throughput and memory utilization of GPT-2 model training (48 layers with hidden_size 2048) in our environment (GPU memory limit: 8G):

OSDP training for OPT models

We add OSDP implementation for OPT models. The following instructions can deploy the OPT-30B (8 layers) training on a single machine with 8 GPUs and memory limit 16GB.

$ cd opt
$ sh scripts/train_fsdp.sh
$ sh scripts/train_osdp.sh

Experimental results

Feature: Operator splitting

Description

Operator splitting provides OSDP with the ability to search for a finer-grained execution plan for the model as well as minimizes memory surge in training, which provides OSDP with the ability to undertake a larger batch size and further optimize the system throughput.

Implementation

Example using operator splitting:

class Layer(nn.Module):
  def __init__(self, config):
    self.mlp = splitted_linear(config...)
    ...
  
  def forward(self, input):
    output = splitted_linear_forward(input, self.mlp, num_splits)
    ...

Feature: Group Sharding & Communication with groups

Group Sharding

Trade-off between memory consumption and system throughput for more efficient use of inter-machine bandwidth.

Communication with groups

Increase system throughput by reducing inter-machine communication parameters (usually the inter-machine bandwidth is much lower than the intra-machine bandwidth).

Implementation

Example of using Group Sharding training bert-large with 2 machines and 4 GPUs (we use the auto_wrap API provided by fairscale to complete sharded data parallel deployment):

fsdp_args = gen_fsdp_args(nnodes=2, nproc_per_node=2, gsdp_type='group_sharding', model_type='bert-large')
with enable_wrap(wrapper_cls=FSDP, **fsdp_args):
        my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=1e6)
        model = auto_wrap(model, auto_wrap_policy=my_auto_wrap_policy) #auto_wrap

Example of using Communication with groups training bert-large with 2 machines and 4 GPUs:

fsdp_args = gen_fsdp_args(nnodes=2, nproc_per_node=2, gsdp_type='communication_with_groups', model_type='bert-large')
with enable_wrap(wrapper_cls=FSDP, **fsdp_args):
        my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=1e6)
        model = auto_wrap(model, auto_wrap_policy=my_auto_wrap_policy) #auto_wrap

Running Group Sharding & Communication with groups

We provide an example of OSDP training bert with Group Sharding & Communication with groups:

Execute the train_bert_large...py file through the scripts/script_bert_large_...sh script, and deploy the Group Sharding or Communication with groups experiment by specifying fsdp_args as 'group_sharding' or 'communication_with_groups' (specify fsdp_args as 'none' to deploy the comparative experiment).

$ cd bert
$ sh scripts/script_bert_large_group_sharding.sh
$ sh scripts/script_bert_large_communication_with_groups.sh
$ sh scripts/script_bert_large_fsdp.sh

Cite

If you use OSDP in a scientific publication, we would appreciate citations to the following paper:

@misc{jiang2023osdp,
      title={OSDP: Optimal Sharded Data Parallel for Distributed Deep Learning}, 
      author={Youhe Jiang and Fangcheng Fu and Xupeng Miao and Xiaonan Nie and Bin Cui},
      year={2023},
      eprint={2209.13258},
      archivePrefix={arXiv},
      primaryClass={cs.DC}
}