Skip to content

Commit

Permalink
Preview 2 of 0.7b
Browse files Browse the repository at this point in the history
Add the backward kernel improvements that were missed in Preview 1.
  • Loading branch information
xinyazhang committed Aug 4, 2024
1 parent 98395ad commit 7aa3e14
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 167 deletions.
138 changes: 74 additions & 64 deletions tritonsrc/bwd_kernel_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton
import triton.language as tl
from dropout import dropout_mask, dropout_rng, dropout_offsets
from masked_load_store import load_fn

# Helper function, but not always usable due to compiler bugs (esp. used with tl.trans)
@triton.jit
Expand All @@ -16,9 +17,8 @@ def dot(BLOCK_M : tl.constexpr, QDIM : tl.constexpr, KDIM : tl.constexpr, q, k):

@triton.jit
def bwd_kernel_dk_dv_common(
Q_block_ptr, KT_block_ptr, VT_block_ptr, B_block_ptr,
sm_scale, DO_block_ptr,
DK_block_ptr, DV_block_ptr,
q_ptrs, q_stride, kt, vt, B_block_ptr,
sm_scale, do_ptrs, do_stride,
l_ptrs,
D_ptrs,
seqlen_q,
Expand All @@ -40,25 +40,20 @@ def bwd_kernel_dk_dv_common(
# initialize offsets
offs_m = start_m + tl.arange(0, BLOCK_N)
offs_n = tl.arange(0, BLOCK_M)
ld_offs_d = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL)

lo = (start_m // BLOCK_M) * BLOCK_M if CAUSAL else 0
hi = seqlen_q
Q_block_ptr = tl.advance(Q_block_ptr, (lo, 0))
DO_block_ptr = tl.advance(DO_block_ptr, (lo, 0))
# Q_block_ptr = tl.advance(Q_block_ptr, (lo, 0))
# DO_block_ptr = tl.advance(DO_block_ptr, (lo, 0))
q_ptrs += lo * q_stride
do_ptrs += lo * do_stride

qk_scale = sm_scale * 1.44269504089
# load k and v: they will stay in SRAM throughout
# (BLOCK_DMODEL, BLOCK_N)
if PADDED_HEAD:
kt = tl.load(KT_block_ptr, boundary_check=(1,0), padding_option="zero")
else:
kt = tl.load(KT_block_ptr, boundary_check=(1,), padding_option="zero")
kt = (kt * qk_scale).to(KT_block_ptr.type.element_ty)
kt = (kt * qk_scale).to(kt.type.element_ty)
# (BLOCK_DMODEL, BLOCK_N)
if PADDED_HEAD:
vt = tl.load(VT_block_ptr, boundary_check=(1,0), padding_option="zero")
else:
vt = tl.load(VT_block_ptr, boundary_check=(1,), padding_option="zero")
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
'''
Expand All @@ -80,23 +75,30 @@ def bwd_kernel_dk_dv_common(
'''
# loop over q (seqlen_q, dhead), do (seqlen_q, d_head)
for start_n in range(lo, hi, BLOCK_M):
if lo + BLOCK_M > seqlen_q:
if start_n + BLOCK_M > seqlen_q:
q_padded = True
else:
q_padded = False
# TODO: Unify the name, the usage of m/n is very confusing
offs_m_curr = offs_n[:, None] + start_n # (BLOCK_M, 1)
# -- load q, do --
# TODO: It is more optimal to do OOB check only in the last iter.
# (BLOCK_M, BLOCK_DMODEL), offs = (BLOCK_M * iter, 0) = (start_n, 0)
if PADDED_HEAD:
q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero")
#
# This common function can be further split into regular and
# non-regular version, determined by tl.constexpr, just like the fwd kernel.

# q = tl.load(Q_block_ptr)
if q_padded:
q = load_fn(q_ptrs, offs_n + start_n, ld_offs_d, seqlen_q, head_dim)
else:
q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero")
# do: (BLOCK_M, BLOCK_DMODEL)
if PADDED_HEAD:
do = tl.load(DO_block_ptr, boundary_check=(0,1), padding_option="zero")
q = load_fn(q_ptrs, None, ld_offs_d, seqlen_q, head_dim)
# do = tl.load(DO_block_ptr)
# TODO: pre_load_do
if q_padded:
do = load_fn(do_ptrs, offs_n + start_n, ld_offs_d, seqlen_q, head_dim)
else:
do = tl.load(DO_block_ptr, boundary_check=(0,), padding_option="zero")
do = load_fn(do_ptrs, None, ld_offs_d, seqlen_q, head_dim)
# -- compute qk ----
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# TODO: These two checks can be optimized to occur on the last iter.
Expand All @@ -111,6 +113,7 @@ def bwd_kernel_dk_dv_common(
pass
elif BIAS_TYPE == 1:
# FIXME: do boundary_check correctly
# TODO: check q_padded is correct calculated, the condition should be start_n + BLOCK_M
"""
if q_padded and k_padded: # CAVEAT: using "or" disables the partial boundary_check branches
bias = tl.load(B_block_ptr, boundary_check=(0,1), padding_option="zero")
Expand Down Expand Up @@ -144,12 +147,12 @@ def bwd_kernel_dk_dv_common(
keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, max_seqlen_k)
# CAVEAT: do NOT update p, ds needs the original p
if BLOCK_M == 1:
dv += tl.where(keep, p / (1 - dropout_p), 0.0).to(Q_block_ptr.dtype.element_ty) * do
dv += tl.where(keep, p / (1 - dropout_p), 0.0).to(q_ptrs.dtype.element_ty) * do
else:
dv += tl.dot(tl.trans(tl.where(keep, p / (1 - dropout_p), 0.0)).to(Q_block_ptr.dtype.element_ty), do)
dv += tl.dot(tl.trans(tl.where(keep, p / (1 - dropout_p), 0.0)).to(q_ptrs.dtype.element_ty), do)
else:
if BLOCK_M == 1:
dv += p.to(Q_block_ptr.dtype.element_ty) * do
dv += p.to(q_ptrs.dtype.element_ty) * do
else:
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
Expand All @@ -163,22 +166,24 @@ def bwd_kernel_dk_dv_common(
ds = p * (dp - Di) # (BLOCK_M, BLOCK_N)
# compute dk
if BLOCK_M == 1:
dk += ds.to(Q_block_ptr.dtype.element_ty) * q
dk += ds.to(q_ptrs.dtype.element_ty) * q
else:
# ds.shape = (BLOCK_M, BLOCK_N), q.shape = (BLOCK_M, BLOCK_DMODEL)
dk += tl.dot(tl.trans(ds.to(Q_block_ptr.dtype.element_ty)), q) # (BLOCK_N, BLOCK_DMODEL)
# update pointers
Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0))
DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0)) # Debug DO accessing problems
dk += tl.dot(tl.trans(ds.to(q_ptrs.dtype.element_ty)), q) # (BLOCK_N, BLOCK_DMODEL)
# update pointers (block_ptr code was left intentionally as comment)
# Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0))
# DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0)) # Debug DO accessing problems
q_ptrs += q_stride * BLOCK_M
do_ptrs += do_stride * BLOCK_M
if BIAS_TYPE == 1:
B_block_ptr = tl.advance(B_block_ptr, (BLOCK_M, 0))
return (dk * sm_scale).to(DK_block_ptr.type.element_ty), dv.to(DV_block_ptr.type.element_ty)
return (dk * sm_scale).to(kt.type.element_ty), dv.to(vt.type.element_ty)

@triton.jit
def bwd_kernel_dq_db_common(
Q_block_ptr, K_block_ptr, V_block_ptr, B_block_ptr,
sm_scale, DO_block_ptr,
DQ_block_ptr, DB_block_ptr, store_db,
q, kt_ptrs, k_stride, vt_ptrs, v_stride, B_block_ptr,
sm_scale, do,
dq, DB_block_ptr, store_db,
l_ptrs,
D_ptrs,
seqlen_q,
Expand All @@ -197,37 +202,34 @@ def bwd_kernel_dq_db_common(
PADDED_HEAD: tl.constexpr,
BIAS_TYPE: tl.constexpr,
):
if start_m + BLOCK_N > seqlen_k:
k_padded = True
else:
k_padded = False
# initialize offsets
offs_m = start_m + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
ld_offs_d = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL)

# Check for OOB accesses on D and LSE
overflow_size_q = start_m + BLOCK_M - seqlen_q
boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size_q, dtype=tl.int32)
d_lse_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
d_lse_padding = tl.full((BLOCK_M, ), 0, dtype=tl.float32)
Di = tl.load(D_ptrs + offs_m, mask=d_lse_ptrs_mask, other=d_lse_padding)
l_i = tl.load(l_ptrs + offs_m, mask=d_lse_ptrs_mask, other=d_lse_padding)
if overflow_size_q > 0:
boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size_q, dtype=tl.int32)
d_lse_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
d_lse_padding = tl.full((BLOCK_M, ), 0, dtype=tl.float32)
Di = tl.load(D_ptrs + offs_m, mask=d_lse_ptrs_mask, other=d_lse_padding)
l_i = tl.load(l_ptrs + offs_m, mask=d_lse_ptrs_mask, other=d_lse_padding)
else:
Di = tl.load(D_ptrs + offs_m)
l_i = tl.load(l_ptrs + offs_m)

dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# loop over k, v
lo = 0
hi = min(start_m + BLOCK_M, seqlen_k) if CAUSAL else seqlen_k
if BIAS_TYPE == 1:
B_block_ptr = tl.advance(B_block_ptr, (lo, 0))

qk_scale = sm_scale * 1.44269504089
# load q and do: they will stay in SRAM throughout
if PADDED_HEAD:
q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero")
else:
q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero")
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
if PADDED_HEAD:
do = tl.load(DO_block_ptr, boundary_check=(0,1), padding_option="zero")
else:
do = tl.load(DO_block_ptr, boundary_check=(0,), padding_option="zero")
q = (q* qk_scale).to(q.type.element_ty)

'''
K1 K2 (d)V dO
Q1 qk11 qk12 (d)v1 dO1
Expand All @@ -244,12 +246,18 @@ def bwd_kernel_dq_db_common(
k_padded = False
# -- load k, v --
# shape = (BLOCK_DMODEL, BLOCK_N), offs = (0, BLOCK_N * iter) = (0, start_n)
if PADDED_HEAD:
kt = tl.load(K_block_ptr, boundary_check=(1,0), padding_option="zero")
vt = tl.load(V_block_ptr, boundary_check=(1,0), padding_option="zero")
# kt = tl.load(K_block_ptr)
# vt = tl.load(V_block_ptr)
if k_padded:
kt = load_fn(kt_ptrs, ld_offs_d, offs_n + start_n, head_dim, seqlen_k)
else:
kt = load_fn(kt_ptrs, ld_offs_d, None, head_dim, seqlen_k)

# TODO: pre_load_vt
if k_padded:
vt = load_fn(vt_ptrs, ld_offs_d, offs_n + start_n, head_dim, seqlen_k)
else:
kt = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero")
vt = tl.load(V_block_ptr, boundary_check=(1,), padding_option="zero")
vt = load_fn(vt_ptrs, ld_offs_d, None, head_dim, seqlen_k)
# -- compute qk ----
# q.offs = (start_m, 0), k.offs = (0, start_n)
qk = dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, q, kt)
Expand Down Expand Up @@ -292,18 +300,20 @@ def bwd_kernel_dq_db_common(
# compute dq. Unfortunately we cannot avoid transpose here as this loop
# uses k both normal and transpose.
if BLOCK_M == 1:
dq += tl.view(kt, [BLOCK_DMODEL]) * ds.to(Q_block_ptr.type.element_ty)
dq += tl.view(kt, [BLOCK_DMODEL]) * ds.to(q.type.element_ty)
else:
# ds.shape = (BLOCK_M, BLOCK_N), kt.shape = (BLOCK_DMODEL, BLOCK_N)
dq += tl.dot(ds.to(Q_block_ptr.type.element_ty), tl.trans(kt)) # (BLOCK_M, BLOCK_DMODEL)
dq += tl.dot(ds.to(q.type.element_ty), tl.trans(kt)) # (BLOCK_M, BLOCK_DMODEL)
if BIAS_TYPE == 1:
if store_db:
tl.store(DB_block_ptr, ds.to(DB_block_ptr.type.element_ty), boundary_check=(0,1))
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N))
# Keep the block ptr as comment
# K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
# V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N))
kt_ptrs += BLOCK_N * k_stride
vt_ptrs += BLOCK_N * v_stride
if BIAS_TYPE == 1:
B_block_ptr = tl.advance(B_block_ptr, (0, BLOCK_N))
DB_block_ptr = tl.advance(DB_block_ptr, (0, BLOCK_N))
return (dq * sm_scale).to(DQ_block_ptr.type.element_ty)
# tl.store(DQ_block_ptr, (dq * sm_scale).to(DQ_block_ptr.type.element_ty), boundary_check=(0,1))
return (dq * sm_scale).to(dq.type.element_ty)
Loading

0 comments on commit 7aa3e14

Please sign in to comment.