Home

Awesome

Ring Flash Attention

This repo implements RingAttention using FlashAttention. The current implementation supports:

Note that

Performance Summary

The following table summarizes the performance of the implemented APIs:

batch apiGPUtheoretic<br />flash_attnring_attnzigzag_ringstripe_attn
fwd only (iter/sec)8xH800591.5 / 8 = 73.938.563.055.0
52.1%85.2%74.4%
fwd + bwd (iter/sec)8xH800154.7 / 8 = 19.310.417.416.0
53.9%90.2%82.9%
fwd only (iter/sec)8xA100373.4 / 8 = 46.724.038.232.5
51.4%81.7%69.6%
fwd + bwd (iter/sec)8xA10094.7 / 8 = 11.86.210.69.75
52.5%89.8%82.6%
varlen apiGPUtheoretic<br />flash_attnring_attnzigzag_ringllama3_attn
fwd only (iter/sec)8xH800852.4 / 8 = 106.652.474.860.8
49.1%70.2%57.0%
fwd + bwd (iter/sec)8xH800225.4 / 8 = 28.214.421.416.4
51.1%75.9%58.1%
fwd only (iter/sec)8xA100532.3 / 8 = 66.533.147.934.3
49.8%72.0%51.6%
fwd + bwd (iter/sec)8xA100133.8 / 8 = 16.78.713.49.7
52.1%80.2%58.0%

Note that

Installation

pip install ring-flash-attn

or use the following command to build from source:

git clone https://github.com/zhuzilin/ring-flash-attention.git
cd ring-flash-attention
pip install .

TODOs

Test

torchrun --nproc_per_node 8 test/test_llama3_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_ring_flash_attn_func.py
torchrun --nproc_per_node 8 test/test_ring_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_func.py
torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_stripe_flash_attn_func.py

Benchmark

torchrun --nproc_per_node 8 benchmark/benchmark_kvpacked_func.py
torchrun --nproc_per_node 8 benchmark/benchmark_varlen_kvpacked_func.py

Known Limitations

There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the values with the original fp32 ones.

And also because we need to save extra fp32 buffer during computation, the memory usage would be higher than theoretic limit.

Also,