Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flash_attn_varlen_func_with_kvcache. #685

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

garrett4wade
Copy link

@garrett4wade garrett4wade commented Nov 22, 2023

Implementing a new function called flash_attn_varlen_func_with_kvcache, which behaves similar to flash_attn_with_kvcache but allows variable-length q/k/v inputs.

Aim for Implementing This Feature

This enables inflight batching(#672) , i.e., during generation, when one sequence completes with the EOS token, we put a new prompt at this position such as the PAD token will not accupy computation bandwidth. This technique results in variable-length inputs during each generation step, (e.g., [1, 1, 3, 1] if the new prompt is length 3).

Enabling inflight batching in flash-attn benefits LLM RLHF which requires high generation throughput (no pad in this case) while updating model parameters at the same time (we can't use inference libraries like vLLM because we need to synchronize parameters after each train step, which is expensive).

Usage

q/k/v are packed 1D tensors. q should be passed in together with a cu_seqlens_q and max_seqlen_q similar to flash_attn_varlen_func. k and v are optional arguments, which should be passed in with cu_seqlens_k. max_seqlen_k is determined by kv cache. kv cache will be updated in-place. Use kv cache only if k and v are not passed in.

k_cache and v_cache still have shape (batch_size_cache, seqlen_cache, n_heads, head_dim).

Major Changes:

  • Adding mha_varlen_fwd_kvcache in csrc/flash_attn/flash_api.cpp. This function is similar to mha_fwd_kvcache but sets forward params appropriately to deal with variable-length inputs. block_info.h is also changed accordingly.
  • Adding flash_attn_varlen_func_with_kvcache in flash_attn/flash_attn_interface.py.
  • Adding test_flash_attn_varlen_kvcache in tests/test_flash_attn.py.

Minor Changes:

  • Add a make clean option to remove compilation cache. I found weird bugs when compile the code again without make clean .
  • Minor refactor of mha_fwd_kvcache. Set cache length to the k_cache_seqlens attribute of forward params instead of using cu_seqlens_k and is_seqlens_k_cumulative. Add a new attribute called cu_seqlens_knew to distinguish it from cu_seqlens_k. block_info.h and flash.h are changed accordingly.

Limitation

This new function by default sets num_splits=1, which may hurt performance, but currently I don't understand how to fix this.


Edit 2023.11.23: Append test result.

image

@shcho1118
Copy link
Contributor

@garrett4wade
If this PR applies, could it be used in place of vllm's paged attention v2 kernel?
The problem I have right now is that I can't use the splitkv kernel in varlen funcs in the decoding phase.

@garrett4wade
Copy link
Author

@shcho1118 Hi cho, I don't think this PR can help because I forcely set num_splits=1 for varlen func with kv cache, although this can be potentially fixed in the future.

@sgrigory
Copy link
Contributor

sgrigory commented Jan 6, 2024

@shcho1118 Hi cho, I don't think this PR can help because I forcely set num_splits=1 for varlen func with kv cache, although this can be potentially fixed in the future.

@garrett4wade @shcho1118 FYI, I'm adding split-kv to mha_varlen_fwd for decoding in #754

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants