diff --git a/tritonsrc/bwd_inner_dk_dv.py b/tritonsrc/bwd_inner_dk_dv.py index 8d4b0bb..acc26f3 100644 --- a/tritonsrc/bwd_inner_dk_dv.py +++ b/tritonsrc/bwd_inner_dk_dv.py @@ -6,6 +6,7 @@ import triton.language as tl from dropout import dropout_mask, dropout_rng, dropout_offsets from masked_load_store import load_fn +from triton.language.extra import libdevice # Helper function, but not always usable due to compiler bugs (esp. used with tl.trans) @triton.jit @@ -138,6 +139,10 @@ def bwd_inner_dk_dv( mask=d_lse_ptrs_mask[:,None], other=d_lse_padding[:, None]) p = tl.math.exp2(qk_scale * qk - l_i) # (BLOCK_M, BLOCK_N) + + if not FULL_BLOCKS or CAUSAL: + if qk_scale == 0.0: + p = tl.where(libdevice.isnan(p), 0.0, p) # -- compute dv ---- if ENABLE_DROPOUT: philox_offset = batch_philox_offset + start_q * max_seqlen_k + start_k diff --git a/tritonsrc/bwd_inner_dq.py b/tritonsrc/bwd_inner_dq.py index 4553de0..845cce1 100644 --- a/tritonsrc/bwd_inner_dq.py +++ b/tritonsrc/bwd_inner_dq.py @@ -6,6 +6,7 @@ import triton.language as tl from dropout import dropout_mask, dropout_rng, dropout_offsets from masked_load_store import load_fn +from triton.language.extra import libdevice # Helper function, but not always usable due to compiler bugs (esp. used with tl.trans) @triton.jit @@ -111,6 +112,11 @@ def bwd_inner_dq( else: tl.static_assert(False, f'Unsupported BIAS_TYPE {BIAS_TYPE}') p = tl.math.exp2(qk_scale * qk - l_i[:, None]) + + if not FULL_BLOCKS or CAUSAL: + if qk_scale == 0.0: + p = tl.where(libdevice.isnan(p), 0.0, p) + # compute dp = dot(v, do) dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) dp += dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, do, vt) diff --git a/tritonsrc/fwd_kernel_inner.py b/tritonsrc/fwd_kernel_inner.py index d8632f4..100d008 100644 --- a/tritonsrc/fwd_kernel_inner.py +++ b/tritonsrc/fwd_kernel_inner.py @@ -5,6 +5,7 @@ import triton.language as tl from dropout import dropout_mask, dropout_rng, dropout_offsets from masked_load_store import load_fn, mstore2d +from triton.language.extra import libdevice @triton.jit def attn_fwd_inner( @@ -94,8 +95,14 @@ def attn_fwd_inner( # softmax m_ij = tl.maximum(m_i, qk_scale * tl.max(qk, 1)) + # FIXME: when sm_scale = 0.0 and MASK_STEPS/CAUSAL = True + # qk * qk_scale = nan p = tl.math.exp2(qk * qk_scale - m_ij[:, None]) + if MASK_STEPS or CAUSAL: + if qk_scale == 0.0: + p = tl.where(libdevice.isnan(p), 0.0, p) + # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: