diff --git a/tencentpretrain/layers/multi_headed_attn.py b/tencentpretrain/layers/multi_headed_attn.py index 128ff66..3a7d52a 100755 --- a/tencentpretrain/layers/multi_headed_attn.py +++ b/tencentpretrain/layers/multi_headed_attn.py @@ -94,9 +94,9 @@ def unshape(x): query, key, value = [linear_layer(x) for linear_layer, x in zip(self.linear_layers, [query, key, value])] - query = query.view(batch_size, seq_length, heads_num, per_head_size) - key = key.view(batch_size, seq_length, self.local_kv_heads_num, per_head_size) - value = value.view(batch_size, seq_length, self.local_kv_heads_num, per_head_size) + query = query.view(batch_size, -1, heads_num, per_head_size) + key = key.view(batch_size, -1, self.local_kv_heads_num, per_head_size) + value = value.view(batch_size, -1, self.local_kv_heads_num, per_head_size) query = query.transpose(1, 2) key = repeat_kv(key, self.repeat_num).transpose(1, 2)