Skip to content

Commit

Permalink
[Relax][KV Cache] Refactor _attention_sequence_prefill function to …
Browse files Browse the repository at this point in the history
…handle dynamic `batch_size` in TIR

This PR removes `batch_size` from the function signature, instead mapping it within the function body.
  • Loading branch information
mengshyu authored and MasterJH5574 committed Sep 11, 2024
1 parent e1e0dc2 commit e0ef1c9
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ def merge_state_inplace(


def _attention_sequence_prefill(
batch_size, h_kv, h_q, d, dtype, target: Target, causal=0, attn_score_scaling_factor=1.0
h_kv, h_q, d, dtype, target: Target, causal=0, attn_score_scaling_factor=1.0
): # pylint: disable=line-too-long
LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes
group_size = h_q // h_kv
Expand All @@ -1264,6 +1264,7 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches
var_output: T.handle, # [total_len, h_q, d]
var_lse: T.handle # [total_len, h_q]
):
batch_size = T.int32(is_size_var=True)
qo_len = T.int32(is_size_var=True)
kv_len = T.int32(is_size_var=True)
q = T.match_buffer(var_q, (batch_size, qo_len, h_q, d), dtype)
Expand Down

0 comments on commit e0ef1c9

Please sign in to comment.