Skip to content

aqlaboratory/ring-flash-attention

 
 

Repository files navigation

Ring Flash Attention

This repo implements the RingAttention with FlashAttention. Currently, this repo implements:

  • varlen api, corresponding to flash_attn_varlen_func:
    • llama3_flash_attn_varlen_func: the context parallelism used in llama3 tech report with extra design for varlen and low memory overhead.
      • Technically, this is not ring attention and will bring memory overhead, but this is the recommended api for most use case, as the communication pattern is more friendly to GPU cluster and the arithmetic errors is lower.
    • ring_flash_attn_varlen_func: naive ring attention.
    • zigzag_ring_flash_attn_varlen_func: an more compute balanced version of ring attention, see issue#2.
  • batch api, corresponding to flash_attn_func:
    • ring_flash_attn_func: naive ring attention.
    • zigzag_ring_flash_attn_func: an more compute balanced version of ring attention, see issue#2.
    • stripe_flash_attn_func: stripe attention version of ring_flash_attn_func, the block size is set to 1 to use flash_attn api, see: https://arxiv.org/abs/2311.09431
  • huggingface model adapter. Here is an example to use the adapter: OpenRLHF/OpenRLHF/pull#439.

Note that

  • all function has the *_func, *_kvpacked_func, *_qkvpacked_func variant implemented.
  • the varlen versions only support passing one cu_seqlens.

The current performance on 8xH800 is (benchmark/benchmark_qkvpacked_func.py):

GPU theoretic flash_attn ring_attn zigzag_ring stripe_attn
fwd only (iter/sec) 8xH800 2418.4 / 8 = 302.3 208.0 283.0 259.6
68.8% 93.6% 85.9%
fwd + bwd (iter/sec) 8xH800 705.2 / 8 = 88.2 54.3 75.7 76.9
61.5% 85.9% 87.2%
fwd only (iter/sec) 8xA100 1545.9 / 8 = 193.2 124.4 179.0 163.9
64.3% 92.7% 84.8%
fwd + bwd (iter/sec) 8xA100 470.6 / 8 = 58.8 33.3 49.5 45.9
56.6% 84.1% 78.1%

Note that

  • when running the benchmark with with 8 gpu, the flash attn code is running with 1/8 computation of ring attention.
  • nvlink between GPUs are required for high performance.
  • the varlen versions of the ring attention variants are slow at the moment, please use the non-varlen version or the llama3 api if possible.
  • please remember to adapt the RoPE offset for different api.

Installation

pip install ring-flash-attn

or use the following command to build from source:

git clone https://github.com/zhuzilin/ring-flash-attention.git
cd ring-flash-attention
pip install .

Limits

There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the values with the original fp32 ones.

And also because we need to save extra fp32 buffer during computation, the memory usage would be higher than theoretic limit.

TODOs

  • Implement ring_flash_attn_varlen_qkvpacked_func
  • Implement zigzag_ring_flash_attn_qkvpacked_func issue#2
  • Implement stripe_flash_attn_qkvpacked_func
  • Implement zigzag_ring_flash_attn_varlen_qkvpacked_func
  • Implement *_kvpacked_func and *_func variant for all APIs
  • Optimize *_varlen_func Implement llama3_flash_attn_varlen_func.
  • Add an example to train llama. Implement adapter for huggingface model.
  • Try to upstream to flash attention.

Test

torchrun --nproc_per_node 8 test/test_llama3_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_ring_flash_attn_func.py
torchrun --nproc_per_node 8 test/test_ring_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_func.py
torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_stripe_flash_attn_func.py

Benchmark

torchrun --nproc_per_node 8 benchmark/benchmark_qkvpacked_func.py
torchrun --nproc_per_node 8 benchmark/benchmark_varlen_qkvpacked_func.py

Known Limits

  • dropout is not supported at the moment, because it's hard to save all the rng_states.
  • window_size is not supported, because it will be really tricky to implement a varlen version with window_size.

About

Ring attention implementation with flash attention

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%