You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Support for var-len + paged KV was added in #831 and I confirmed that it works when num_splits = 1. But if I force the kernel to use num_splits > 1, I get incorrect results.
Support for var-len + paged KV was added in #831 and I confirmed that it works when
num_splits = 1
. But if I force the kernel to usenum_splits > 1
, I get incorrect results.Looking at the code, I suspect var-len + paged KV + split KV is not currently supported. First, the storage representation of partial reductions is very wasteful for the var-len case:
https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp#L316
Second, the way the output pointer is indexed in
combine_attn_seqk_parallel
seems to assume the fixed query-length format (batch_size x seqlen_q x num_heads x head_size
). Shouldn't we usebinfo.q_offset
there?https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_fwd_kernel.h#L1273-L1274
@tridao @sgrigory
The text was updated successfully, but these errors were encountered: