Home

Awesome

Ring Flash Attention

This repo implements the RingAttention with FlashAttention. Currently, this repo implements:

Note that

The current performance on 8xH800 is (benchmark/benchmark_qkvpacked_func.py):

GPUtheoretic flash_attnring_attnzigzag_ringstripe_attn
fwd only (iter/sec)8xH8002418.4 / 8 = 302.3208.0283.0259.6
68.8%93.6%85.9%
fwd + bwd (iter/sec)8xH800705.2 / 8 = 88.254.375.776.9
61.5%85.9%87.2%
fwd only (iter/sec)8xA1001545.9 / 8 = 193.2124.4179.0163.9
64.3%92.7%84.8%
fwd + bwd (iter/sec)8xA100470.6 / 8 = 58.833.349.545.9
56.6%84.1%78.1%

Note that

Installation

pip install ring-flash-attn

or use the following command to build from source:

git clone https://github.com/zhuzilin/ring-flash-attention.git
cd ring-flash-attention
pip install .

Limits

There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the values with the original fp32 ones.

And also because we need to save extra fp32 buffer during computation, the memory usage would be higher than theoretic limit.

TODOs

Test

torchrun --nproc_per_node 8 test/test_llama3_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_ring_flash_attn_func.py
torchrun --nproc_per_node 8 test/test_ring_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_func.py
torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_stripe_flash_attn_func.py

Benchmark

torchrun --nproc_per_node 8 benchmark/benchmark_qkvpacked_func.py
torchrun --nproc_per_node 8 benchmark/benchmark_varlen_qkvpacked_func.py

Known Limits