Skip to content

Commit

Permalink
Bug fix: incorrect positional offset when using ROPE with KV cache
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Oct 21, 2024
1 parent dd969e9 commit 50a0e4a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
8 changes: 4 additions & 4 deletions Seq2SeqSharp/Layers/GroupQueryAttention.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,29 +125,29 @@ public IWeightTensor Perform(IWeightTensor inputQ, IWeightTensor keyMask, int ba

string cachedKName = $"{m_name}_cached_K";
string cachedVName = $"{m_name}_cached_V";
int seqLenAll = seqLenQ;
int ropeOffset = 0;

if (contextTensors != null)
{
if (contextTensors.ContainsKey(cachedKName))
{
seqLenAll = (int)contextTensors[cachedKName].Sizes[2] + seqLenQ;
ropeOffset = (int)contextTensors[cachedKName].Sizes[2];
}
}

//Multi-head attentions
IWeightTensor Qs = g.View(g.AsContiguous(g.Transpose(allQ, 1, 2)), dims: new long[] { batchSize * m_num_heads, seqLenQ, m_head_dim });
if (m_PEType == PositionEmbeddingEnums.RoPE)
{
Qs = g.RoPE(Qs, seqLenAll, seqLenAll - seqLenQ);
Qs = g.RoPE(Qs, seqLenQ, ropeOffset);
}

int group_size = m_num_heads / m_num_kv_groups;
IWeightTensor Ks = null;
if (m_PEType == PositionEmbeddingEnums.RoPE)
{
Ks = g.View(g.AsContiguous(g.Transpose(allK, 1, 2)), dims: new long[] { batchSize * m_num_kv_groups, seqLenQ, m_head_dim });
Ks = g.RoPE(Ks, seqLenAll, seqLenAll - seqLenQ);
Ks = g.RoPE(Ks, seqLenQ, ropeOffset);
}
else
{
Expand Down
8 changes: 4 additions & 4 deletions Seq2SeqSharp/Layers/MultiHeadAttention.cs
Original file line number Diff line number Diff line change
Expand Up @@ -138,28 +138,28 @@ private IWeightTensor PerformClassic(IWeightTensor inputQ, IWeightTensor keyMask

string cachedKName = $"{m_name}_cached_K";
string cachedVName = $"{m_name}_cached_V";
int seqLenAll = seqLenQ;
int ropeOffset = 0;

if (contextTensors != null)
{
if(contextTensors.ContainsKey(cachedKName))
{
seqLenAll = (int)contextTensors[cachedKName].Sizes[2] + seqLenQ;
ropeOffset = (int)contextTensors[cachedKName].Sizes[2];
}
}

//Multi-head attentions
IWeightTensor Qs = g.View(g.AsContiguous(g.Transpose(allQ, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, seqLenQ, m_d });
if (m_PEType == PositionEmbeddingEnums.RoPE)
{
Qs = g.RoPE(Qs, seqLenAll, seqLenAll - seqLenQ);
Qs = g.RoPE(Qs, seqLenQ, ropeOffset);
}

IWeightTensor Ks = null;
if (m_PEType == PositionEmbeddingEnums.RoPE)
{
Ks = g.View(g.AsContiguous(g.Transpose(allK, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, -1, m_d });
Ks = g.RoPE(Ks, seqLenAll, seqLenAll - seqLenQ);
Ks = g.RoPE(Ks, seqLenQ, ropeOffset);
Ks = g.View(g.AsContiguous(g.Transpose(Ks, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, m_d, -1 });
}
else
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Seq2SeqSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
<PlatformTarget>AnyCPU</PlatformTarget>
<AppendTargetFrameworkToOutputPath>false</AppendTargetFrameworkToOutputPath>
<OutputPath>bin\</OutputPath>
<Version>2.8.15</Version>
<Version>2.8.16</Version>
<Description>Seq2SeqSharp is a tensor based fast &amp; flexible encoder-decoder deep neural network framework written by .NET (C#). It can be used for sequence-to-sequence task, sequence-labeling task and sequence-classification task and other NLP tasks. Seq2SeqSharp supports both CPUs (x86, x64 and ARM64) and GPUs. It's powered by .NET core, so Seq2SeqSharp can run on both Windows and Linux without any modification and recompilation.</Description>
<PackageReadmeFile>README.md</PackageReadmeFile>
<Title>Seq2SeqSharp</Title>
Expand Down

0 comments on commit 50a0e4a

Please sign in to comment.