You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
clear only V is unsafe because NAN*0 = NAN. In IEEE754 standard, NaN has propagation.
If the default value of SMEM is NAN, it will result in the output O also containing NAN values.
this issue occurred during my saturation test.
The text was updated successfully, but these errors were encountered:
https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_fwd_kernel.h#L267-L269
should be modified to
flash::copy<Is_even_MN, Is_even_K, /Clear_OOB_MN=/true>(
gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN);
clear only V is unsafe because NAN*0 = NAN. In IEEE754 standard, NaN has propagation.
If the default value of SMEM is NAN, it will result in the output O also containing NAN values.
this issue occurred during my saturation test.
The text was updated successfully, but these errors were encountered: