Skip to content

Commit

Permalink
[Tutorial] Fix post IFU issues with FA (#398)
Browse files Browse the repository at this point in the history
* [Tutorial] Fix post IFU issues with FA

* Remove redundant kernels in 06-fused-attention.py

* Added README for scripts in perf-kernels dir

* Fix bwd kernel

---------

Co-authored-by: Lixun Zhang <[email protected]>
  • Loading branch information
binarman and zhanglx13 authored Nov 14, 2023
1 parent 9941ce7 commit 5b06b16
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 300 deletions.
6 changes: 5 additions & 1 deletion python/perf-kernels/06-fused-attention-transV.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
import triton
import triton.language as tl

torch_dtype:tl.constexpr = torch.float16
TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz')
if TORCH_HAS_FP8E5:
torch_dtype:tl.constexpr = torch.float8_e5m2fnuz

@triton.jit
def max_fn(x, y):
Expand Down Expand Up @@ -145,7 +149,7 @@ def _attn_fwd(
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
q = tl.load(Q_block_ptr)
q = (q * qk_scale).to(tl.float16)
q = (q * qk_scale).to(q.dtype)
# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
Expand Down
30 changes: 30 additions & 0 deletions python/perf-kernels/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# AMD Perf Kernels

This directory contains customized/tuned/experimental kernels on AMD MI series GPUs.

## `06-fused-attention-transV.py`

This script is a copy of `tutorials/06-fused-attention.py` with the following
two changes:

- Tensor V is transposed in the way that seqlen/N_CTX dimension becomes the
fastest changing (a.k.a. leading or least strided) dimension.
This script produces better performance than `tutorials/06-fused-attention.py`
since it has better LDS access efficiency for tensor V.
Note that in the future, we'll improve the LDS access efficiency for
non-transposed tensor V, i.e. head dimension is the fastest changing dimension.
- Only fwd kernel is benchmarked.

## `06-fused-attention-fwd-transV.py`

This script is used to produce the best performance for fwd kernel.
It is a copy of `06-fused-attention-transV.py` with the following
changes:

- All bwd kernels are removed.
- Storing `m` at the end of the fwd kernel is removed.
- Autotuner is removed. All parameters for D=64 ad D=128 are pre-tuned
on MI250X and hard coded.

Note that this script is also used to benchmark FA performance with 2 GCDs.
Check the [2GCD benchmark script](https://github.com/ROCmSoftwarePlatform/triton/blob/triton-mlir/scripts/amd/benchmark_flash_attention.py) for more details.
Loading

0 comments on commit 5b06b16

Please sign in to comment.