Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issues]: The Gap between AOT and JIT Triton on Flash Attention kernel #34

Open
jinsong-mao opened this issue Jun 28, 2024 · 0 comments

Comments

@jinsong-mao
Copy link

Suggestion Description

Thanks for this great work,

There is some perf gap between AOT and JIT Triton for flash attention on most seqlen, n_heads, head_dim
We tried to tune the flash attention kernel and got some perf improvement on head_dim=128, However, it's still slower than JIT Triton kernel.

Looks their triton kernel tune space has some difference and this is the main difference we found.
triton kernel tune space for aotriton - https://github.com/ROCm/aotriton/blob/main/tritonsrc/attn_torch_function.py#L47
triton kernel tune space for jit triton - https://github.com/ROCm/triton/blob/triton-mlir/python/tutorials/06-fused-attention.py#L84-L92

Is there any other main difference that make JIT Triton faster than AOT triton FA kernel?

Operating System

ubuntu 22

GPU

mi300

ROCm Component

rocBLAS

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant