diff --git a/python/perf-kernels/06-fused-attention-transV.py b/python/perf-kernels/06-fused-attention-transV.py index b36e2f1268dc..1f9387f26777 100644 --- a/python/perf-kernels/06-fused-attention-transV.py +++ b/python/perf-kernels/06-fused-attention-transV.py @@ -17,6 +17,10 @@ import triton import triton.language as tl +torch_dtype:tl.constexpr = torch.float16 +TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') +if TORCH_HAS_FP8E5: + torch_dtype:tl.constexpr = torch.float8_e5m2fnuz @triton.jit def max_fn(x, y): @@ -145,7 +149,7 @@ def _attn_fwd( qk_scale = sm_scale * 1.44269504 # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs q = tl.load(Q_block_ptr) - q = (q * qk_scale).to(tl.float16) + q = (q * qk_scale).to(q.dtype) # stage 1: off-band # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE diff --git a/python/perf-kernels/README.md b/python/perf-kernels/README.md new file mode 100644 index 000000000000..603182e822b4 --- /dev/null +++ b/python/perf-kernels/README.md @@ -0,0 +1,30 @@ +# AMD Perf Kernels + +This directory contains customized/tuned/experimental kernels on AMD MI series GPUs. + +## `06-fused-attention-transV.py` + +This script is a copy of `tutorials/06-fused-attention.py` with the following +two changes: + +- Tensor V is transposed in the way that seqlen/N_CTX dimension becomes the +fastest changing (a.k.a. leading or least strided) dimension. +This script produces better performance than `tutorials/06-fused-attention.py` +since it has better LDS access efficiency for tensor V. +Note that in the future, we'll improve the LDS access efficiency for +non-transposed tensor V, i.e. head dimension is the fastest changing dimension. +- Only fwd kernel is benchmarked. + +## `06-fused-attention-fwd-transV.py` + +This script is used to produce the best performance for fwd kernel. +It is a copy of `06-fused-attention-transV.py` with the following +changes: + +- All bwd kernels are removed. +- Storing `m` at the end of the fwd kernel is removed. +- Autotuner is removed. All parameters for D=64 ad D=128 are pre-tuned +on MI250X and hard coded. + +Note that this script is also used to benchmark FA performance with 2 GCDs. +Check the [2GCD benchmark script](https://github.com/ROCmSoftwarePlatform/triton/blob/triton-mlir/scripts/amd/benchmark_flash_attention.py) for more details. diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 2b94edd60891..0622bf17bb0f 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -18,59 +18,9 @@ import triton.language as tl torch_dtype:tl.constexpr = torch.float16 -# torch_dtype:tl.constexpr = torch.float8_e5m2fnuz TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') - -@triton.jit -def _attn_fwd_inner( - acc, l_i, m_i, q, - K_block_ptr, V_block_ptr, - start_m, qk_scale, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - STAGE: tl.constexpr, - offs_m: tl.constexpr, - offs_n: tl.constexpr, -): - # range of values handled by this stage - if STAGE == 1: - lo, hi = 0, start_m * BLOCK_M - else: - lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M - lo = tl.multiple_of(lo, BLOCK_M) - K_block_ptr = tl.advance(K_block_ptr, (0, lo)) - V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(K_block_ptr) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - if STAGE == 2: - mask = offs_m[:, None] >= (start_n + offs_n[None, :]) - qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - else: - m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - qk = qk * qk_scale - m_ij[:, None] - p = tl.math.exp2(qk) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - alpha = tl.math.exp2(m_i - m_ij) - l_i = l_i * alpha + l_ij - # -- update output accumulator -- - acc = acc * alpha[:, None] - # update acc - v = tl.load(V_block_ptr) - acc += tl.dot(p.to(tl.float16), v) - # update m_i and l_i - m_i = m_ij - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - return acc, l_i, m_i +if TORCH_HAS_FP8E5: + torch_dtype:tl.constexpr = torch.float8_e5m2fnuz @triton.jit def _attn_fwd_inner( @@ -133,8 +83,8 @@ def _attn_fwd_inner( triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), # d64-False - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), # d64-True + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), ], key=['Z', 'H', 'N_CTX', 'STAGE', 'BLOCK_DMODEL'], ) @@ -157,20 +107,7 @@ def _attn_fwd( ): start_m = tl.program_id(0) off_hz = tl.program_id(1) - qkv_offset = off_hz * stride_qh - Q_block_ptr = tl.make_block_ptr( - base=Q + qkv_offset, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - STAGE: tl.constexpr, - pre_load_v: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - off_z = off_hz // H - off_h = off_hz % H - qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + qvk_offset = off_hz * stride_qh # block pointers Q_block_ptr = tl.make_block_ptr( @@ -247,238 +184,25 @@ def _attn_fwd( acc = acc / l_i[:, None] m_ptrs = M + off_hz * N_CTX + offs_m tl.store(m_ptrs, m_i + tl.math.log2(l_i)) - # write back O - O_block_ptr = tl.make_block_ptr( - base=Out + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) tl.store(O_block_ptr, acc.to(Out.type.element_ty)) @triton.jit -def _bwd_preprocess( - Out, DO, - NewDO, Delta, - BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, -): +def _attn_bwd_preprocess(O, DO, # + NewDO, Delta, # + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, # + ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) - off_hz = tl.program_id(1) off_n = tl.arange(0, D_HEAD) # load - o = tl.load(O + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]) - do = tl.load(DO + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + o = tl.load(O + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) delta = tl.sum(o * do, axis=1) # write-back tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) tl.store(Delta + off_m, delta) -# The main inner-loop logic for computing dK and dV. -@triton.jit -def _attn_bwd_dkdv( - dk, dv, - Q, k, v, sm_scale, - DO, - M, D, - # shared by Q/K/V/DO. - stride_tok, stride_d, - H, N_CTX, - BLOCK_M1: tl.constexpr, - BLOCK_N1: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - # Filled in by the wrapper. - start_n, start_m, num_steps, - MASK: tl.constexpr, -): - offs_m = start_m + tl.arange(0, BLOCK_M1) - offs_n = start_n + tl.arange(0, BLOCK_N1) - offs_k = tl.arange(0, BLOCK_DMODEL) - qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d - do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d - # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. - tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) - curr_m = start_m - step_m = BLOCK_M1 - for blk_idx in range(num_steps): - qT = tl.load(qT_ptrs) - # Load m before computing qk to reduce pipeline stall. - offs_m = curr_m + tl.arange(0, BLOCK_M1) - m = tl.load(M + offs_m) - qkT = tl.dot(k, qT) - pT = tl.math.exp2(qkT - m[None, :]) - # Autoregressive masking. - if MASK: - mask = (offs_m[None, :] >= offs_n[:, None]) - pT = tl.where(mask, pT, 0.0) - do = tl.load(do_ptrs) - # Compute dV. - ppT = pT - ppT = ppT.to(tl.float16) - dv += tl.dot(ppT, do) - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(D + offs_m) - # Compute dP and dS. - dpT = tl.dot(v, tl.trans(do)).to(tl.float32) - dsT = pT * (dpT - Di[None, :]) - dsT = dsT.to(tl.float16) - dk += tl.dot(dsT, tl.trans(qT)) - # Increment pointers. - curr_m += step_m - qT_ptrs += step_m * stride_tok - do_ptrs += step_m * stride_tok - return dk, dv - - -# the main inner-loop logic for computing dQ -@triton.jit -def _attn_bwd_dq( - dq, q, K, V, - do, m, D, - # shared by Q/K/V/DO. - stride_tok, stride_d, - H, N_CTX, - BLOCK_M2: tl.constexpr, - BLOCK_N2: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - # Filled in by the wrapper. - start_m, start_n, num_steps, - MASK: tl.constexpr, -): - offs_m = start_m + tl.arange(0, BLOCK_M2) - offs_n = start_n + tl.arange(0, BLOCK_N2) - offs_k = tl.arange(0, BLOCK_DMODEL) - kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d - vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d - # D (= delta) is pre-divided by ds_scale. - Di = tl.load(D + offs_m) - # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. - tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) - curr_n = start_n - step_n = BLOCK_N2 - for blk_idx in range(num_steps): - kT = tl.load(kT_ptrs) - vT = tl.load(vT_ptrs) - qk = tl.dot(q, kT) - p = tl.math.exp2(qk - m) - # Autoregressive masking. - if MASK: - offs_n = curr_n + tl.arange(0, BLOCK_N2) - mask = (offs_m[:, None] >= offs_n[None, :]) - p = tl.where(mask, p, 0.0) - # Compute dP and dS. - dp = tl.dot(do, vT).to(tl.float32) - ds = p * (dp - Di[:, None]) - ds = ds.to(tl.float16) - # Compute dQ. - # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. - dq += tl.dot(ds, tl.trans(kT)) - # Increment pointers. - curr_n += step_n - kT_ptrs += step_n * stride_tok - vT_ptrs += step_n * stride_tok - return dq - - -@triton.jit -def _attn_bwd( - Q, K, V, sm_scale, - DO, - DQ, DK, DV, - M, D, - # shared by Q/K/V/DO. - stride_z, stride_h, stride_tok, stride_d, - H, N_CTX, - BLOCK_M1: tl.constexpr, - BLOCK_N1: tl.constexpr, - BLOCK_M2: tl.constexpr, - BLOCK_N2: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, -): - LN2: tl.constexpr = 0.6931471824645996 # = ln(2) - - bhid = tl.program_id(2) - off_chz = (bhid * N_CTX).to(tl.int64) - adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) - pid = tl.program_id(0) - - # offset pointers for batch/head - Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_kz + off_h * stride_kh - V += off_z * stride_vz + off_h * stride_vh - DO += off_z * stride_qz + off_h * stride_qh - DQ += off_z * stride_qz + off_h * stride_qh - DK += off_z * stride_kz + off_h * stride_kh - DV += off_z * stride_vz + off_h * stride_vh - # See fwd pass above for explanation. - qk_scale = sm_scale * 1.44269504 - for start_n in range(0, num_block_kv): - if CAUSAL: - lo = tl.math.max(start_n * BLOCK_M - P_SEQ, 0) - else: - lo = 0 - # initialize row/col offsets - offs_qm = lo + tl.arange(0, BLOCK_M) - offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) - offs_m = tl.arange(0, BLOCK_N) - offs_k = tl.arange(0, BLOCK_DMODEL) - # initialize pointers to value-like data - q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - v_ptrs = V + (offs_n[None, :] * stride_qm + offs_k[:, None] * stride_qk) - do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) - # pointer to row-wise quantities in value-like data - D_ptrs = D + off_hz * N_CTX - l_ptrs = L + off_hz * N_CTX - # initialize dk amd dv - dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # k and v stay in SRAM throughout - k = tl.load(k_ptrs) - v = tl.load(v_ptrs) - # loop over rows - for start_m in range(lo, num_block_q * BLOCK_M, BLOCK_M): - offs_m_curr = start_m + offs_m - # load q, k, v, do on-chip - q = tl.load(q_ptrs) - # recompute p = softmax(qk, dim=-1).T - if CAUSAL: - qk = tl.where(P_SEQ + offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf")) - else: - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, tl.trans(k)) - l_i = tl.load(l_ptrs + offs_m_curr) - p = tl.math.exp2(qk * qk_scale - l_i[:, None]) - # compute dv - do = tl.load(do_ptrs) - dv += tl.dot(tl.trans(p.to(do.dtype)), do) - # compute dp = dot(v, do) - Di = tl.load(D_ptrs + offs_m_curr) - dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] - dp += tl.dot(do, v) - # compute ds = p * (dp - delta[:, None]) - ds = p * dp * sm_scale - # compute dk = dot(ds.T, q) - dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) - # compute dq - dq = tl.load(dq_ptrs) - dq += tl.dot(ds.to(Q.dtype.element_ty), k) - tl.store(dq_ptrs, dq) - # increment pointers - dq_ptrs += BLOCK_M * stride_qm - q_ptrs += BLOCK_M * stride_qm - do_ptrs += BLOCK_M * stride_qm - # write-back - dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) - tl.store(dk_ptrs, dk) - tl.store(dv_ptrs, dv) - @triton.jit def _bwd_kernel_dk_dv( Q, K, V, sm_scale, Out, DO, @@ -732,7 +456,7 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): best_config = _attn_fwd.get_best_config() block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1]) grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1) - + ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid ctx.sm_scale = sm_scale @@ -750,11 +474,13 @@ def backward(ctx, do): else: BLOCK = 128 q, k, v, o, L = ctx.saved_tensors + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() do = do.contiguous() - dq = torch.zeros_like(q, dtype=torch.float32) - # dk = torch.empty_like(k, dtype=torch_dtype) + dq = torch.zeros_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] delta = torch.empty_like(L) do_scaled = torch.empty_like(do) # Figure out what BLOCK size fwd used and adjust num_blocks accordingly. @@ -764,10 +490,10 @@ def backward(ctx, do): # Alternatively we could compute a new grid but this keeps it consistent # with fwd and easier to reason about. block_scale = (q.shape[2] // ctx.grid[0]) // BLOCK - _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( - o, do, - do_scaled, delta, - BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL, + _attn_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + o, do, # + do_scaled, delta, # + BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL, # ) if not ctx.split_kernel: _bwd_kernel[(ctx.grid[1],)]( @@ -887,15 +613,15 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): tri_dk, k.grad = k.grad.clone(), None tri_dq, q.grad = q.grad.clone(), None # compare - assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) if torch.version.hip is None: - assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=0) + torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0) # The current block size for MI200 series is 64x64. This results in # larger differences in float results due to rounding. else: - assert torch.allclose(ref_dv, tri_dv, atol=5e-2, rtol=0) - assert torch.allclose(ref_dk, tri_dk, atol=5e-2, rtol=0) - assert torch.allclose(ref_dq, tri_dq, atol=5e-2, rtol=0) + torch.testing.assert_close(ref_dv, tri_dv, atol=5e-2, rtol=0) + torch.testing.assert_close(ref_dk, tri_dk, atol=5e-2, rtol=1e-2) + torch.testing.assert_close(ref_dq, tri_dq, atol=5e-2, rtol=1e-2) try: