From 26517a9c3aca71517d95f1bf33f4c7bea0dd2eef Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Wed, 14 Aug 2024 12:39:29 -0500 Subject: [PATCH 1/7] add bad config --- python/perf-kernels/flash-attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 988438340abe..3ad870316d44 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -1069,6 +1069,7 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlen @pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ + (1, 1, 1, 512, 256, 160), (4, 48, 24, 1024, 1024, 64), (1, 24, 6, 8192, 8192, 64), (1, 4, 2, 16384, 16384, 128), From 1c1d5f6d95e3dbdb8f2aae830c099c0726356b8c Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Wed, 14 Aug 2024 14:40:48 -0500 Subject: [PATCH 2/7] save --- python/perf-kernels/flash-attention.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 3ad870316d44..8cbb5fa7a30c 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -1069,7 +1069,11 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlen @pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ + (1, 1, 1, 512, 256, 128), + (1, 1, 1, 256, 128, 160), (1, 1, 1, 512, 256, 160), + (1, 1, 1, 1024, 512, 160), + (1, 1, 1, 512, 256, 256), (4, 48, 24, 1024, 1024, 64), (1, 24, 6, 8192, 8192, 64), (1, 4, 2, 16384, 16384, 128), From 55082af29898b644d11791730a73f7dcc83ec99b Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Mon, 19 Aug 2024 11:14:07 -0500 Subject: [PATCH 3/7] save stuff --- python/perf-kernels/flash-attention.py | 320 ++++++++++++++++++++----- scripts | 1 + 2 files changed, 263 insertions(+), 58 deletions(-) create mode 160000 scripts diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 8cbb5fa7a30c..917f04f57794 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -21,6 +21,7 @@ """ import argparse +from typing import Tuple import pytest import sys import torch @@ -28,6 +29,8 @@ import triton import triton.language as tl +DEBUG = True + class MetaData(): cu_seqlens_q = None @@ -40,8 +43,37 @@ class MetaData(): num_contexts = 0 varlen = False layout = None + cache_seqlens = None + cache_batch_idx = None + new_kv = False + seqlen_new = None + k_new = None + v_new = None dropout_p, return_encoded_softmax = 0.0, False + def __repr__(self) -> str: + return (f"MetaData(\n" + f" sm_scale={self.sm_scale},\n" + f" cu_seqlens_q={self.cu_seqlens_q},\n" + f" cu_seqlens_k={self.cu_seqlens_k},\n" + f" max_seqlens_q={self.max_seqlens_q},\n" + f" max_seqlens_k={self.max_seqlens_k},\n" + f" bias={self.bias},\n" + f" alibi_slopes={self.alibi_slopes},\n" + f" causal={self.causal},\n" + f" num_contexts={self.num_contexts},\n" + f" varlen={self.varlen},\n" + f" layout={self.layout},\n" + f" cache_seqlens={self.cache_seqlens},\n" + f" cache_batch_idx={self.cache_batch_idx},\n" + f" new_kv={self.new_kv},\n" + f" seqlen_new={self.seqlen_new},\n" + f" k_new={self.k_new},\n" + f" v_new={self.v_new},\n" + f" dropout_p={self.dropout_p},\n" + f" return_encoded_softmax={self.return_encoded_softmax}\n" + f")") + def __init__(self, sm_scale=1.0): self.sm_scale = sm_scale @@ -158,13 +190,6 @@ def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): return tensor -@triton.jit -def print_gpu(prefix, val=None): - if (tl.program_id(0) == 0) and ((tl.program_id(1) == 0) and (tl.program_id(2) == 0)): - if val is not None: - tl.device_print(prefix, val) - else: - tl.device_print(prefix) @triton.jit @@ -243,12 +268,34 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri size_n = start_n + OFFS_N[None, :] mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) + + if tl.program_id(0) == 4: + tl.device_print("q:", q) + # print("q:", q) + # tl.device_print("k:", k) + + # -- compute qk ---- + qk += tl.dot(q, k) + + # if tl.program_id(0) == 4: + # tl.device_print("qk:", qk) + if IS_CAUSAL: causal_boundary = start_n + offs_n_causal causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + # if tl.program_id(0) == 4: + # tl.device_print("causal_mask:", causal_mask) + + # if tl.program_id(0) == 4: + # tl.device_print("qk max before causal:", tl.max(qk)) + # tl.device_print("qk min before causal:", tl.min(qk)) + qk = tl.where(causal_mask, qk, float("-inf")) - # -- compute qk ---- - qk += tl.dot(q, k) + + # if tl.program_id(0) == 4: + # tl.device_print("qk max after causal:", tl.max(qk)) + # tl.device_print("qk min after causal:", tl.min(qk)) + if bias_ptrs is not None: bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) @@ -301,32 +348,49 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri @triton.autotune( configs=[ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=1), + # triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + # num_warps=8), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + # num_warps=4), + # triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + # num_warps=8), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, + # num_warps=4), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, + # num_warps=4), + # triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + # num_warps=8), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + # num_warps=4), + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + # num_warps=8), + # TODO: This config fails with head_size not pow2 with data mismatches. Check why. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + # num_warps=4), ], - key=['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], use_cuda_graph=True, ) @triton.jit -def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, - stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, - stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, - cu_seqlens_k, dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, +def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, + stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, + stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): + + # if tl.program_id(0) == 4: + # tl.device_print("tl.program_id(0):", tl.program_id(0)) + # if tl.program_id(0) == 4: + # tl.device_print("BLOCK_M:", BLOCK_M) + # tl.device_print("BLOCK_N:", BLOCK_N) + # tl.device_print("BLOCK_DMODEL:", BLOCK_DMODEL) + # tl.device_print("ACTUAL_BLOCK_DMODEL:", ACTUAL_BLOCK_DMODEL) + start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) @@ -400,6 +464,8 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh elif seqlen_k % BLOCK_N: n_extra_tokens = seqlen_k % BLOCK_N PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + # if tl.program_id(0) == 4: + # tl.device_print("PADDED_HEAD", PADDED_HEAD) # Compute pointers for all the tensors used in this kernel. q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm @@ -439,13 +505,13 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # scale sm_scale by log_2(e) and use 2^x in the loop as we do not # have native e^x support in HW. - QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089 + qk_scale = sm_scale * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q if PADDED_HEAD: q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) - q = (q * QK_SCALE).to(q.type.element_ty) + q = (q * qk_scale).to(q.type.element_ty) # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -502,10 +568,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, ACTUAL_BLOCK_DMODEL) # epilogue - # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. - l_recip = 1 / l_i[:, None] - acc = acc * l_recip - + acc = acc / l_i[:, None] if ENABLE_DROPOUT: acc = acc / (1 - dropout_p) # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, @@ -518,6 +581,10 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh acc = acc.to(Out.type.element_ty) if IS_CAUSAL: if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + # if tl.program_id(0) == 4: + # tl.device_print("causal_start_idx:", causal_start_idx) + # tl.device_print("start_m_idx:", start_m_idx) + # tl.device_print("end_m_idx:", end_m_idx) out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] @@ -822,10 +889,6 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, dq *= LN2 tl.store(DQ_block_ptr, dq.to(q.dtype)) - -empty = torch.empty(128, device="cuda") - - def get_shape_from_layout(q, k, metadata): if metadata.layout == 'thd': nheads_q, nheads_k = q.shape[1], k.shape[1] @@ -864,10 +927,19 @@ def get_strides_from_layout(q, k, v, o, metadata): return q_strides, k_strides, v_strides, o_strides -class _attention(torch.autograd.Function): +class _attention_prefill(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, o, metadata): + if DEBUG: + print() + print("_attention_prefill.forward") + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("o:", o, o.shape) + print("metadata:", metadata) + # NOTE: a large bias tensor leads to overflow during pointer arithmetic if (metadata.bias is not None): assert (metadata.bias.numel() < 2**31) @@ -884,6 +956,9 @@ def forward(ctx, q, k, v, o, metadata): # Smallest head_dim supported is 16. If smaller, the tile in the # kernel is padded - there is no padding in memory for any dims. padded_d_model = max(padded_d_model, 16) + if DEBUG: + print("head_size:", head_size) + print("padded_d_model:", padded_d_model) grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']), nheads_q, batch) @@ -913,7 +988,9 @@ def forward(ctx, q, k, v, o, metadata): alibi_strides = (metadata.alibi_slopes.stride(0), metadata.alibi_slopes.stride(1)) else: alibi_strides = (0, 0) - + + # if DEBUG: + # print("grid:", grid) attn_fwd[grid](q, k, v, metadata.bias, metadata.sm_scale, M, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, *alibi_strides, metadata.cu_seqlens_q, metadata.cu_seqlens_k, dropout_p=metadata.dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, @@ -924,6 +1001,14 @@ def forward(ctx, q, k, v, o, metadata): USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax) + if DEBUG: + print("Saving in ctx for backward") + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("o:", o, o.shape) + print("M:", M, M.shape) + ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid ctx.sm_scale = metadata.sm_scale @@ -938,18 +1023,34 @@ def forward(ctx, q, k, v, o, metadata): return o, encoded_softmax @staticmethod - def backward(ctx, do, _): + def backward(ctx, do, _): # expects bhsd + if DEBUG: + print() + print("_attention_prefill.backward") + print("do:", do, do.shape, do.stride()) + if torch.version.hip is not None: BLOCK = 64 else: BLOCK = 128 q, k, v, o, M = ctx.saved_tensors + if DEBUG: + print("q:", q, q.shape, q.stride()) + print("k:", k, k.shape, k.stride()) + print("v:", v, v.shape, v.stride()) + print("o:", o, o.shape, o.stride()) + print("M:", M, M.shape, M.stride()) assert do.is_contiguous() assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() seqlen_q = q.shape[2] - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dv = torch.empty_like(v) + if True: + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + else: + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) BATCH, N_HEAD, N_CTX = q.shape[:3] PRE_BLOCK = 128 # NUM_WARPS, NUM_STAGES = 4, 1 @@ -958,11 +1059,21 @@ def backward(ctx, do, _): RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) arg_k = k arg_k = arg_k * (ctx.sm_scale * RCP_LN2) - assert N_CTX % PRE_BLOCK == 0 - delta = torch.empty_like(M) + # assert N_CTX % PRE_BLOCK == 0 + if True: + delta = torch.zeros_like(M) + else: + delta = torch.empty_like(M) _, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1] # padded_head = (Lk != ctx.BLOCK_DMODEL) grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0]) + if DEBUG: + print() + print("_attn_bwd_preprocess input") + print("o:", o, o.shape, ) + print("do:", k, k.shape) + print("delta:", v, v.shape) + _attn_bwd_preprocess[grid_preprocess]( o, do, @@ -981,6 +1092,22 @@ def backward(ctx, do, _): D_HEAD=ctx.BLOCK_DMODEL, ) grid = lambda META: (triton.cdiv(N_CTX, META['BLOCK_N1']), 1, BATCH * N_HEAD) + if DEBUG: + print() + print("_attn_bwd input") + print("q:", q, q.shape) + print("arg_k:", arg_k, arg_k.shape) + print("v:", v, v.shape) + print("ctx.sm_scale:", ctx.sm_scale) + print("ctx.alibi_slopes:", ctx.alibi_slopes) + print("do:", do, do.shape) + print("dq:", dq, dq.shape) + print("dk:", dk, dk.shape) + print("dv:", dv, dv.shape) + print("M:", M, M.shape) + print("delta:", delta, delta.shape) + + _attn_bwd[grid]( q, arg_k, @@ -1008,10 +1135,17 @@ def backward(ctx, do, _): USE_ALIBI=False if ctx.alibi_slopes is None else True, ) + if DEBUG: + print() + print("_attn_bwd output") + print("dq:", dq, dq.shape) + print("dk:", dk, dk.shape) + print("dv:", dv, dv.shape) + return dq, dk, dv, None, None -attention = _attention.apply +attention_prefill = _attention_prefill.apply def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout): @@ -1068,6 +1202,70 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlen return q, k, v, input_metadata +def compare_tensors(out, ref, atol=2e-2, rtol=2e-2): + if out.shape != ref.shape: + print(f"Shapes mismatch: {out.shape} vs {ref.shape}") + return + + abs_diff = torch.abs(out - ref) + rel_diff = abs_diff / torch.abs(ref) + + abs_mask = abs_diff > atol + rel_mask = rel_diff > rtol + + mismatch_mask = torch.logical_or(abs_mask, rel_mask) + + if torch.any(mismatch_mask): + mismatch_indices = torch.nonzero(mismatch_mask) + print(f"Found {len(mismatch_indices)} mismatching elements:") + for idx in mismatch_indices: + idx_tuple = tuple(idx.tolist()) + print(f" Index {idx_tuple}:") + print(f" out: {out[idx_tuple].item()}") + print(f" ref: {ref[idx_tuple].item()}") + print(f" Absolute difference: {abs_diff[idx_tuple].item()}") + print(f" Relative difference: {rel_diff[idx_tuple].item()}") + else: + print("No mismatches found within the specified tolerance.") + + +def input_helper_increasing_seqlen(Z: int, HQ: int, HK: int, N_CTX_Q: int, N_CTX_K: int, D_HEAD: int, dtype: torch.dtype, layout: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, MetaData]: + torch.manual_seed(20) + + # Generate increasing values for the sequence dimension + q_seq_values = torch.arange(N_CTX_Q, dtype=dtype, device="cuda").unsqueeze(-1).expand(-1, D_HEAD) + k_seq_values = torch.arange(N_CTX_K, dtype=dtype, device="cuda").unsqueeze(-1).expand(-1, D_HEAD) + + # Initialize q, k, v with increasing sequence lengths + if layout == 'bhsd': + q = q_seq_values.unsqueeze(0).unsqueeze(0).expand(Z, HQ, -1, -1) + k = k_seq_values.unsqueeze(0).unsqueeze(0).expand(Z, HK, -1, -1) + elif layout == 'bshd': + q = q_seq_values.unsqueeze(0).unsqueeze(2).expand(Z, -1, HQ, -1) + k = k_seq_values.unsqueeze(0).unsqueeze(2).expand(Z, -1, HK, -1) + else: + assert False, 'Got unsupported tensor layout' + + v = k.clone() # v has the same shape as k + + # # Add some randomness to avoid exact equality + # q += torch.randn_like(q) * 0.1 + # k += torch.randn_like(k) * 0.1 + # v += torch.randn_like(v) * 0.1 + + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = N_CTX_Q + input_metadata.max_seqlens_k = N_CTX_K + input_metadata.layout = layout + + return q, k, v, input_metadata + + @pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ (1, 1, 1, 512, 256, 128), (1, 1, 1, 256, 128, 160), @@ -1096,7 +1294,10 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlen @pytest.mark.parametrize('layout', ['bshd', 'bhsd']) def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16): torch.manual_seed(20) - q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) + if True: + q, k, v, input_metadata = input_helper_increasing_seqlen(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) + else: + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) if causal: input_metadata.need_causal() @@ -1111,7 +1312,7 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, o = torch.empty_like(q) # triton implementation - tri_out, _ = attention(q, k, v, o, input_metadata) + tri_out, _ = attention_prefill(q, k, v, o, input_metadata) # Transpose here if layout is bshd so we have same reference code for all layouts if layout == 'bshd': @@ -1143,6 +1344,11 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, # compare if layout == 'bshd': ref_out = ref_out.transpose(1, 2).clone() + + + print("tri_out:", tri_out) + print("ref_out:",ref_out ) + # compare_tensors(tri_out, ref_out, atol=2e-2, rtol=2e-2) torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) @@ -1167,8 +1373,7 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, ]) @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('use_bias', [True]) -@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) -def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): +def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=torch.float16): torch.manual_seed(20) sm_scale = D_HEAD**-0.5 input_metadata = MetaData(sm_scale=sm_scale) @@ -1176,14 +1381,14 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): if causal: input_metadata.need_causal() if use_bias: - bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=dtype, device="cuda") + bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=torch.float32, device="cuda") input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K) else: bias = None o = torch.empty_like(q) # triton implementation - tri_out, _ = attention(q, k, v, o, input_metadata) + tri_out, _ = attention_prefill(q, k, v, o, input_metadata) # reference implementation:171 scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * sm_scale @@ -1199,8 +1404,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): # this by converting the NaNs to 0s, which is what they should be out of the softmax. nan_mask = torch.isnan(p) p[nan_mask == 1] = 0 - - ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v) + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) # compare torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) @@ -1223,7 +1427,7 @@ def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k[start_k:end_k]).float() p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v[start_k:end_k]) - attention(q, k, v, tri_out, input_metadata) + attention_prefill(q, k, v, tri_out, input_metadata) torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) @@ -1251,7 +1455,7 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k_curr).float() p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr) - attention(q, k, v, tri_out, input_metadata) + attention_prefill(q, k, v, tri_out, input_metadata) torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) @@ -1460,7 +1664,7 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal if causal: input_metadata.need_causal() o = torch.empty_like(q) - fn = lambda: attention(q, k, v, o, input_metadata) + fn = lambda: attention_prefill(q, k, v, o, input_metadata) if mode == 'bwd': o, _ = fn() do = torch.randn_like(o) @@ -1532,4 +1736,4 @@ def main(): if __name__ == '__main__': - sys.exit(main()) + sys.exit(main()) \ No newline at end of file diff --git a/scripts b/scripts new file mode 160000 index 000000000000..dadae855cd6e --- /dev/null +++ b/scripts @@ -0,0 +1 @@ +Subproject commit dadae855cd6e97f13223ce0dae9f11e52c804360 From 9f68d12728680849b3aac9c68d223c3876016889 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Mon, 19 Aug 2024 12:06:03 -0500 Subject: [PATCH 4/7] save --- python/perf-kernels/flash-attention.py | 46 ++++++++++++-------------- scripts | 2 +- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 917f04f57794..990c6d2021d7 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -269,32 +269,26 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) - if tl.program_id(0) == 4: - tl.device_print("q:", q) - # print("q:", q) - # tl.device_print("k:", k) + # if tl.program_id(0) == 7: + # tl.device_print("q:", q) + # tl.device_print("k:", k) # -- compute qk ---- qk += tl.dot(q, k) - # if tl.program_id(0) == 4: - # tl.device_print("qk:", qk) + # if tl.program_id(0) == 7: + # tl.device_print("qk before causal:", qk) if IS_CAUSAL: causal_boundary = start_n + offs_n_causal causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] - # if tl.program_id(0) == 4: + # if tl.program_id(0) == 7: # tl.device_print("causal_mask:", causal_mask) - - # if tl.program_id(0) == 4: - # tl.device_print("qk max before causal:", tl.max(qk)) - # tl.device_print("qk min before causal:", tl.min(qk)) qk = tl.where(causal_mask, qk, float("-inf")) - # if tl.program_id(0) == 4: - # tl.device_print("qk max after causal:", tl.max(qk)) - # tl.device_print("qk min after causal:", tl.min(qk)) + if tl.program_id(0) == 7: + tl.device_print("qk after causal:", qk) if bias_ptrs is not None: bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None @@ -383,9 +377,9 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): - # if tl.program_id(0) == 4: + # if tl.program_id(0) == 7: # tl.device_print("tl.program_id(0):", tl.program_id(0)) - # if tl.program_id(0) == 4: + # if tl.program_id(0) == 7: # tl.device_print("BLOCK_M:", BLOCK_M) # tl.device_print("BLOCK_N:", BLOCK_N) # tl.device_print("BLOCK_DMODEL:", BLOCK_DMODEL) @@ -464,7 +458,7 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s elif seqlen_k % BLOCK_N: n_extra_tokens = seqlen_k % BLOCK_N PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) - # if tl.program_id(0) == 4: + # if tl.program_id(0) == 7: # tl.device_print("PADDED_HEAD", PADDED_HEAD) # Compute pointers for all the tensors used in this kernel. @@ -581,7 +575,7 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s acc = acc.to(Out.type.element_ty) if IS_CAUSAL: if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - # if tl.program_id(0) == 4: + # if tl.program_id(0) == 7: # tl.device_print("causal_start_idx:", causal_start_idx) # tl.device_print("start_m_idx:", start_m_idx) # tl.device_print("end_m_idx:", end_m_idx) @@ -1001,13 +995,13 @@ def forward(ctx, q, k, v, o, metadata): USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax) - if DEBUG: - print("Saving in ctx for backward") - print("q:", q, q.shape) - print("k:", k, k.shape) - print("v:", v, v.shape) - print("o:", o, o.shape) - print("M:", M, M.shape) + # if DEBUG: + # print("Saving in ctx for backward") + # print("q:", q, q.shape) + # print("k:", k, k.shape) + # print("v:", v, v.shape) + # print("o:", o, o.shape) + # print("M:", M, M.shape) ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid @@ -1269,6 +1263,8 @@ def input_helper_increasing_seqlen(Z: int, HQ: int, HK: int, N_CTX_Q: int, N_CTX @pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ (1, 1, 1, 512, 256, 128), (1, 1, 1, 256, 128, 160), + (1, 1, 1, 512, 2, 160), + (1, 1, 1, 512, 128, 160), (1, 1, 1, 512, 256, 160), (1, 1, 1, 1024, 512, 160), (1, 1, 1, 512, 256, 256), diff --git a/scripts b/scripts index dadae855cd6e..12fd5d44bc2c 160000 --- a/scripts +++ b/scripts @@ -1 +1 @@ -Subproject commit dadae855cd6e97f13223ce0dae9f11e52c804360 +Subproject commit 12fd5d44bc2c36547fdb5a649c212f994bad02af From 605ce29980611f113827a4c4e297029a0d005fe4 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Mon, 19 Aug 2024 16:13:46 -0500 Subject: [PATCH 5/7] use rcp = 1 --- python/perf-kernels/flash-attention.py | 45 +++++++++++++++++--------- scripts | 2 +- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 990c6d2021d7..f21905dd4600 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -30,6 +30,11 @@ import triton.language as tl DEBUG = True +if True: + RCP_LN2_FWD = triton.language.constexpr(1) +else: + RCP_LN2_FWD = triton.language.constexpr(1.4426950408889634) # = 1.0 / ln(2) + class MetaData(): @@ -269,25 +274,25 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) - # if tl.program_id(0) == 7: + # if tl.program_id(0) == 0: # tl.device_print("q:", q) # tl.device_print("k:", k) # -- compute qk ---- qk += tl.dot(q, k) - # if tl.program_id(0) == 7: - # tl.device_print("qk before causal:", qk) + if tl.program_id(0) == 0: + tl.device_print("qk before causal:", qk) if IS_CAUSAL: causal_boundary = start_n + offs_n_causal causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] - # if tl.program_id(0) == 7: + # if tl.program_id(0) == 0: # tl.device_print("causal_mask:", causal_mask) qk = tl.where(causal_mask, qk, float("-inf")) - if tl.program_id(0) == 7: + if tl.program_id(0) == 0: tl.device_print("qk after causal:", qk) if bias_ptrs is not None: @@ -296,7 +301,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri # While bias is added after multiplying qk with sm_scale, # our optimization to use 2^x instead of e^x results in an additional # scale factor of log2(e) which we must also multiply the bias with. - qk += (bias * 1.44269504089) + qk += (bias * RCP_LN2_FWD) if alibi_slope is not None: # Compute the global position of each token within the sequence @@ -304,7 +309,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri global_n_positions = start_n + tl.arange(0, BLOCK_N) alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, global_n_positions) - qk += (alibi_block * 1.44269504089) # scale factor of log2(e) + qk += (alibi_block * RCP_LN2_FWD) # scale factor of log2(e) # softmax m_ij = tl.maximum(m_i, tl.max(qk, 1)) @@ -377,9 +382,9 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): - # if tl.program_id(0) == 7: + # if tl.program_id(0) == 0: # tl.device_print("tl.program_id(0):", tl.program_id(0)) - # if tl.program_id(0) == 7: + # if tl.program_id(0) == 0: # tl.device_print("BLOCK_M:", BLOCK_M) # tl.device_print("BLOCK_N:", BLOCK_N) # tl.device_print("BLOCK_DMODEL:", BLOCK_DMODEL) @@ -458,7 +463,7 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s elif seqlen_k % BLOCK_N: n_extra_tokens = seqlen_k % BLOCK_N PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) - # if tl.program_id(0) == 7: + # if tl.program_id(0) == 0: # tl.device_print("PADDED_HEAD", PADDED_HEAD) # Compute pointers for all the tensors used in this kernel. @@ -499,12 +504,14 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # scale sm_scale by log_2(e) and use 2^x in the loop as we do not # have native e^x support in HW. - qk_scale = sm_scale * 1.44269504089 + qk_scale = sm_scale * RCP_LN2_FWD # Q is loaded once at the beginning and shared by all N blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q if PADDED_HEAD: q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + # if tl.program_id(0) == 0: + # tl.device_print("q:", q) q = (q * qk_scale).to(q.type.element_ty) # Here we compute how many full and masked blocks we have. @@ -575,7 +582,7 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s acc = acc.to(Out.type.element_ty) if IS_CAUSAL: if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - # if tl.program_id(0) == 7: + # if tl.program_id(0) == 0: # tl.device_print("causal_start_idx:", causal_start_idx) # tl.device_print("start_m_idx:", start_m_idx) # tl.device_print("end_m_idx:", end_m_idx) @@ -683,7 +690,7 @@ def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, kqT = tl.dot(k, qT) if alibi_slope is not None: alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n, True) - kqT += alibi_block * 1.44269504089 + kqT += alibi_block * RCP_LN2_FWD pT = tl.math.exp2(kqT - m[None, :]) # Autoregressive masking. @@ -734,7 +741,7 @@ def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, qk = tl.dot(q, kT) if alibi_slope is not None: alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n) - qk += alibi_block * 1.44269504089 + qk += alibi_block * RCP_LN2_FWD p = tl.math.exp2(qk - m) # Autoregressive masking. @@ -1251,7 +1258,10 @@ def input_helper_increasing_seqlen(Z: int, HQ: int, HK: int, N_CTX_Q: int, N_CTX k.requires_grad_(True) v.requires_grad_(True) - sm_scale = D_HEAD**-0.5 + if True: + sm_scale = 1 + else: + sm_scale = D_HEAD**-0.5 input_metadata = MetaData(sm_scale=sm_scale) input_metadata.max_seqlens_q = N_CTX_Q input_metadata.max_seqlens_k = N_CTX_K @@ -1263,7 +1273,12 @@ def input_helper_increasing_seqlen(Z: int, HQ: int, HK: int, N_CTX_Q: int, N_CTX @pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ (1, 1, 1, 512, 256, 128), (1, 1, 1, 256, 128, 160), + (1, 1, 1, 4, 2, 160), + (1, 1, 1, 16, 2, 160), + (1, 1, 1, 64, 2, 160), + (1, 1, 1, 256, 2, 160), (1, 1, 1, 512, 2, 160), + (1, 1, 1, 512, 4, 160), (1, 1, 1, 512, 128, 160), (1, 1, 1, 512, 256, 160), (1, 1, 1, 1024, 512, 160), diff --git a/scripts b/scripts index 12fd5d44bc2c..fa4c48522763 160000 --- a/scripts +++ b/scripts @@ -1 +1 @@ -Subproject commit 12fd5d44bc2c36547fdb5a649c212f994bad02af +Subproject commit fa4c485227637f59dc38a095657b4a9f5cf10c23 From 9bb8df61448afb7bb7e2891696491ab4a96cdcd7 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 20 Aug 2024 11:29:11 -0500 Subject: [PATCH 6/7] match main --- python/perf-kernels/flash-attention.py | 103 ++++++++++++++----------- 1 file changed, 60 insertions(+), 43 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index f21905dd4600..75f8bd5d0416 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -245,6 +245,13 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr): + # tl.device_print("actual_seqlen_k:", actual_seqlen_k) + # tl.device_print("ACTUAL_BLOCK_DMODEL:", ACTUAL_BLOCK_DMODEL) + # tl.device_print("BLOCK_N:", BLOCK_N) + # tl.device_print("BLOCK_M:", BLOCK_M) + # tl.device_print("block_min:", block_min) + # tl.device_print("block_max:", block_max) + # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if @@ -253,7 +260,11 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri k_offs_n = start_n + tl.arange(0, BLOCK_N) else: k_offs_n = None - k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + + if PADDED_HEAD: + k_offs_k = tl.arange(0, BLOCK_DMODEL) + else: + k_offs_k = None k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) if PRE_LOAD_V: # We can use the same offsets as k, just with dims transposed. @@ -281,8 +292,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri # -- compute qk ---- qk += tl.dot(q, k) - if tl.program_id(0) == 0: - tl.device_print("qk before causal:", qk) + # if tl.program_id(0) == 0: + # tl.device_print("qk before causal:", qk) if IS_CAUSAL: causal_boundary = start_n + offs_n_causal @@ -292,8 +303,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri qk = tl.where(causal_mask, qk, float("-inf")) - if tl.program_id(0) == 0: - tl.device_print("qk after causal:", qk) + # if tl.program_id(0) == 0: + # tl.device_print("qk after causal:", qk) if bias_ptrs is not None: bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None @@ -310,11 +321,24 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, global_n_positions) qk += (alibi_block * RCP_LN2_FWD) # scale factor of log2(e) + + + # tl.device_print("qk:", qk) # softmax m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] + # tl.device_print("m_ij:", m_ij) + # this can lead to -inf - (-inf) which is nan + if IS_CAUSAL: + # tl.device_print("qk before clamp:", qk) + qk = tl.where(qk > float("-inf"), qk - m_ij[:, None], float("-inf")) + else: + qk = qk - m_ij[:, None] + # if tl.program_id(0) == 0: + # tl.device_print("qk before exp:", qk) p = tl.math.exp2(qk) + # if tl.program_id(0) == 0: + # tl.device_print("p:", p) # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) @@ -326,22 +350,39 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) + # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) + if IS_CAUSAL: # this can lead to -inf - (-inf) which is nan + alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_ij, float("-inf"))) + else: + alpha = tl.math.exp2(m_i - m_ij) + # tl.device_print("alpha:", alpha) + + acc = acc * alpha[:, None] + # tl.device_print("acc:", acc) if not PRE_LOAD_V: v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + # tl.device_print("v:", v) # -- update m_i and l_i + # tl.device_print("l_i before:", l_i) l_i = l_i * alpha + l_ij + # tl.device_print("l_i after:", l_i) # update m_i and l_i m_i = m_ij + # tl.device_print("acc before:", acc) + # tl.device_print("p:", p) acc += tl.dot(p.to(v.type.element_ty), v) + # tl.device_print("acc after:", acc) k_ptrs += BLOCK_N * stride_kn v_ptrs += BLOCK_N * stride_vk if bias_ptrs is not None: bias_ptrs += BLOCK_N * stride_bn if RETURN_ENCODED_SOFTMAX: encoded_sm_ptrs += BLOCK_N + + # tl.device_print("acc:", acc) + return acc, l_i, m_i @@ -372,11 +413,12 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], use_cuda_graph=True, ) + @triton.jit -def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, - stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, - stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, +def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, + stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, + stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, + cu_seqlens_k, dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, @@ -504,7 +546,7 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # scale sm_scale by log_2(e) and use 2^x in the loop as we do not # have native e^x support in HW. - qk_scale = sm_scale * RCP_LN2_FWD + QK_SCALE: tl.constexpr = SM_SCALE * RCP_LN2_FWD # Q is loaded once at the beginning and shared by all N blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q if PADDED_HEAD: @@ -512,7 +554,7 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) # if tl.program_id(0) == 0: # tl.device_print("q:", q) - q = (q * qk_scale).to(q.type.element_ty) + q = (q * QK_SCALE).to(q.type.element_ty) # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -569,7 +611,10 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, ACTUAL_BLOCK_DMODEL) # epilogue - acc = acc / l_i[:, None] + # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + if ENABLE_DROPOUT: acc = acc / (1 - dropout_p) # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, @@ -1202,34 +1247,6 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlen input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) return q, k, v, input_metadata - -def compare_tensors(out, ref, atol=2e-2, rtol=2e-2): - if out.shape != ref.shape: - print(f"Shapes mismatch: {out.shape} vs {ref.shape}") - return - - abs_diff = torch.abs(out - ref) - rel_diff = abs_diff / torch.abs(ref) - - abs_mask = abs_diff > atol - rel_mask = rel_diff > rtol - - mismatch_mask = torch.logical_or(abs_mask, rel_mask) - - if torch.any(mismatch_mask): - mismatch_indices = torch.nonzero(mismatch_mask) - print(f"Found {len(mismatch_indices)} mismatching elements:") - for idx in mismatch_indices: - idx_tuple = tuple(idx.tolist()) - print(f" Index {idx_tuple}:") - print(f" out: {out[idx_tuple].item()}") - print(f" ref: {ref[idx_tuple].item()}") - print(f" Absolute difference: {abs_diff[idx_tuple].item()}") - print(f" Relative difference: {rel_diff[idx_tuple].item()}") - else: - print("No mismatches found within the specified tolerance.") - - def input_helper_increasing_seqlen(Z: int, HQ: int, HK: int, N_CTX_Q: int, N_CTX_K: int, D_HEAD: int, dtype: torch.dtype, layout: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, MetaData]: torch.manual_seed(20) @@ -1747,4 +1764,4 @@ def main(): if __name__ == '__main__': - sys.exit(main()) \ No newline at end of file + sys.exit(main()) From ea93831cc86987f1c7c872a4370f9e6013155d75 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Tue, 20 Aug 2024 12:31:33 -0500 Subject: [PATCH 7/7] set num_warps to 8 --- python/perf-kernels/flash-attention.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 75f8bd5d0416..b39db8e6ad36 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -388,7 +388,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri @triton.autotune( configs=[ - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=1), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), # triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, # num_warps=8), # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, @@ -1288,9 +1288,14 @@ def input_helper_increasing_seqlen(Z: int, HQ: int, HK: int, N_CTX_Q: int, N_CTX @pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ + (1, 1, 1, 4, 2, 16), + (1, 1, 1, 4, 2, 32), + (1, 1, 1, 4, 2, 64), + (1, 1, 1, 4, 2, 128), + (1, 1, 1, 4, 2, 160), + (1, 1, 1, 4, 2, 256), (1, 1, 1, 512, 256, 128), (1, 1, 1, 256, 128, 160), - (1, 1, 1, 4, 2, 160), (1, 1, 1, 16, 2, 160), (1, 1, 1, 64, 2, 160), (1, 1, 1, 256, 2, 160),