Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU EP] Support GroupQueryAttention #22658

Open
wants to merge 32 commits into
base: main
Choose a base branch
from

Conversation

satyajandhyala
Copy link
Contributor

Description

Motivation and Context

const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0;
const bool has_present_key = output_count > 1 && past_key;
const bool has_attention_bias = attention_bias != nullptr;
const int tile_size = 12;

Check warning

Code scanning / PREfast

The const variable 'tile_size' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'tile_size' can be computed at compile-time. Consider using constexpr (con.5).
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

onnxruntime/contrib_ops/webgpu/bert/attention.cc Outdated Show resolved Hide resolved
@satyajandhyala satyajandhyala marked this pull request as ready for review November 1, 2024 19:28
present_value, parameters, context, seqlen_k, total_seqlen_tensor);
}
TensorShape k_new_shape(k_new_dims);
Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape);
Copy link
Contributor

@skottmckay skottmckay Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line causes a segfault with these models: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile as they have GQA nodes that do not have the optional key and value inputs so the Tensor* is a nullptr.

.Input(1,
"key",
"Key with shape (batch_size, kv_sequence_length, kv_hidden_size) ",
"T",
OpSchema::Optional)
.Input(2,
"value",
"Value with shape (batch_size, kv_sequence_length, kv_hidden_size)",
"T",
OpSchema::Optional)
.Input(3,
"past_key",
"past state key with support for format BNSH. When past_key uses same tensor as present_key"
"(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.",
"T",
OpSchema::Optional)
.Input(4,
"past_value",
"past state value with support for format BNSH. When past_value uses same tensor as present_value"
"(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.",
"T",
OpSchema::Optional)

WebgpuAttentionParameters is not copying the value of is_packed_qkv_ from GroupQueryAttentionParameters

Copy link
Contributor Author

@satyajandhyala satyajandhyala Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to correctly initialize is_packed_qkv_

Comment on lines +17 to +19
struct WebgpuAttentionParameters {
WebgpuAttentionParameters(AttentionParameters parameters) : is_gqa_parameters_(false),
batch_size_(parameters.batch_size),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason WebGPU needs a parameters struct that combines AttentionParameters and GroupQueryAttentionParameters? Feels a little confusing to merge those and wondering why it's necessary if we don't need to do that for other EPs that implement these operators.

Copy link
Contributor Author

@satyajandhyala satyajandhyala Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to avoid code duplication. I refactored code into attention used by both GQA and MHA. The CPU version has GQA separate implementation. group_query_attention_helper::CheckInputs() and AttentionBase::CheckInputs output different structs, GroupQueryAttentionParameters and AttentionParameters respectively. WebGPU parameters is a union of these to structs.

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Nov 6, 2024
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment on lines 126 to 130
shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n";
if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) {
shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.past_sequence_length * uniforms.K;\n";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n";
if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) {
shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.past_sequence_length * uniforms.K;\n";
shader.MainFunctionBody() << "let kOffset = workgroup_id.z * uniforms.kv_sequence_length * uniforms.K;\n";
if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) {
shader.MainFunctionBody() << "let pastKeyOffset = workgroup_id.z * uniforms.past_sequence_length * uniforms.K;\n";

Comment on lines 144 to 153
shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n"
" tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
" } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
" tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
" }\n";
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n"
" tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
" } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
" tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
" }\n";
} else {
shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n"
" tileK[idx] = "
<< (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
" } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
" tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
" }\n";
} else {

Comment on lines 433 to 443
const Tensor* seqlen_k) {
const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0 && !parameters.past_present_share_buffer_;
const bool has_present_value = output_count > 1 && past_value != nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const Tensor* seqlen_k) {
const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0 && !parameters.past_present_share_buffer_;
const bool has_present_value = output_count > 1 && past_value != nullptr;
const Tensor* seqlen_k) {
const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0 && !parameters.past_present_share_buffer_;
const bool has_present_value = output_count > 1 && past_value != nullptr;

Comment on lines 26 to 30
.TypeConstraint("T", WebGpuSupportedFloatTypes())
.MayInplace(3, 1)
.MayInplace(4, 2)
.InputMemoryType(OrtMemTypeCPUInput, 6),
GroupQueryAttention);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.TypeConstraint("T", WebGpuSupportedFloatTypes())
.MayInplace(3, 1)
.MayInplace(4, 2)
.InputMemoryType(OrtMemTypeCPUInput, 6),
GroupQueryAttention);
.TypeConstraint("T", WebGpuSupportedFloatTypes())
.MayInplace(3, 1)
.MayInplace(4, 2)
.InputMemoryType(OrtMemTypeCPUInput, 6),
GroupQueryAttention);

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment on lines 145 to 153
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
" } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
" tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
" }\n";
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
" } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
" tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
" }\n";
} else {
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
" } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
" tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
" }\n";
} else {

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment on lines 485 to 487
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
const int total_sequence_length = seqlen_k == nullptr ? (past_sequence_length + parameters.kv_sequence_length_) : parameters.seqlen_present_kv_cache_;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
const int total_sequence_length = seqlen_k == nullptr ? (past_sequence_length + parameters.kv_sequence_length_) : parameters.seqlen_present_kv_cache_;
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
const int total_sequence_length = seqlen_k == nullptr ? (past_sequence_length + parameters.kv_sequence_length_) : parameters.seqlen_present_kv_cache_;

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment on lines 222 to 224

const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components;
program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components;
program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size,
const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components;
program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size,

Comment on lines 305 to 307
int work_group_size = 64;
const int total_sequence_length_comp = (total_sequence_length + components -1) / components;
if (total_sequence_length_comp < work_group_size) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int work_group_size = 64;
const int total_sequence_length_comp = (total_sequence_length + components -1) / components;
if (total_sequence_length_comp < work_group_size) {
int work_group_size = 64;
const int total_sequence_length_comp = (total_sequence_length + components - 1) / components;
if (total_sequence_length_comp < work_group_size) {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants