Home

Awesome

GPU/TPU Jax implementation of RingAttention

This codebase provides the implementation of the Ring Attention with Blockwise Transformers. The model is described in the paper Ring Attention with Blockwise Transformers for Near-Infinite Context and Blockwise Parallel Transformer for Large Context Models.

Ring Attention with Blockwise Parallel Transformers enables training sequences up to a length of 'number of devices' times longer than those possible with BPT. This is achieved by distributing the attention and feedforward computation across multiple devices and overlapping the communication with computation. Thanks to the blockwise computing of the attention and feedforward network, it is possible to train with tens of millions of tokens in context size without adding any communication or computation overhead.

Example usage here with code snippets

First, install the package:

pip install ringattention

Then, ringattention and blockwise_feedforward can be imported and used as follows:

from ringattention import ringattention, blockwise_feedforward

You can use the ringattention function by wrapping it with shard_map to shard the computation across multiple devices. Here is an example of how to use the ringattention function with sharding:

ring_attention_sharded = shard_map(
    partial(
        ringattention,
        axis_name="sp",
        float32_logits=True,
        cache_idx=None,
        blockwise_kwargs=dict(
            causal_block_size=1,
            deterministic=True,
            dropout_rng=None,
            attn_pdrop=0.0,
            query_chunk_size=512,
            key_chunk_size=512,
            policy=jax.checkpoint_policies.nothing_saveable,
            dtype=jax.numpy.float32,
            precision=None,
            prevent_cse=True,
        )
    ),
    mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim),
    in_specs=(
        PS(("dp", "fsdp"), "sp", "tp", None),
        PS(("dp", "fsdp"), "sp", "tp", None),
        PS(("dp", "fsdp"), "sp", "tp", None),
        PS(("dp", "fsdp"), None, None, None),
        PS(("dp", "fsdp"), None),
    ),
    out_specs=PS(("dp", "fsdp"), "sp", "tp", None),
    check_rep=False
)
attn_output = ring_attention_sharded(xq, xk, xv, attention_bias, segment_ids)

Explanation of the arguments:

RingAttention is utilized in Large World Model (LWM) for million-length vision-language training, where a complete example of using RingAttention and BlockwiseTransformers can be found: LWM codebase

Reference

If you find our work relevant to your research, please cite:

@article{liu2023blockwise,
    title={Blockwise Parallel Transformer for Large Context Models},
    author={Liu, Hao and Abbeel, Pieter},
    journal={Advances in neural information processing systems},
    year={2023}
}
@article{liu2023ring,
    title={Ring Attention with Blockwise Transformers for Near-Infinite Context},
    author={Liu, Hao and Zaharia, Matei and Abbeel, Pieter},
    journal={arXiv preprint arXiv:2310.01889},
    year={2023}
}