Home

Awesome

<div align="center">

Flash Bi-directional Linear Attention

</div>

The aim of this repository is to implement bi-directional linear attention for non-causal modeling using Triton.

<div align="center"> <img width="600" alt="image" src="https://private-user-images.githubusercontent.com/74758580/387246938-cd89a618-5d54-41b7-9055-36ba28b29fbd.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzQ3OTEwODQsIm5iZiI6MTczNDc5MDc4NCwicGF0aCI6Ii83NDc1ODU4MC8zODcyNDY5MzgtY2Q4OWE2MTgtNWQ1NC00MWI3LTkwNTUtMzZiYTI4YjI5ZmJkLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDEyMjElMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQxMjIxVDE0MTk0NFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTRkYjk5MDk5YTExNDZiZDZiNmMyNzlhYzk2ZmRiNjZiZjk4ZTdhNzhlMzRiNTA0MDU0NTRiYWI5NzYyYWU5ODQmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.gCyr6rJJgGEbq9kJUZ70_SLI7-KNdmLS8A3tSfQatP4"> </div>

This project is currently maintained by an individual and remains a work in progress. As the maintainer is still in the early stages of learning Triton, many implementations may not be optimal. Contributions and suggestions are welcome!

Models

Roughly sorted according to the timeline supported in FBi-LA

DateModelTitlePaperCodeFBi-LA impl
2024-11LinfusionLinFusion: 1 GPU, 1 Minute, 16K Imagearxivofficialcode
2024-11MLLADemystify Mamba in Vision: A Linear Attention Perspectivearxivofficialcode
2024-11Focused-LAFLatten Transformer: Vision Transformer using Focused Linear Attentionarxivofficialcode

More models will be implemented gradually.

P.S.: The current implementation of MLLA is relatively basic and will be updated soon.

Usage

Installation

git clone https://github.com/hp-l33/flash-bidirectional-linear-attention.git
pip install -e flash-bidirectional-linear-attention/.

Integrated Models

This library has integrated some models, which can be called directly. Taking LinFusion as an example:

import torch
from diffusers import AutoPipelineForText2Image
from fbi_la.models import LinFusion

sd_repo = "Lykon/dreamshaper-8"

pipeline = AutoPipelineForText2Image.from_pretrained(
    sd_repo, torch_dtype=torch.float16, variant="fp16"
).to(torch.device("cuda"))

linfusion = LinFusion.construct_for(pipeline)

image = pipeline(
    "An astronaut floating in space. Beautiful view of the stars and the universe in the background.",
    generator=torch.manual_seed(123)
).images[0]

Benchmarks

Tested on an A800 80G GPU.

B8-H16-D64:
         T  torch_fwd  triton_fwd  torch_bwd  triton_bwd
0    128.0   0.063488    0.049152   0.520192    0.651264
1    256.0   0.080896    0.056320   0.795648    0.599040
2    512.0   0.111616    0.070656   1.074176    1.065984
3   1024.0   0.169984    0.101376   1.014784    0.746496
4   2048.0   0.300032    0.165888   1.464320    1.364992
5   4096.0   0.532480    0.287744   2.741248    2.564096
6   8192.0   1.005568    0.521216   5.232128    4.940800
7  16384.0   1.924608    0.980992  10.235904    9.695744

TODO

Acknowledgments

Thanks to the following repositories for their inspiration: