diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 027cdc31e289..b69ff61689a1 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -22,27 +22,109 @@ def max_fn(x, y): return tl.math.max(x, y) +@triton.jit +def _attn_fwd_inner( + acc, l_i, m_i, q, + K_block_ptr, V_block_ptr, + start_m, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, + N_CTX, + pre_load_v: tl.constexpr, +): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + # causal = False + else: + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = tl.where(mask, qk, float("-inf")) + qk += tl.dot(q, k) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(tl.float16), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': True}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': True}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': True}, num_stages=0, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=0, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=0, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': True}, num_stages=0, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': True}, num_stages=0, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'pre_load_v': False}, num_stages=0, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=0, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=0, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=0, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 4, 'pre_load_v': False}, num_stages=0, num_warps=4), + ], + key=['N_CTX', 'STAGE'], +) + @triton.jit -def _fwd_kernel( - Q, K, V, sm_scale, - L, - Out, +def _attn_fwd( + Q, K, V, sm_scale, M, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, - Z, H, N_CTX, P_SEQ, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + Z, H, + N_CTX, + STAGE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - IS_CAUSAL: tl.constexpr, + pre_load_v: tl.constexpr, ): start_m = tl.program_id(0) off_hz = tl.program_id(1) - q_offset = off_hz * stride_qh - kv_offset = off_hz * stride_kh + qkv_offset = off_hz * stride_qh Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, + base=Q + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), @@ -50,16 +132,16 @@ def _fwd_kernel( order=(1, 0) ) K_block_ptr = tl.make_block_ptr( - base=K + kv_offset, - shape=(BLOCK_DMODEL, N_CTX + P_SEQ), + base=K + qkv_offset, + shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1) ) V_block_ptr = tl.make_block_ptr( - base=V + kv_offset, - shape=(N_CTX + P_SEQ, BLOCK_DMODEL), + base=V + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), @@ -70,55 +152,53 @@ def _fwd_kernel( offs_n = tl.arange(0, BLOCK_N) # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # scale sm_scale by log_2(e) and use # 2^x instead of exp in the loop because CSE and LICM # don't work as expected with `exp` in the loop qk_scale = sm_scale * 1.44269504 - # load q: it will stay in SRAM throughout + # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs q = tl.load(Q_block_ptr) q = (q * qk_scale).to(tl.float16) - # loop over k, v and update accumulator - lo = 0 - hi = P_SEQ + (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + P_SEQ - for start_n in range(lo, hi, BLOCK_N): - # -- load k, v -- - k = tl.load(K_block_ptr) - v = tl.load(V_block_ptr) - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16) - if IS_CAUSAL: - qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - qk += tl.dot(q, k) - # -- compute scaling constant --- - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - alpha = tl.math.exp2(m_i - m_i_new) - p = tl.math.exp2(qk - m_i_new[:, None]) - # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc *= acc_scale[:, None] - acc += tl.dot(p.to(tl.float16), v) - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - # update pointers - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - # write back l and m + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + 4 - STAGE, offs_m, offs_n, + N_CTX, pre_load_v, + ) + # stage 2: on-band + if STAGE & 2: + # barrier makes it easier for compielr to schedule the + # two loops independently + tl.debug_barrier() + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + 2, offs_m, offs_n, + N_CTX, pre_load_v, + ) + # epilogue + # write back m acc = acc / l_i[:, None] - l_ptrs = L + off_hz * N_CTX + offs_m - tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i + tl.math.log2(l_i)) # write back O O_block_ptr = tl.make_block_ptr( - base=Out + q_offset, + base=Out + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) - tl.store(O_block_ptr, acc.to(tl.float16)) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) @triton.jit @@ -455,42 +535,43 @@ def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} o = torch.empty_like(q) - BLOCK_M = 128 if torch.version.hip is None: + BLOCK_M = 128 BLOCK_N = 64 if Lk <= 64 else 32 num_stages = 4 if Lk <= 64 else 3 num_warps = 4 if Lk <= 64 else 8 - else: - BLOCK_N = 64 - num_warps = 4 - num_stages = 1 - waves_per_eu = 2 if causal else 3 - - grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) - L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2] - - _fwd_kernel[grid]( - q, k, v, sm_scale, - L, - o, + + stage = 3 if causal else 1 + grid = lambda META: ( + triton.cdiv(q.shape[2], META['BLOCK_M']), + q.shape[0] * q.shape[1], + 1 + ) + M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + + _attn_fwd[grid]( + q, k, v, sm_scale, M, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], q.shape[2], P_SEQ, - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, - IS_CAUSAL=causal, - num_warps=num_warps, - num_stages=num_stages, waves_per_eu=waves_per_eu) + q.shape[0], q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + STAGE=stage, + ) + + ## restore the grid for bwd kernel + best_config = _attn_fwd.get_best_config(N_CTX = q.shape[2], STAGE = stage) + block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1]) + grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1) - ctx.save_for_backward(q, k, v, o, L) + ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid ctx.sm_scale = sm_scale ctx.BLOCK_DMODEL = Lk ctx.causal = causal ctx.split_kernel = split_kernel - ctx.P_SEQ = P_SEQ return o @staticmethod @@ -570,23 +651,35 @@ def backward(ctx, do): attention = _attention.apply -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ', - [(4, 48, 1024, 64, 128), - (4, 48, 2048, 64, 128), - (4, 48, 4096, 64, 128), - (4, 48, 8192, 64, 128), - (4, 48, 16384, 64, 128) +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', + [(4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + #(4, 48, 8192, 64), + #(4, 48, 16384, 64) ]) @pytest.mark.parametrize('causal', [False, True]) -def test_op_fwd(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16): +def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): torch.manual_seed(20) - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - sm_scale = q.shape[-1] ** (-0.5) + q = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + v = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + sm_scale = 0.5 dout = torch.randn_like(q) # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ) + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale if causal: p[:, :, M == 0] = float("-inf") @@ -598,23 +691,23 @@ def test_op_fwd(Z, H, N_CTX, D_HEAD, P_SEQ, causal, dtype=torch.float16): assert torch.allclose(ref_out, tri_out, atol=1e-2, rtol=0) -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD, P_SEQ', - [(4, 48, 1024, 64, 0), - (4, 48, 2048, 64, 0), - (4, 48, 4096, 64, 0), - (1, 16, 8192, 64, 0), +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', + [(4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + (1, 16, 8192, 64), ]) -def test_op_bwd(Z, H, N_CTX, D_HEAD, P_SEQ, dtype=torch.float16): +def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20) causal = True q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - k = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - v = torch.empty((Z, H, N_CTX + P_SEQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() - sm_scale = q.shape[-1] ** (-0.5) + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + sm_scale = 0,5 split_kernel = True dout = torch.randn_like(q) # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX + P_SEQ), device="cuda"), diagonal=P_SEQ) + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale if causal: p[:, :, M == 0] = float("-inf") @@ -656,17 +749,28 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, P_SEQ, dtype=torch.float16): BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 -configs = [triton.testing.Benchmark( - x_names=['N_CTX'], - x_vals=[2**i for i in range(10, 15)], - line_arg='provider', - line_vals=['triton'] + (['flash'] if HAS_FLASH else []), - line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', - args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode, 'causal': causal} -) for mode in ['fwd', 'bwd'] for causal in [False, True]] +configs = [] +for mode in ['fwd', 'bwd']: + for causal in [False, True]: + if mode == 'bwd' and causal == False: + continue + configs.append(triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(10, 15)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-causal={causal}', + args={ + 'H': N_HEADS, + 'BATCH': BATCH, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + 'causal': causal}) + ) @triton.testing.perf_report(configs)