Skip to content

Commit

Permalink
Added past_present_share_buffer to the hint. typo
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Nov 13, 2024
1 parent 993140b commit 7502493
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
(parameters.sequence_length_ + tile_size - 1) / tile_size,
parameters.batch_size_ * parameters.num_heads_)
.SetWorkgroupSize(tile_size, tile_size)
.CacheHint(std::to_string(tile_size), parameters.is_first_prompt_)
.CacheHint(std::to_string(tile_size), parameters.is_first_prompt_, parameters.past_present_share_buffer_)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(vectorized_head_size)},
{static_cast<uint32_t>(total_sequence_length)},
Expand Down Expand Up @@ -381,7 +381,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n"
<< " var idx = TILE_SIZE * local_id.y + local_id.x;\n";

if ((feed_past_value_ && has_present_value_) && past_present_share_buffer_) {
if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) {
shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n"
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n"
<< " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
Expand Down Expand Up @@ -452,7 +452,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int
program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size,
(parameters.sequence_length_ + tile_size - 1) / tile_size,
parameters.batch_size_ * parameters.num_heads_)
.CacheHint(std::to_string(tile_size), parameters.is_first_prompt_)
.CacheHint(std::to_string(tile_size), parameters.is_first_prompt_, parameters.past_present_share_buffer_)
.SetWorkgroupSize(tile_size, tile_size)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(total_sequence_length)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
std::vector<int64_t> present_kv_shape({static_cast<int64_t>(parameters.batch_size_), static_cast<int64_t>(kv_num_heads_), static_cast<int64_t>(present_kv_seqlen), static_cast<int64_t>(parameters.head_size_)});
Tensor* present_key = context.Output(1, present_kv_shape);
Tensor* present_value = context.Output(2, present_kv_shape);
parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw();
parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw();

TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_,
parameters.sequence_length_, parameters.head_size_});
Expand Down

0 comments on commit 7502493

Please sign in to comment.