Skip to content

Commit

Permalink
Convert NaN created by 0.0 (from sm_scale) * -inf (from masking) to zero
Browse files Browse the repository at this point in the history
This fixes #47
  • Loading branch information
xinyazhang committed Oct 7, 2024
1 parent 88e7d79 commit fb8cb14
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tritonsrc/bwd_inner_dk_dv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import triton.language as tl
from dropout import dropout_mask, dropout_rng, dropout_offsets
from masked_load_store import load_fn
from triton.language.extra import libdevice

# Helper function, but not always usable due to compiler bugs (esp. used with tl.trans)
@triton.jit
Expand Down Expand Up @@ -138,6 +139,10 @@ def bwd_inner_dk_dv(
mask=d_lse_ptrs_mask[:,None],
other=d_lse_padding[:, None])
p = tl.math.exp2(qk_scale * qk - l_i) # (BLOCK_M, BLOCK_N)

if not FULL_BLOCKS or CAUSAL:
if qk_scale == 0.0:
p = tl.where(libdevice.isnan(p), 0.0, p)
# -- compute dv ----
if ENABLE_DROPOUT:
philox_offset = batch_philox_offset + start_q * max_seqlen_k + start_k
Expand Down
6 changes: 6 additions & 0 deletions tritonsrc/bwd_inner_dq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import triton.language as tl
from dropout import dropout_mask, dropout_rng, dropout_offsets
from masked_load_store import load_fn
from triton.language.extra import libdevice

# Helper function, but not always usable due to compiler bugs (esp. used with tl.trans)
@triton.jit
Expand Down Expand Up @@ -111,6 +112,11 @@ def bwd_inner_dq(
else:
tl.static_assert(False, f'Unsupported BIAS_TYPE {BIAS_TYPE}')
p = tl.math.exp2(qk_scale * qk - l_i[:, None])

if not FULL_BLOCKS or CAUSAL:
if qk_scale == 0.0:
p = tl.where(libdevice.isnan(p), 0.0, p)

# compute dp = dot(v, do)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
dp += dot(BLOCK_M, BLOCK_DMODEL, BLOCK_DMODEL, do, vt)
Expand Down
7 changes: 7 additions & 0 deletions tritonsrc/fwd_kernel_inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton.language as tl
from dropout import dropout_mask, dropout_rng, dropout_offsets
from masked_load_store import load_fn, mstore2d
from triton.language.extra import libdevice

@triton.jit
def attn_fwd_inner(
Expand Down Expand Up @@ -94,8 +95,14 @@ def attn_fwd_inner(

# softmax
m_ij = tl.maximum(m_i, qk_scale * tl.max(qk, 1))
# FIXME: when sm_scale = 0.0 and MASK_STEPS/CAUSAL = True
# qk * qk_scale = nan
p = tl.math.exp2(qk * qk_scale - m_ij[:, None])

if MASK_STEPS or CAUSAL:
if qk_scale == 0.0:
p = tl.where(libdevice.isnan(p), 0.0, p)

# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT:
Expand Down

0 comments on commit fb8cb14

Please sign in to comment.