Home

Awesome

<img src="./ring.png" width="450px"></img>

Ring Attention - Pytorch

Implementation of <a href="https://arxiv.org/abs/2310.01889">Ring Attention</a>, from <a href="https://www.haoliu.site/">Liu</a> et al. at Berkeley AI, in Pytorch.

It basically splits the data across the sequence dimension (instead of batch) and applies ring reduce to the processing of the tiles of the attention matrix, flash attention style.

I believe this is being used for the 1-10 million tokens for the latest Gemini. At least some form of it; the other possibility would be unpublished improvements on top of <a href="https://github.com/lucidrains/recurrent-memory-transformer-pytorch">RMT</a>.

In addition, the repository also contains the logic for <a href="https://arxiv.org/abs/2311.09431">Striped Attention</a>, a follow up paper that permutes the sequence for better workload balancing for autoregressive transformers.

It also contains support for <a href="https://arxiv.org/abs/2305.13245">grouped query attention</a>, popularized by Llama series of attention models. This will further save on communication costs during the ring reduce.

Appreciation

Install

$ pip install ring-attention-pytorch

Usage

import torch
from ring_attention_pytorch import RingAttention

attn = RingAttention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    causal = True,
    auto_shard_seq = True,
    ring_attn = True,
    ring_seq_size = 512
)

tokens = torch.randn(1, 1024, 512)
attended = attn(tokens)

assert attended.shape == tokens.shape

This repository also contains an implementation of <a href="https://arxiv.org/abs/2408.04093">Tree Attention Decoding</a> from Shyam et al.

It can be imported and used as follows

from ring_attention_pytorch import tree_attn_decode

out = tree_attn_decode(q, k, v) # where q, k, v exists across all machines

Test

First install requirements

$ pip install -r requirements.txt

Then say testing autoregressive striped ring attention on cuda would be

$ python assert.py --use-cuda --causal --striped-ring-attn

Testing tree attention would be

$ python assert_tree_attn.py --use-cuda --seq-len 8192

Todo

Citations

@article{Liu2023RingAW,
    title    = {Ring Attention with Blockwise Transformers for Near-Infinite Context},
    author   = {Hao Liu and Matei Zaharia and Pieter Abbeel},
    journal  = {ArXiv},
    year     = {2023},
    volume   = {abs/2310.01889},
    url      = {https://api.semanticscholar.org/CorpusID:263608461}
}
@article{Brandon2023StripedAF,
    title   = {Striped Attention: Faster Ring Attention for Causal Transformers},
    author  = {William Brandon and Aniruddha Nrusimha and Kevin Qian and Zachary Ankner and Tian Jin and Zhiye Song and Jonathan Ragan-Kelley},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2311.09431},
    url     = {https://api.semanticscholar.org/CorpusID:265220849}
}
@article{Dao2022FlashAttentionFA,
    title   = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
    author  = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2205.14135}
}
@article{dao2023flashattention2,
    title   = {Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
    author  = {Dao, Tri},
    year    = {2023}
}
@article{Tillet2019TritonAI,
    title   = {Triton: an intermediate language and compiler for tiled neural network computations},
    author  = {Philippe Tillet and H. Kung and D. Cox},
    journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
    year    = {2019}
}
@article{Ainslie2023GQATG,
    title   = {GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints},
    author  = {Joshua Ainslie and James Lee-Thorp and Michiel de Jong and Yury Zemlyanskiy and Federico Lebr'on and Sumit K. Sanghai},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.13245},
    url     = {https://api.semanticscholar.org/CorpusID:258833177}
}
@inproceedings{Shyam2024TreeAT,
    title   = {Tree Attention: Topology-aware Decoding for Long-Context Attention on GPU clusters},
    author  = {Vasudev Shyam and Jonathan Pilault and Emily Shepperd and Quentin Anthony and Beren Millidge},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:271769078}
}

<a href="http://www.incompleteideas.net/IncIdeas/BitterLesson.html">The Bitter Lesson</a> - Richard Sutton