Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/jpvillam/v0.3.3_triton' into int…
Browse files Browse the repository at this point in the history
…egration_no_fp8
  • Loading branch information
gshtras committed Mar 22, 2024
2 parents 0a2309a + 1fff99a commit 1ec6554
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions vllm/model_executor/layers/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ def __init__(
self.sliding_window = ((self.sliding_window, self.sliding_window) if
self.sliding_window is not None else (-1, -1))

def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens, n_kv_heads, head_dim = x.shape
return (
x[:, :, None, :]
.expand(tokens, n_kv_heads, n_rep, head_dim)
.reshape(tokens, n_kv_heads * n_rep, head_dim)
)

def forward(
self,
query: torch.Tensor,
Expand Down Expand Up @@ -86,6 +95,11 @@ def forward(
# Prompt run.
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
# Interleave for MQA
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)

# normal attention
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
Expand Down

0 comments on commit 1ec6554

Please sign in to comment.