Home

Awesome

SageAttention

This repository provides the official implementation of SageAttention and SageAttention2.

SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration
Paper: https://arxiv.org/abs/2410.02367
Jintao Zhang, Jia Wei, Haofeng Huang, Pengle Zhang, Jun Zhu, Jianfei Chen

SageAttention2 Technical Report: Accurate 4-Bit Attention for Plug-and-play Inference Acceleration
Paper: https://arxiv.org/abs/2411.10958
Jintao Zhang, Haofeng Huang, Pengle Zhang, Jia Wei, Jun Zhu, Jianfei Chen

Local Image SageAttention on CogvideoX-2B (RTX4090)

Local Image SageAttention2 on Llama3.1-8B

Beta Version of SageAttention2

This is a beta release of SageAttention2. We welcome any feedback on accuracy, performance issues, bugs, feature requests, or suggestions. Please feel free to open an issue or launch a pull request!

Current Features:

For stable version, please use SageAttention-1.

Project Updates

Base environment

We recommend to install: (the kernel will be faster a little)

Installation

For stable version or Triton only version, refer to SageAttention-1 or install using pip:

pip install sageattention==1.0.6

To use SageAttention 2.0.0, please compile from source:

git clone https://github.com/thu-ml/SageAttention.git
cd sageattention 
pip install -e . # or python setup.py install

Note: Currently, SageAttention is optimized for excellent performance on RTX4090, RTX3090, L20, and L40 GPUs. On A100, A800, and A6000 GPUs, performance is best with a head_dim=128, while head_dim=64 is less optimal. Similarly, performance on the Hopper architecture is currently not optimal. We are actively working to enhance performance in these configurations.

How to Use

from sageattention import sageattn
attn_output = sageattn(q, k, v, tensor_layout="HND", is_causal=False)

Available APIs:

For optimal speed and accuracy performance on custom devices and models, we strongly recommend referring to the this file for detailed guidance.

Note: Support for head_dim values of 64, 96, and 128 is currently available. Extended support for other head_dim is under development. Support for different sequence length between q and k,v and group-query attention is available.

Plug-and-play Example

We can replace scaled_dot_product_attention easily.
We will take Cogvideo as an example:

Just add the following codes and run!

from sageattention import sageattn
import torch.nn.functional as F

F.scaled_dot_product_attention = sageattn

Specifically,

cd example
python sageattn_cogvideo.py

You can get a lossless video in ./example faster than by using python original_cogvideo.py

Note: Not all models use F.scaled_dot_product_attention, so maybe you should replace the original Attention by modifying the Attention Class of the target model.

Performance

Speed of Kernels

8+8 means the kernel with INT8 quantization for $QK^\top$ and FP8 quantization for $PV$. 8+16 uses FP16 for $PV$. Local Image

Local Image

Local Image

Local Image

Local Image

Local Image

Local Image

Local Image

Local Image

Note: The TOPS results refer only to the Attention Kernel, excluding the quantization and smoothing. we use FP16 accumulator for FP16 $PV$, and FP32 accumulator for FP8 $PV$.

End-to-end Performance

<!-- ![Local Image](./resource/real_speedup.png) -->

The table below shows the end-to-end performance across various models using SageAttention 1.0. For more evaluation, please refer to our paper. Local Image

Citation

If you use this code or find our work valuable, please cite:

@misc{zhang2024sageattention,
      title={SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration}, 
      author={Jintao Zhang and Jia wei and Haofeng Huang and Pengle Zhang and Jun Zhu and Jianfei Chen},
      year={2024},
      eprint={2410.02367},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2410.02367}, 
}

@misc{zhang2024sageattention2,
      title={SageAttention2 Technical Report: Accurate 4 Bit Attention for Plug-and-play Inference Acceleration}, 
      author={Jintao Zhang and Haofeng Huang and Pengle Zhang and Jia Wei and Jun Zhu and Jianfei Chen},
      year={2024},
      eprint={2411.10958},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2411.10958}, 
}