From 481de73b69cba60ddf3b7d88c2095c5a00346a9e Mon Sep 17 00:00:00 2001 From: EdouardYvinec Date: Mon, 21 Oct 2024 14:42:47 +0200 Subject: [PATCH] fix: in newer versions of triton, tl.dot should take as input only q and tl.trans(k) --- flash_attn/flash_attn_triton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/flash_attn_triton.py b/flash_attn/flash_attn_triton.py index 30420c057..5a4636aee 100644 --- a/flash_attn/flash_attn_triton.py +++ b/flash_attn/flash_attn_triton.py @@ -180,7 +180,7 @@ def _fwd_kernel( other=0.0, ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, trans_b=True) + qk += tl.dot(q, tl.trans(k)) # Trying to combine the two masks seem to make the result wrong if not EVEN_N: # Need to mask out otherwise the softmax is wrong qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))