-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add flash_attn_varlen_func_with_kvcache
.
#685
base: main
Are you sure you want to change the base?
Conversation
1. add varlen supporting of MHA layer 2. add varlen supporting of ApplyRotaryEmbQKV_ / ApplyRotaryEmbKV_ (with test code) 3. fix wrong spell
@garrett4wade |
@shcho1118 Hi cho, I don't think this PR can help because I forcely set num_splits=1 for varlen func with kv cache, although this can be potentially fixed in the future. |
@garrett4wade @shcho1118 FYI, I'm adding split-kv to |
Implementing a new function called
flash_attn_varlen_func_with_kvcache
, which behaves similar toflash_attn_with_kvcache
but allows variable-length q/k/v inputs.Aim for Implementing This Feature
This enables inflight batching(#672) , i.e., during generation, when one sequence completes with the EOS token, we put a new prompt at this position such as the PAD token will not accupy computation bandwidth. This technique results in variable-length inputs during each generation step, (e.g., [1, 1, 3, 1] if the new prompt is length 3).
Enabling inflight batching in flash-attn benefits LLM RLHF which requires high generation throughput (no pad in this case) while updating model parameters at the same time (we can't use inference libraries like vLLM because we need to synchronize parameters after each train step, which is expensive).
Usage
q/k/v are packed 1D tensors.
q
should be passed in together with acu_seqlens_q
andmax_seqlen_q
similar toflash_attn_varlen_func
.k
andv
are optional arguments, which should be passed in withcu_seqlens_k
.max_seqlen_k
is determined by kv cache. kv cache will be updated in-place. Use kv cache only ifk
andv
are not passed in.k_cache
andv_cache
still have shape (batch_size_cache, seqlen_cache, n_heads, head_dim).Major Changes:
mha_varlen_fwd_kvcache
incsrc/flash_attn/flash_api.cpp
. This function is similar tomha_fwd_kvcache
but sets forward params appropriately to deal with variable-length inputs.block_info.h
is also changed accordingly.flash_attn_varlen_func_with_kvcache
inflash_attn/flash_attn_interface.py
.test_flash_attn_varlen_kvcache
intests/test_flash_attn.py
.Minor Changes:
make clean
option to remove compilation cache. I found weird bugs when compile the code again withoutmake clean
.mha_fwd_kvcache
. Set cache length to thek_cache_seqlens
attribute of forward params instead of usingcu_seqlens_k
andis_seqlens_k_cumulative
. Add a new attribute calledcu_seqlens_knew
to distinguish it fromcu_seqlens_k
.block_info.h
andflash.h
are changed accordingly.Limitation
This new function by default sets
num_splits=1
, which may hurt performance, but currently I don't understand how to fix this.Edit 2023.11.23: Append test result.