-
Notifications
You must be signed in to change notification settings - Fork 28
Extend FlashAttention Prefill with KV cache #318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: sycl-develop
Are you sure you want to change the base?
Extend FlashAttention Prefill with KV cache #318
Conversation
Great Job Min Jean! |
Refer to sglang/test/srt/test_triton_attention_kernels.py, |
@sunjiweiswift Thanks for good suggestion. This is non-contiguous input of cached KV for this feature and the major part of code is same. |
Yes agree with you @pengzhao-intel . We need to keep the example as simple as possible and leave the rest in sglang. If there is aby performance regression there, we would be able to offer help. |
Co-authored-by: Mehdi Goli <[email protected]>
Co-authored-by: Mehdi Goli <[email protected]>
Co-authored-by: Mehdi Goli <[email protected]>
770669a
to
72edfc0
Compare
@min-jean-cho I put the prefetch for the cached version, I also fixed some index/strides as well |
@mehdi-goli, looks good. Thanks for the update! |
Hi @mehdi-goli, any further comments on this? Thanks. |
@mehdi-goli @muhammad-tanvir-1211 please help approve and merge this PR and we have further works based on this :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the contribution!
applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp
Outdated
Show resolved
Hide resolved
applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp
Outdated
Show resolved
Hide resolved
Co-authored-by: Tadej Ciglarič <[email protected]>
Co-authored-by: Tadej Ciglarič <[email protected]>
Co-authored-by: Tadej Ciglarič <[email protected]>
Co-authored-by: Tadej Ciglarič <[email protected]>
Co-authored-by: Tadej Ciglarič <[email protected]>
int offset_k_cache = num_heads_kv * head_size_qk * seq_len_kv_cache; | ||
int offset_v_cache = num_heads_kv * head_size_vo * seq_len_kv_cache; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we consider the cached key-value pairs to be the same across all batches
? My understanding is that each batch would have it's seq_len
for the cached keys and values, which would mean that seq_len_kv_cache
would also be of Variable Length
type (same as seq_len_qo
and seq_len_kv
). This code would potentially give out of bound access because it is missing the multiplication with l_coord
(if we want to keep seq_len_kv_cache
fixed length), or a multiplication with kv_cache_cumulative_length[l_coord]
(if we want to change the type to Variable Length
)
Moved to #331. As discussed offline, separating without KV cache vs. with KV cache into separate pipelines. Created a new PR rather than update in here to keep the difference clear. |
Moved to #331.
This extends FlashAttention prefill with cached KV in addition to current KV (blue box in the below figure). Both causal and non-causal are supported.
