From a67c31591d4ff59df68839561a415a22095c4074 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Sat, 1 Feb 2025 17:53:48 -0800 Subject: [PATCH] kv_splits heuristic Signed-off-by: simon-mo --- vllm/attention/backends/triton_mla.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 95dc119a47bb5..761b5cc82dc6c 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -255,7 +255,7 @@ class TritonMLAMetadata(MLACommonMetadata): num_prefill_tokens: int - num_kv_splits: int = 4 # TODO(lucas) add heuristic + num_kv_splits: int = 8 # TODO(lucas) add heuristic attn_logits: Optional[torch.Tensor] = None req_idx: Optional[torch.Tensor] = None @@ -319,7 +319,8 @@ def prefill_metadata(self) -> Optional["TritonMLAMetadata"]: context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - head_dim=self.head_dim) + head_dim=self.head_dim, + num_kv_splits=self.num_kv_splits) return self._cached_prefill_metadata @property @@ -366,7 +367,8 @@ def decode_metadata(self) -> Optional["TritonMLAMetadata"]: block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, input_positions=input_positions, - head_dim=self.head_dim) + head_dim=self.head_dim, + num_kv_splits=self.num_kv_splits) return self._cached_decode_metadata def advance_step(self, @@ -621,6 +623,17 @@ def build(self, seq_lens: List[int], query_lens: List[int], for modality, placeholder_map in self.multimodal_placeholder_maps.items() } + + # Rough heuristic for num_kv_splits + # not empirically validated + if max_decode_seq_len < 512: + num_kv_splits = 2 + elif max_decode_seq_len < 512: + num_kv_splits = 4 + elif max_decode_seq_len < 4096: + num_kv_splits = 8 + else: + num_kv_splits = 16 return TritonMLAMetadata( num_prefills=self.num_prefills, @@ -641,7 +654,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, - num_kv_splits=4, # TODO(lucas) add heuristic + num_kv_splits=8, # TODO(lucas) add heuristic head_dim=self.runner.model_config.get_head_size(), )