From d5b09d1a85ee2965c3503c8b96a855608caae110 Mon Sep 17 00:00:00 2001 From: PengWeixuan <145038191+PengWeixuan@users.noreply.github.com> Date: Mon, 14 Oct 2024 19:56:30 +0800 Subject: [PATCH] Fix comment errors After reading kv_cache, keys and values should have the dimensions of (bs, cache_len + seqlen, n_local_heads, head_dim), rather than (bs, seqlen, n_local_heads, head_dim). --- llama/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama/model.py b/llama/model.py index 90b535b6..a7bc0865 100644 --- a/llama/model.py +++ b/llama/model.py @@ -179,8 +179,8 @@ def forward( values = self.cache_v[:bsz, : start_pos + seqlen] # repeat k/v heads if n_kv_heads < n_heads - keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) - values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) keys = keys.transpose(1, 2)