diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index 4b23a99c8370e..587fe08476cca 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -45,6 +45,15 @@ def __init__( self.sliding_window = ((self.sliding_window, self.sliding_window) if self.sliding_window is not None else (-1, -1)) + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + tokens, n_kv_heads, head_dim = x.shape + return ( + x[:, :, None, :] + .expand(tokens, n_kv_heads, n_rep, head_dim) + .reshape(tokens, n_kv_heads * n_rep, head_dim) + ) + def forward( self, query: torch.Tensor, @@ -86,6 +95,11 @@ def forward( # Prompt run. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): + if self.num_kv_heads != self.num_heads: + # Interleave for MQA + key = self.repeat_kv(key, self.num_queries_per_kv) + value = self.repeat_kv(value, self.num_queries_per_kv) + # normal attention query = query.unflatten(0, (batch_size, seq_len)) key = key.unflatten(0, (batch_size, seq_len))