diff --git a/tritonsrc/bwd_kernel_common.py b/tritonsrc/bwd_kernel_common.py index effca28..016753d 100644 --- a/tritonsrc/bwd_kernel_common.py +++ b/tritonsrc/bwd_kernel_common.py @@ -5,6 +5,7 @@ import triton import triton.language as tl from dropout import dropout_mask, dropout_rng, dropout_offsets +from masked_load_store import load_fn # Helper function, but not always usable due to compiler bugs (esp. used with tl.trans) @triton.jit @@ -16,9 +17,8 @@ def dot(BLOCK_M : tl.constexpr, QDIM : tl.constexpr, KDIM : tl.constexpr, q, k): @triton.jit def bwd_kernel_dk_dv_common( - Q_block_ptr, KT_block_ptr, VT_block_ptr, B_block_ptr, - sm_scale, DO_block_ptr, - DK_block_ptr, DV_block_ptr, + q_ptrs, q_stride, kt, vt, B_block_ptr, + sm_scale, do_ptrs, do_stride, l_ptrs, D_ptrs, seqlen_q, @@ -40,25 +40,20 @@ def bwd_kernel_dk_dv_common( # initialize offsets offs_m = start_m + tl.arange(0, BLOCK_N) offs_n = tl.arange(0, BLOCK_M) + ld_offs_d = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) lo = (start_m // BLOCK_M) * BLOCK_M if CAUSAL else 0 hi = seqlen_q - Q_block_ptr = tl.advance(Q_block_ptr, (lo, 0)) - DO_block_ptr = tl.advance(DO_block_ptr, (lo, 0)) + # Q_block_ptr = tl.advance(Q_block_ptr, (lo, 0)) + # DO_block_ptr = tl.advance(DO_block_ptr, (lo, 0)) + q_ptrs += lo * q_stride + do_ptrs += lo * do_stride qk_scale = sm_scale * 1.44269504089 # load k and v: they will stay in SRAM throughout # (BLOCK_DMODEL, BLOCK_N) - if PADDED_HEAD: - kt = tl.load(KT_block_ptr, boundary_check=(1,0), padding_option="zero") - else: - kt = tl.load(KT_block_ptr, boundary_check=(1,), padding_option="zero") - kt = (kt * qk_scale).to(KT_block_ptr.type.element_ty) + kt = (kt * qk_scale).to(kt.type.element_ty) # (BLOCK_DMODEL, BLOCK_N) - if PADDED_HEAD: - vt = tl.load(VT_block_ptr, boundary_check=(1,0), padding_option="zero") - else: - vt = tl.load(VT_block_ptr, boundary_check=(1,), padding_option="zero") dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) ''' @@ -80,23 +75,30 @@ def bwd_kernel_dk_dv_common( ''' # loop over q (seqlen_q, dhead), do (seqlen_q, d_head) for start_n in range(lo, hi, BLOCK_M): - if lo + BLOCK_M > seqlen_q: + if start_n + BLOCK_M > seqlen_q: q_padded = True else: q_padded = False + # TODO: Unify the name, the usage of m/n is very confusing offs_m_curr = offs_n[:, None] + start_n # (BLOCK_M, 1) # -- load q, do -- # TODO: It is more optimal to do OOB check only in the last iter. # (BLOCK_M, BLOCK_DMODEL), offs = (BLOCK_M * iter, 0) = (start_n, 0) - if PADDED_HEAD: - q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero") + # + # This common function can be further split into regular and + # non-regular version, determined by tl.constexpr, just like the fwd kernel. + + # q = tl.load(Q_block_ptr) + if q_padded: + q = load_fn(q_ptrs, offs_n + start_n, ld_offs_d, seqlen_q, head_dim) else: - q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero") - # do: (BLOCK_M, BLOCK_DMODEL) - if PADDED_HEAD: - do = tl.load(DO_block_ptr, boundary_check=(0,1), padding_option="zero") + q = load_fn(q_ptrs, None, ld_offs_d, seqlen_q, head_dim) + # do = tl.load(DO_block_ptr) + # TODO: pre_load_do + if q_padded: + do = load_fn(do_ptrs, offs_n + start_n, ld_offs_d, seqlen_q, head_dim) else: - do = tl.load(DO_block_ptr, boundary_check=(0,), padding_option="zero") + do = load_fn(do_ptrs, None, ld_offs_d, seqlen_q, head_dim) # -- compute qk ---- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # TODO: These two checks can be optimized to occur on the last iter. @@ -111,6 +113,7 @@ def bwd_kernel_dk_dv_common( pass elif BIAS_TYPE == 1: # FIXME: do boundary_check correctly + # TODO: check q_padded is correct calculated, the condition should be start_n + BLOCK_M """ if q_padded and k_padded: # CAVEAT: using "or" disables the partial boundary_check branches bias = tl.load(B_block_ptr, boundary_check=(0,1), padding_option="zero") @@ -144,12 +147,12 @@ def bwd_kernel_dk_dv_common( keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, max_seqlen_k) # CAVEAT: do NOT update p, ds needs the original p if BLOCK_M == 1: - dv += tl.where(keep, p / (1 - dropout_p), 0.0).to(Q_block_ptr.dtype.element_ty) * do + dv += tl.where(keep, p / (1 - dropout_p), 0.0).to(q_ptrs.dtype.element_ty) * do else: - dv += tl.dot(tl.trans(tl.where(keep, p / (1 - dropout_p), 0.0)).to(Q_block_ptr.dtype.element_ty), do) + dv += tl.dot(tl.trans(tl.where(keep, p / (1 - dropout_p), 0.0)).to(q_ptrs.dtype.element_ty), do) else: if BLOCK_M == 1: - dv += p.to(Q_block_ptr.dtype.element_ty) * do + dv += p.to(q_ptrs.dtype.element_ty) * do else: dv += tl.dot(tl.trans(p.to(do.dtype)), do) dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -163,22 +166,24 @@ def bwd_kernel_dk_dv_common( ds = p * (dp - Di) # (BLOCK_M, BLOCK_N) # compute dk if BLOCK_M == 1: - dk += ds.to(Q_block_ptr.dtype.element_ty) * q + dk += ds.to(q_ptrs.dtype.element_ty) * q else: # ds.shape = (BLOCK_M, BLOCK_N), q.shape = (BLOCK_M, BLOCK_DMODEL) - dk += tl.dot(tl.trans(ds.to(Q_block_ptr.dtype.element_ty)), q) # (BLOCK_N, BLOCK_DMODEL) - # update pointers - Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0)) - DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0)) # Debug DO accessing problems + dk += tl.dot(tl.trans(ds.to(q_ptrs.dtype.element_ty)), q) # (BLOCK_N, BLOCK_DMODEL) + # update pointers (block_ptr code was left intentionally as comment) + # Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0)) + # DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0)) # Debug DO accessing problems + q_ptrs += q_stride * BLOCK_M + do_ptrs += do_stride * BLOCK_M if BIAS_TYPE == 1: B_block_ptr = tl.advance(B_block_ptr, (BLOCK_M, 0)) - return (dk * sm_scale).to(DK_block_ptr.type.element_ty), dv.to(DV_block_ptr.type.element_ty) + return (dk * sm_scale).to(kt.type.element_ty), dv.to(vt.type.element_ty) @triton.jit def bwd_kernel_dq_db_common( - Q_block_ptr, K_block_ptr, V_block_ptr, B_block_ptr, - sm_scale, DO_block_ptr, - DQ_block_ptr, DB_block_ptr, store_db, + q, kt_ptrs, k_stride, vt_ptrs, v_stride, B_block_ptr, + sm_scale, do, + dq, DB_block_ptr, store_db, l_ptrs, D_ptrs, seqlen_q, @@ -197,37 +202,34 @@ def bwd_kernel_dq_db_common( PADDED_HEAD: tl.constexpr, BIAS_TYPE: tl.constexpr, ): - if start_m + BLOCK_N > seqlen_k: - k_padded = True - else: - k_padded = False # initialize offsets offs_m = start_m + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + ld_offs_d = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + # Check for OOB accesses on D and LSE overflow_size_q = start_m + BLOCK_M - seqlen_q - boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size_q, dtype=tl.int32) - d_lse_ptrs_mask = boundary > tl.arange(0, BLOCK_M) - d_lse_padding = tl.full((BLOCK_M, ), 0, dtype=tl.float32) - Di = tl.load(D_ptrs + offs_m, mask=d_lse_ptrs_mask, other=d_lse_padding) - l_i = tl.load(l_ptrs + offs_m, mask=d_lse_ptrs_mask, other=d_lse_padding) + if overflow_size_q > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size_q, dtype=tl.int32) + d_lse_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + d_lse_padding = tl.full((BLOCK_M, ), 0, dtype=tl.float32) + Di = tl.load(D_ptrs + offs_m, mask=d_lse_ptrs_mask, other=d_lse_padding) + l_i = tl.load(l_ptrs + offs_m, mask=d_lse_ptrs_mask, other=d_lse_padding) + else: + Di = tl.load(D_ptrs + offs_m) + l_i = tl.load(l_ptrs + offs_m) + dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # loop over k, v lo = 0 hi = min(start_m + BLOCK_M, seqlen_k) if CAUSAL else seqlen_k if BIAS_TYPE == 1: B_block_ptr = tl.advance(B_block_ptr, (lo, 0)) + qk_scale = sm_scale * 1.44269504089 - # load q and do: they will stay in SRAM throughout - if PADDED_HEAD: - q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero") - else: - q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) - if PADDED_HEAD: - do = tl.load(DO_block_ptr, boundary_check=(0,1), padding_option="zero") - else: - do = tl.load(DO_block_ptr, boundary_check=(0,), padding_option="zero") + q = (q* qk_scale).to(q.type.element_ty) + ''' K1 K2 (d)V dO Q1 qk11 qk12 (d)v1 dO1 @@ -244,12 +246,18 @@ def bwd_kernel_dq_db_common( k_padded = False # -- load k, v -- # shape = (BLOCK_DMODEL, BLOCK_N), offs = (0, BLOCK_N * iter) = (0, start_n) - if PADDED_HEAD: - kt = tl.load(K_block_ptr, boundary_check=(1,0), padding_option="zero") - vt = tl.load(V_block_ptr, boundary_check=(1,0), padding_option="zero") + # kt = tl.load(K_block_ptr) + # vt = tl.load(V_block_ptr) + if k_padded: + kt = load_fn(kt_ptrs, ld_offs_d, offs_n + start_n, head_dim, seqlen_k) + else: + kt = load_fn(kt_ptrs, ld_offs_d, None, head_dim, seqlen_k) + + # TODO: pre_load_vt + if k_padded: + vt = load_fn(vt_ptrs, ld_offs_d, offs_n + start_n, head_dim, seqlen_k) else: - kt = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero") - vt = tl.load(V_block_ptr, boundary_check=(1,), padding_option="zero") + vt = load_fn(vt_ptrs, ld_offs_d, None, head_dim, seqlen_k) # -- compute qk ---- # q.offs = (start_m, 0), k.offs = (0, start_n) qk = dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, q, kt) @@ -292,18 +300,20 @@ def bwd_kernel_dq_db_common( # compute dq. Unfortunately we cannot avoid transpose here as this loop # uses k both normal and transpose. if BLOCK_M == 1: - dq += tl.view(kt, [BLOCK_DMODEL]) * ds.to(Q_block_ptr.type.element_ty) + dq += tl.view(kt, [BLOCK_DMODEL]) * ds.to(q.type.element_ty) else: # ds.shape = (BLOCK_M, BLOCK_N), kt.shape = (BLOCK_DMODEL, BLOCK_N) - dq += tl.dot(ds.to(Q_block_ptr.type.element_ty), tl.trans(kt)) # (BLOCK_M, BLOCK_DMODEL) + dq += tl.dot(ds.to(q.type.element_ty), tl.trans(kt)) # (BLOCK_M, BLOCK_DMODEL) if BIAS_TYPE == 1: if store_db: tl.store(DB_block_ptr, ds.to(DB_block_ptr.type.element_ty), boundary_check=(0,1)) # update pointers - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N)) + # Keep the block ptr as comment + # K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + # V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N)) + kt_ptrs += BLOCK_N * k_stride + vt_ptrs += BLOCK_N * v_stride if BIAS_TYPE == 1: B_block_ptr = tl.advance(B_block_ptr, (0, BLOCK_N)) DB_block_ptr = tl.advance(DB_block_ptr, (0, BLOCK_N)) - return (dq * sm_scale).to(DQ_block_ptr.type.element_ty) - # tl.store(DQ_block_ptr, (dq * sm_scale).to(DQ_block_ptr.type.element_ty), boundary_check=(0,1)) + return (dq * sm_scale).to(dq.type.element_ty) diff --git a/tritonsrc/bwd_split_kernel.py b/tritonsrc/bwd_split_kernel.py index 0bf8938..9f041d4 100644 --- a/tritonsrc/bwd_split_kernel.py +++ b/tritonsrc/bwd_split_kernel.py @@ -17,7 +17,7 @@ import triton import triton.language as tl from bwd_kernel_common import bwd_kernel_dk_dv_common, bwd_kernel_dq_db_common -from masked_load_store import mstore2d +from masked_load_store import load_fn, mstore2d # TODO: Remove Unused 'Out' Argument from kernels below @triton.jit @@ -50,12 +50,16 @@ def bwd_kernel_dk_dv( PADDED_HEAD: tl.constexpr, BIAS_TYPE: tl.constexpr, ): - start_m = tl.program_id(0) * BLOCK_N + start_m = tl.program_id(0) * BLOCK_N # start_m is a misused name. For dkdv it partitions seqlen_k off_h = tl.program_id(1) # head index off_z = tl.program_id(2) # batch index, for varlen it indicates index in cu_seqlens_q/k num_h = tl.num_programs(1) num_z = tl.num_programs(2) off_zh = off_z * num_h + off_h * 1 + offs_m = tl.arange(0, BLOCK_M) + offs_n = start_m + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + ld_offs_d = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) cu_seqlens_q_start = 0 cu_seqlens_k_start = 0 @@ -92,41 +96,61 @@ def bwd_kernel_dk_dv( # Q is consumed depending on block ID. Every block uses # previous block offset by BLOCK_M x D_HEAD. q_offset = off_h * stride_qh + batch_index * stride_qz + cu_seqlens_q_start * stride_qm - Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(seqlen_q, head_dim), - strides=(stride_qm, stride_qk), - offsets=(0, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) + Q += q_offset + q_ptrs = Q + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + q_advance = BLOCK_M * stride_qm + # Q_block_ptr = tl.make_block_ptr( + # base=Q, + # shape=(seqlen_q, head_dim), + # strides=(stride_qm, stride_qk), + # offsets=(0, 0), + # block_shape=(BLOCK_M, BLOCK_DMODEL), + # order=(1, 0) + # ) k_offset = off_h * stride_kh + batch_index * stride_kz + cu_seqlens_k_start * stride_kn - KT_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(head_dim, seqlen_k), - strides=(stride_kk, stride_kn), - offsets=(0, start_m), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) - ) + K += k_offset + kt_ptrs = K + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + # kt_offs_n = None if start_m + BLOCK_N <= seqlen_k else start_m + tl.arange(0, BLOCK_N) + if start_m + BLOCK_N <= seqlen_k: + kt = load_fn(kt_ptrs, ld_offs_d, None, head_dim, seqlen_k) + else: + kt = load_fn(kt_ptrs, ld_offs_d, offs_n, head_dim, seqlen_k) + # KT_block_ptr = tl.make_block_ptr( + # base=K + k_offset, + # shape=(head_dim, seqlen_k), + # strides=(stride_kk, stride_kn), + # offsets=(0, start_m), + # block_shape=(BLOCK_DMODEL, BLOCK_N), + # order=(0, 1) + # ) v_offset = off_h * stride_vh + batch_index * stride_vz + cu_seqlens_k_start * stride_vk - VT_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(head_dim, seqlen_k), - strides=(stride_vn, stride_vk), - offsets=(0, start_m), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) - ) + V += v_offset + # VT_block_ptr = tl.make_block_ptr( + # base=V, + # shape=(head_dim, seqlen_k), + # strides=(stride_vn, stride_vk), + # offsets=(0, start_m), + # block_shape=(BLOCK_DMODEL, BLOCK_N), + # order=(0, 1) + # ) + # vt = tl.load(VT_block_ptr) + vt_ptrs = V + offs_d[:, None] * stride_vn + offs_n[None, :] * stride_vk + if start_m + BLOCK_N <= seqlen_k: + vt = load_fn(vt_ptrs, ld_offs_d, None, head_dim, seqlen_k) + else: + vt = load_fn(vt_ptrs, ld_offs_d, offs_n, head_dim, seqlen_k) + # tl.device_print('vt', vt) do_offset = off_h * stride_oh + batch_index * stride_oz + cu_seqlens_q_start * stride_om - DO_block_ptr = tl.make_block_ptr( - base=DO + do_offset, - shape=(seqlen_q, head_dim), - strides=(stride_om, stride_ok), - offsets=(0, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) + DO += do_offset + do_ptrs = DO + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok + # DO_block_ptr = tl.make_block_ptr( + # base=DO, + # shape=(seqlen_q, head_dim), + # strides=(stride_om, stride_ok), + # offsets=(0, 0), + # block_shape=(BLOCK_M, BLOCK_DMODEL), + # order=(1, 0) + # ) if BIAS_TYPE == 0: B_block_ptr = 0 elif BIAS_TYPE == 1: @@ -156,28 +180,13 @@ def bwd_kernel_dk_dv( batch_philox_offset = 0 dk_offset = off_h * stride_dkh + batch_index * stride_dkz + cu_seqlens_k_start * stride_dkn - DK_block_ptr = tl.make_block_ptr( - base=DK + dk_offset, - shape=(seqlen_k, head_dim), - strides=(stride_dkn, stride_dkk), - offsets=(start_m, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0) - ) + DK += dk_offset dv_offset = off_h * stride_dvh + batch_index * stride_dvz + cu_seqlens_k_start * stride_dvk - DV_block_ptr = tl.make_block_ptr( - base=DV + dv_offset, - shape=(seqlen_k, head_dim), - strides=(stride_dvk, stride_dvn), - offsets=(start_m, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0) - ) + DV += dv_offset dk, dv = bwd_kernel_dk_dv_common( - Q_block_ptr, KT_block_ptr, VT_block_ptr, B_block_ptr, - sm_scale, DO_block_ptr, - DK_block_ptr, DV_block_ptr, + q_ptrs, stride_qm, kt, vt, B_block_ptr, + sm_scale, do_ptrs, stride_om, l_ptrs, D_ptrs, seqlen_q, seqlen_k, @@ -197,7 +206,7 @@ def bwd_kernel_dk_dv( mstore2d(dk, BLOCK_N, BLOCK_DMODEL, - o_base=DK + dk_offset, + o_base=DK, o_start_row=start_m, o_start_col=0, o_rows=seqlen_k, @@ -207,7 +216,7 @@ def bwd_kernel_dk_dv( mstore2d(dv, BLOCK_N, BLOCK_DMODEL, - o_base=DV + dv_offset, + o_base=DV, o_start_row=start_m, o_start_col=0, o_rows=seqlen_k, @@ -250,6 +259,10 @@ def bwd_kernel_dq( num_h = tl.num_programs(1) num_z = tl.num_programs(2) off_zh = off_z * num_h + off_h * 1 + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + ld_offs_d = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) cu_seqlens_q_start = 0 cu_seqlens_k_start = 0 @@ -283,41 +296,57 @@ def bwd_kernel_dq( batch_index = off_z # Initialize pointers to Q, K, V q_offset = off_h * stride_qh + batch_index * stride_qz + cu_seqlens_q_start * stride_qm - Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(seqlen_q, head_dim), - strides=(stride_qm, stride_qk), - offsets=(start_m, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) + Q += q_offset + # Q_block_ptr = tl.make_block_ptr( + # base=Q, + # shape=(seqlen_q, head_dim), + # strides=(stride_qm, stride_qk), + # offsets=(start_m, 0), + # block_shape=(BLOCK_M, BLOCK_DMODEL), + # order=(1, 0) + # ) + q_ptrs = Q + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + if start_m + BLOCK_M <= seqlen_q: + q = load_fn(q_ptrs, None, ld_offs_d, seqlen_q, head_dim) + else: + q = load_fn(q_ptrs, offs_m, ld_offs_d, seqlen_q, head_dim) k_offset = off_h * stride_kh + batch_index * stride_kz + cu_seqlens_k_start * stride_kn - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(head_dim, seqlen_k), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) - ) + K += k_offset + kt_ptrs = K + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + # K_block_ptr = tl.make_block_ptr( + # base=K, + # shape=(head_dim, seqlen_k), + # strides=(stride_kk, stride_kn), + # offsets=(0, 0), + # block_shape=(BLOCK_DMODEL, BLOCK_N), + # order=(0, 1) + # ) v_offset = off_h * stride_vh + batch_index * stride_vz + cu_seqlens_k_start * stride_vk - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(head_dim, seqlen_k), - strides=(stride_vn, stride_vk), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) - ) + V += v_offset + vt_ptrs = V + offs_d[:, None] * stride_vn + offs_n[None, :] * stride_vk + # V_block_ptr = tl.make_block_ptr( + # base=V, + # shape=(head_dim, seqlen_k), + # strides=(stride_vn, stride_vk), + # offsets=(0, 0), + # block_shape=(BLOCK_DMODEL, BLOCK_N), + # order=(0, 1) + # ) do_offset = off_h * stride_oh + batch_index * stride_oz + cu_seqlens_q_start * stride_om - DO_block_ptr = tl.make_block_ptr( - base=DO + do_offset, - shape=(seqlen_q, head_dim), - strides=(stride_om, stride_ok), - offsets=(start_m, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) + DO += do_offset + # DO_block_ptr = tl.make_block_ptr( + # base=DO, + # shape=(seqlen_q, head_dim), + # strides=(stride_om, stride_ok), + # offsets=(start_m, 0), + # block_shape=(BLOCK_M, BLOCK_DMODEL), + # order=(1, 0) + # ) + do_ptrs = DO + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok + if start_m + BLOCK_M <= seqlen_q: + do = load_fn(do_ptrs, None, ld_offs_d, seqlen_q, head_dim) + else: + do = load_fn(do_ptrs, offs_m, ld_offs_d, seqlen_q, head_dim) # pointer to row-wise quantities in value-like data D_ptrs = D + off_zh * max_seqlen_q l_ptrs = L + off_zh * max_seqlen_q @@ -328,17 +357,20 @@ def bwd_kernel_dq( # initialize pointers to output dq_offset = batch_index * stride_dqz + off_h * stride_dqh + cu_seqlens_q_start * stride_dqm - # tl.device_print('batch_index ', batch_index) - # tl.device_print('off_h ', off_h) - # tl.device_print('cu_seqlens_q_start ', cu_seqlens_q_start) - DQ_block_ptr = tl.make_block_ptr( - base=DQ + dq_offset, - shape=(seqlen_q, head_dim), - strides=(stride_dqm, stride_dqk), - offsets=(start_m, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) + DQ += dq_offset + # DQ_block_ptr = tl.make_block_ptr( + # base=DQ, + # shape=(seqlen_q, head_dim), + # strides=(stride_dqm, stride_dqk), + # offsets=(start_m, 0), + # block_shape=(BLOCK_M, BLOCK_DMODEL), + # order=(1, 0) + # ) + dq_ptrs = DQ + offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqk + if start_m + BLOCK_M <= seqlen_q: + dq = load_fn(dq_ptrs, None, ld_offs_d, seqlen_q, head_dim) + else: + dq = load_fn(dq_ptrs, offs_m, ld_offs_d, seqlen_q, head_dim) store_db = True if BIAS_TYPE == 0: B_block_ptr = 0 @@ -368,9 +400,9 @@ def bwd_kernel_dq( tl.static_assert(False, f'Unsupported BIAS_TYPE {BIAS_TYPE}') dq = bwd_kernel_dq_db_common( - Q_block_ptr, K_block_ptr, V_block_ptr, B_block_ptr, - sm_scale, DO_block_ptr, - DQ_block_ptr, DB_block_ptr, store_db, + q, kt_ptrs, stride_kn, vt_ptrs, stride_vk, B_block_ptr, + sm_scale, do, + dq, DB_block_ptr, store_db, l_ptrs, D_ptrs, seqlen_q, seqlen_k, @@ -387,10 +419,10 @@ def bwd_kernel_dq( ENABLE_DROPOUT, PADDED_HEAD, BIAS_TYPE) - dq_ptrs, dq_masks = mstore2d(dq, + mstore2d(dq, BLOCK_M, BLOCK_DMODEL, - o_base=DQ + dq_offset, + o_base=DQ, o_start_row=start_m, o_start_col=0, o_rows=seqlen_q, diff --git a/tritonsrc/masked_load_store.py b/tritonsrc/masked_load_store.py index 68804f8..7a2d787 100644 --- a/tritonsrc/masked_load_store.py +++ b/tritonsrc/masked_load_store.py @@ -5,6 +5,78 @@ import triton import triton.language as tl +# Convenience function to load with optional boundary checks. +# "First" is the major dim, "second" is the minor dim. +@triton.jit +def load_fn(ptrs, offset_first, offset_second, _in_boundary_first, _in_boundary_second): + boundary_first = _in_boundary_first + boundary_second = _in_boundary_second + """ + # Debugging GPU segfault + boundary_first = 0 + boundary_second = 0 + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + return tl.load(ptrs, mask=mask, other=0.0) + """ + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) + else: + tensor = tl.load(ptrs) + return tensor + +@triton.jit +def mload1d( + REGS : tl.constexpr, + i_base, + i_start, + i_nums, +): + offs = tl.arange(0, REGS) + i_start + i_ptrs = i_base + offs + # return tl.load(i_base + offs) + overflow = i_start + REGS - i_nums + # if overflow <= 0: + # return tl.load(i_ptrs) + i_ptrs_mask = tl.full([REGS], 1, dtype=tl.int1) + i_ptrs_mask = i_ptrs_mask & (offs < i_nums) + return tl.load(i_ptrs, mask=i_ptrs_mask, other=0.0) + +@triton.jit +def mload2d( + REG_ROWS : tl.constexpr, + REG_COLS : tl.constexpr, + i_base, + i_start_row, + i_start_col, + i_rows, + i_cols, + stride_row, + stride_col, +): + off_rows = tl.arange(0, REG_ROWS) + i_start_row + off_cols = tl.arange(0, REG_COLS) + i_start_col + i_ptrs = i_base + off_rows[:, None] * stride_row + off_cols[None, :] * stride_col + row_overflow = i_start_row + REG_ROWS - i_rows + col_overflow = i_start_col + REG_COLS - i_cols + # if row_overflow <= 0 and col_overflow <= 0: + # if NOCHECK: + # return tl.load(i_ptrs) + i_ptrs_mask = tl.full([REG_ROWS, REG_COLS], 1, dtype=tl.int1) + if row_overflow > 0: + i_ptrs_mask = i_ptrs_mask & (off_rows[:, None] < i_rows) + if col_overflow > 0: + i_ptrs_mask = i_ptrs_mask & (off_cols[None, :] < i_cols) + return tl.load(i_ptrs, mask=i_ptrs_mask, other=0.0) + @triton.jit def mstore2d( registers,