Home

Awesome

FlashAttention in PyTorch

A simplified implementation of FlashAttention in PyTorch. I have implemented the forward pass and backward pass algorithms from the paper, and also shown that it is equivalent to the normal attention formulation in Transformers. I also include some code for benchmarking.

Note that this is for educational purposes only as I haven't implemented any of the CUDA and SRAM memory tricks as described in the paper.

Requirements

Files

To run

Forward pass

Causal mask
python flash_attention_causal.py

Random mask
python flash_attention.py

Benchmarking - Causal mask

FlashAttention
python bench_causal.py --b 1 --h 2 --q_len 16384 --kv_len 16384 --d 512 --type flash

Normal attention
python bench_causal.py --b 1 --h 2 --q_len 16384 --kv_len 16384 --d 512 --type normal

Add --profile to log additional details using PyTorch Profiler.

Benchmarking - Random mask

FlashAttention
python bench.py --b 1 --h 2 --q_len 16384 --kv_len 16384 --d 512 --type flash

Normal attention
python bench.py --b 1 --h 2 --q_len 16384 --kv_len 16384 --d 512 --type normal

Add --profile to log additional details using PyTorch Profiler.

Backward Pass

Causal mask
python check_backward_causal.py

Random mask
python check_backward.py