From 5f1fdaea667813ac3577bfb1fb3ea7562990cdc0 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 12 Nov 2024 20:00:00 -0800 Subject: [PATCH] past_present_share_buffer related changes. --- .../contrib_ops/webgpu/bert/attention.cc | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 2d552150c8284..8bd0a88c89e5d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -116,16 +116,20 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" << "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" << "let kOffset = abs_kv_head_idx * uniforms.kv_sequence_length * uniforms.K;\n"; - if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { + if (feed_past_key_ && has_present_key_) { shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.past_sequence_length * uniforms.K;\n"; + } else if (past_present_share_buffer_) { + shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.present_sequence_length * uniforms.K;\n"; } if (has_present_key_) { shader.MainFunctionBody() << "let presentKeyOffset = abs_kv_head_idx * uniforms.N * uniforms.K;\n"; } } else { shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n"; - if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) { + if (feed_past_key_ && has_present_key_ || past_present_share_buffer_) { shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.past_sequence_length * uniforms.K;\n"; + } else if (past_present_share_buffer_) { + shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.present_sequence_length * uniforms.K;\n"; } if (has_present_key_) { shader.MainFunctionBody() << "let presentKeyOffset = workgroup_id.z * uniforms.N * uniforms.K;\n"; @@ -154,9 +158,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_present_key_) { if (past_present_share_buffer_) { - shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.present_sequence_length) {\n"; + shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; } else { - shader.MainFunctionBody() << " if (n + local_id.y < uniforms.present_sequence_length) {\n"; + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; } shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n" << " }\n"; @@ -355,8 +359,10 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { << "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n" << "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n" << "let vOffset = abs_kv_head_idx * uniforms.N * uniforms.kv_sequence_length + n;\n"; - if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) { + if (feed_past_value_ && has_present_value_) { shader.MainFunctionBody() << "let pastValueOffset = abs_kv_head_idx * uniforms.N * uniforms.past_sequence_length + n;\n"; + } else if (past_present_share_buffer_) { + shader.MainFunctionBody() << "let pastValueOffset = abs_kv_head_idx * uniforms.N * uniforms.present_sequence_length + n;\n"; } if (has_present_value_) { @@ -364,8 +370,10 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { } } else { shader.MainFunctionBody() << "let vOffset = workgroup_id.z * uniforms.N * uniforms.kv_sequence_length + n;\n"; - if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) { + if (feed_past_value_ && has_present_value_) { shader.MainFunctionBody() << "let pastValueOffset = workgroup_id.z * uniforms.N * uniforms.past_sequence_length + n;\n"; + } else if (past_present_share_buffer_) { + shader.MainFunctionBody() << "let pastValueOffset = workgroup_id.z * uniforms.N * uniforms.present_sequence_length + n;\n"; } if (has_present_value_) { @@ -395,9 +403,9 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_present_value_) { if (past_present_share_buffer_) { - shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.present_sequence_length) {\n"; + shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; } else { - shader.MainFunctionBody() << " if (w + local_id.y < uniforms.present_sequence_length) {\n"; + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n"; } shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n" << " }\n"; @@ -475,7 +483,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; - const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_; +const int total_sequence_length = seqlen_k == nullptr ? (past_sequence_length + parameters.kv_sequence_length_) : parameters.seqlen_present_kv_cache_; const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, total_sequence_length});