Skip to content

Commit

Permalink
past_present_share_buffer related changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Nov 13, 2024
1 parent 7502493 commit 5f1fdae
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,20 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n"
<< "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n"
<< "let kOffset = abs_kv_head_idx * uniforms.kv_sequence_length * uniforms.K;\n";
if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) {
if (feed_past_key_ && has_present_key_) {
shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.past_sequence_length * uniforms.K;\n";
} else if (past_present_share_buffer_) {
shader.MainFunctionBody() << "let pastKeyOffset = abs_kv_head_idx * uniforms.present_sequence_length * uniforms.K;\n";
}
if (has_present_key_) {
shader.MainFunctionBody() << "let presentKeyOffset = abs_kv_head_idx * uniforms.N * uniforms.K;\n";
}
} else {
shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n";
if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) {
if (feed_past_key_ && has_present_key_ || past_present_share_buffer_) {
shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.past_sequence_length * uniforms.K;\n";
} else if (past_present_share_buffer_) {
shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.present_sequence_length * uniforms.K;\n";
}
if (has_present_key_) {
shader.MainFunctionBody() << "let presentKeyOffset = workgroup_id.z * uniforms.N * uniforms.K;\n";
Expand Down Expand Up @@ -154,9 +158,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {

if (has_present_key_) {
if (past_present_share_buffer_) {
shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.present_sequence_length) {\n";
shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n";
} else {
shader.MainFunctionBody() << " if (n + local_id.y < uniforms.present_sequence_length) {\n";
shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n";
}
shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n"
<< " }\n";
Expand Down Expand Up @@ -355,17 +359,21 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< "let kv_num_heads = uniforms.num_heads / uniforms.n_reps;\n"
<< "let abs_kv_head_idx = batch_idx * kv_num_heads + kv_head_idx;\n"
<< "let vOffset = abs_kv_head_idx * uniforms.N * uniforms.kv_sequence_length + n;\n";
if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) {
if (feed_past_value_ && has_present_value_) {
shader.MainFunctionBody() << "let pastValueOffset = abs_kv_head_idx * uniforms.N * uniforms.past_sequence_length + n;\n";
} else if (past_present_share_buffer_) {
shader.MainFunctionBody() << "let pastValueOffset = abs_kv_head_idx * uniforms.N * uniforms.present_sequence_length + n;\n";
}

if (has_present_value_) {
shader.MainFunctionBody() << "let presentValueOffset = abs_kv_head_idx * uniforms.N * uniforms.K + n;\n";
}
} else {
shader.MainFunctionBody() << "let vOffset = workgroup_id.z * uniforms.N * uniforms.kv_sequence_length + n;\n";
if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) {
if (feed_past_value_ && has_present_value_) {
shader.MainFunctionBody() << "let pastValueOffset = workgroup_id.z * uniforms.N * uniforms.past_sequence_length + n;\n";
} else if (past_present_share_buffer_) {
shader.MainFunctionBody() << "let pastValueOffset = workgroup_id.z * uniforms.N * uniforms.present_sequence_length + n;\n";
}

if (has_present_value_) {
Expand Down Expand Up @@ -395,9 +403,9 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {

if (has_present_value_) {
if (past_present_share_buffer_) {
shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.present_sequence_length) {\n";
shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n";
} else {
shader.MainFunctionBody() << " if (w + local_id.y < uniforms.present_sequence_length) {\n";
shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n";
}
shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n"
<< " }\n";
Expand Down Expand Up @@ -475,7 +483,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});

Check warning on line 484 in onnxruntime/contrib_ops/webgpu/bert/attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for min [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/attention.cc:484: Add #include <algorithm> for min [build/include_what_you_use] [4]
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_;
const int total_sequence_length = seqlen_k == nullptr ? (past_sequence_length + parameters.kv_sequence_length_) : parameters.seqlen_present_kv_cache_;

const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_,
parameters.sequence_length_, total_sequence_length});
Expand Down

0 comments on commit 5f1fdae

Please sign in to comment.