Skip to content

fla-org/flash-bidirectional-linear-attention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

31 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Flash Bi-directional Linear Attention

The aim of this repository is to implement bi-directional linear attention for non-causal modeling using Triton.

image

This project is currently maintained by an individual and remains a work in progress. As the maintainer is still in the early stages of learning Triton, many implementations may not be optimal. Contributions and suggestions are welcome!

Update

  • [2024-12-30] Optimized the backpropagation speed of the linear attn.
  • [2024-12-28] Updated simple_la, which is a simple form of linear_attn without the norm term.

Models

Roughly sorted according to the timeline supported in FBi-LA

Year Model Title Paper Code fla impl
2024 Linfusion LinFusion: 1 GPU, 1 Minute, 16K Image arxiv official code
2024 MLLA Demystify Mamba in Vision: A Linear Attention Perspective arxiv official code
2023 Focused-LA FLatten Transformer: Vision Transformer using Focused Linear Attention arxiv official code

More models will be implemented gradually.

Usage

Installation

git clone https://github.com/fla-org/flash-bidirectional-linear-attention.git
pip install -e flash-bidirectional-linear-attention/.

Integrated Models

This library has integrated some models, which can be called directly. Taking LinFusion as an example:

import torch
from diffusers import AutoPipelineForText2Image
from fbi_la.models import LinFusion

sd_repo = "Lykon/dreamshaper-8"

pipeline = AutoPipelineForText2Image.from_pretrained(
    sd_repo, torch_dtype=torch.float16, variant="fp16"
).to(torch.device("cuda"))

linfusion = LinFusion.construct_for(pipeline)

image = pipeline(
    "An astronaut floating in space. Beautiful view of the stars and the universe in the background.",
    generator=torch.manual_seed(123)
).images[0]

Benchmarks

Tested on an A800 80G GPU.

B8-H16-D64:
         T  torch_fwd  triton_fwd  torch_bwd  triton_bwd
0    128.0   0.063488    0.049152   0.798720    0.651264
1    256.0   0.080896    0.056320   0.796672    0.625664
2    512.0   0.111616    0.058368   0.798720    0.630784
3   1024.0   0.169984    0.090112   0.864256    0.719872
4   2048.0   0.300032    0.151552   1.624064    0.702464
5   4096.0   0.532480    0.276480   3.058176    1.324032
6   8192.0   1.005568    0.521216   5.880320    2.556928
7  16384.0   1.924608    0.980992  11.540992    5.022208
image

TODO

  • improve memory efficiency during backpropagation
  • implement more models
    • VSSD
    • RALA

Acknowledgments

Thanks to the following repositories for their inspiration:

About

Triton implement of bi-directional (non-causal) linear attention

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages