From 1c09bcccfe4f7206919f15ba9b36743138310aee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Li=20Yudong=20=28=E6=9D=8E=E7=85=9C=E4=B8=9C=29?= Date: Wed, 11 Oct 2023 21:28:19 +0800 Subject: [PATCH] Fix a bug for gqa (#101) * test * test * test --- tencentpretrain/layers/multi_headed_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tencentpretrain/layers/multi_headed_attn.py b/tencentpretrain/layers/multi_headed_attn.py index 128ff66..3a7d52a 100755 --- a/tencentpretrain/layers/multi_headed_attn.py +++ b/tencentpretrain/layers/multi_headed_attn.py @@ -94,9 +94,9 @@ def unshape(x): query, key, value = [linear_layer(x) for linear_layer, x in zip(self.linear_layers, [query, key, value])] - query = query.view(batch_size, seq_length, heads_num, per_head_size) - key = key.view(batch_size, seq_length, self.local_kv_heads_num, per_head_size) - value = value.view(batch_size, seq_length, self.local_kv_heads_num, per_head_size) + query = query.view(batch_size, -1, heads_num, per_head_size) + key = key.view(batch_size, -1, self.local_kv_heads_num, per_head_size) + value = value.view(batch_size, -1, self.local_kv_heads_num, per_head_size) query = query.transpose(1, 2) key = repeat_kv(key, self.repeat_num).transpose(1, 2)