-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WebGPU EP] Support GroupQueryAttention #22658
base: main
Are you sure you want to change the base?
Changes from all commits
0a5d212
5bfa070
e6615e9
449afb4
8d10472
4ea58d1
4bcf257
5c5c934
e716546
aba59e5
067ecd1
53f1c78
f4dc9fc
3d1af1c
2eaeebc
9c828cc
26caa06
64b093f
a8bd38b
d613df4
0fedb9f
993140b
7502493
5f1fdae
6d2bd68
82a005d
fd9409f
15c96b3
72601d1
65495b6
63f20ed
0102206
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/providers/webgpu/compute_context.h" | ||
#include "core/providers/webgpu/program.h" | ||
#include "core/providers/webgpu/shader_helper.h" | ||
#include "core/providers/webgpu/webgpu_kernel.h" | ||
#include "contrib_ops/webgpu/bert/attention_common.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace webgpu { | ||
|
||
using namespace onnxruntime::webgpu; | ||
Check warning on line 16 in onnxruntime/contrib_ops/webgpu/bert/attention.h GitHub Actions / Optional Lint C++
|
||
|
||
class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram> { | ||
public: | ||
TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}, | ||
{"batch_offset", ProgramUniformVariableDataType::Uint32}, | ||
{"sequence_offset", ProgramUniformVariableDataType::Uint32}, | ||
{"head_offset", ProgramUniformVariableDataType::Uint32}, | ||
{"bias_offset", ProgramUniformVariableDataType::Uint32}); | ||
|
||
private: | ||
bool has_bias_; | ||
}; | ||
|
||
class AttentionProbsProgram final : public Program<AttentionProbsProgram> { | ||
public: | ||
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, | ||
bool has_attention_bias, int tile_size, int components, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) | ||
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) { | ||
} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, | ||
{"K", ProgramUniformVariableDataType::Uint32}, | ||
{"N", ProgramUniformVariableDataType::Uint32}, | ||
{"num_heads", ProgramUniformVariableDataType::Uint32}, | ||
{"head_size", ProgramUniformVariableDataType::Uint32}, | ||
{"alpha", ProgramUniformVariableDataType::Float32}, | ||
{"past_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"present_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"n_reps", ProgramUniformVariableDataType::Uint32}, | ||
{"is_first_prompt", ProgramUniformVariableDataType::Uint32}); | ||
|
||
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); | ||
|
||
private: | ||
bool feed_past_key_; | ||
bool has_present_key_; | ||
bool has_attention_bias_; | ||
int tile_size_; | ||
int components_; | ||
int n_reps_; | ||
const Tensor* seqlen_k_; | ||
bool past_present_share_buffer_; | ||
}; | ||
|
||
class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> { | ||
public: | ||
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr) | ||
: Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k) { | ||
} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, | ||
{"num_heads", ProgramUniformVariableDataType::Uint32}, | ||
{"past_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32}, | ||
{"elements_per_thread", ProgramUniformVariableDataType::Uint32}, | ||
{"is_first_prompt", ProgramUniformVariableDataType::Uint32}); | ||
|
||
private: | ||
int work_group_size_; | ||
int components_; | ||
const Tensor* seqlen_k_; | ||
}; | ||
|
||
class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> { | ||
public: | ||
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) | ||
Check warning on line 92 in onnxruntime/contrib_ops/webgpu/bert/attention.h GitHub Actions / Optional Lint C++
|
||
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) { | ||
} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, | ||
{"K", ProgramUniformVariableDataType::Uint32}, | ||
{"N", ProgramUniformVariableDataType::Uint32}, | ||
{"num_heads", ProgramUniformVariableDataType::Uint32}, | ||
{"head_size", ProgramUniformVariableDataType::Uint32}, | ||
{"v_hidden_size", ProgramUniformVariableDataType::Uint32}, | ||
{"past_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"present_sequence_length", ProgramUniformVariableDataType::Uint32}, | ||
{"n_reps", ProgramUniformVariableDataType::Uint32}, | ||
{"is_first_prompt", ProgramUniformVariableDataType::Uint32}); | ||
|
||
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); | ||
|
||
private: | ||
bool feed_past_value_; | ||
bool has_present_value_; | ||
int tile_size_; | ||
int n_reps_; | ||
const Tensor* seqlen_k_; | ||
bool past_present_share_buffer_; | ||
}; | ||
|
||
} // namespace webgpu | ||
} // namespace contrib | ||
} // namespace onnxruntime |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/providers/webgpu/compute_context.h" | ||
#include "core/providers/webgpu/program.h" | ||
#include "core/providers/webgpu/shader_helper.h" | ||
#include "core/providers/webgpu/webgpu_kernel.h" | ||
#include "contrib_ops/webgpu/bert/attention_common.h" | ||
|
||
#include "contrib_ops/cpu/bert/attention_common.h" | ||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace webgpu { | ||
|
||
struct WebgpuAttentionParameters { | ||
WebgpuAttentionParameters(AttentionParameters parameters) : is_gqa_parameters_(false), | ||
Check warning on line 18 in onnxruntime/contrib_ops/webgpu/bert/attention_common.h GitHub Actions / Optional Lint C++
|
||
batch_size_(parameters.batch_size), | ||
sequence_length_(parameters.sequence_length), | ||
kv_sequence_length_(parameters.kv_sequence_length), | ||
past_sequence_length_(parameters.past_sequence_length), | ||
total_sequence_length_(parameters.total_sequence_length), | ||
max_sequence_length_(parameters.max_sequence_length), | ||
input_hidden_size_(parameters.input_hidden_size), | ||
hidden_size_(parameters.hidden_size), | ||
head_size_(parameters.head_size), | ||
v_hidden_size_(parameters.v_hidden_size), | ||
v_head_size_(parameters.v_head_size), | ||
num_heads_(parameters.num_heads), | ||
is_unidirectional_(parameters.is_unidirectional), | ||
past_present_share_buffer_(parameters.past_present_share_buffer), | ||
do_rotary_(parameters.do_rotary), | ||
broadcast_attn_bias_dim_0_(parameters.broadcast_attn_bias_dim_0), | ||
broadcast_attn_bias_dim_1_(parameters.broadcast_attn_bias_dim_1), | ||
mask_filter_value_(parameters.mask_filter_value), | ||
scale_(parameters.scale), | ||
mask_type_(parameters.mask_type), | ||
qkv_format_(parameters.qkv_format) { | ||
} | ||
|
||
WebgpuAttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters) : is_gqa_parameters_(true), | ||
Check warning on line 42 in onnxruntime/contrib_ops/webgpu/bert/attention_common.h GitHub Actions / Optional Lint C++
|
||
batch_size_(parameters.batch_size), | ||
sequence_length_(parameters.sequence_length), | ||
kv_sequence_length_(parameters.sequence_length), | ||
past_sequence_length_(parameters.seqlen_past_kv_cache), | ||
total_sequence_length_(parameters.total_sequence_length), | ||
hidden_size_(parameters.hidden_size), | ||
head_size_(parameters.head_size), | ||
v_hidden_size_(parameters.kv_hidden_size), | ||
v_head_size_(parameters.kv_hidden_size / parameters.kv_num_heads), | ||
num_heads_(parameters.num_heads), | ||
do_rotary_(parameters.do_rotary), | ||
seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache), | ||
seqlen_present_kv_cache_(parameters.seqlen_present_kv_cache), | ||
kv_hidden_size_(parameters.kv_hidden_size), | ||
kv_num_heads_(parameters.kv_num_heads), | ||
num_splits_(parameters.num_splits), | ||
rotary_dim_(parameters.rotary_dim), | ||
is_packed_qkv_(parameters.is_packed_qkv), | ||
is_subsequent_prompt_(parameters.is_subsequent_prompt), | ||
is_first_prompt_(parameters.is_first_prompt), | ||
rotary_interleaved_(parameters.rotary_interleaved), | ||
use_smooth_softmax_(parameters.use_smooth_softmax), | ||
softcap_(parameters.softcap), | ||
zeros_count_(parameters.zeros_count), | ||
zero_ptr_(parameters.zero_ptr), | ||
n_reps(parameters.num_heads / parameters.kv_num_heads), | ||
qkv_format_(parameters.qkv_format) { | ||
} | ||
|
||
bool is_gqa_parameters_; | ||
int batch_size_ = 0; | ||
int sequence_length_ = 0; | ||
int kv_sequence_length_ = 0; // input sequence length of K or V | ||
int past_sequence_length_ = 0; // sequence length in past state of K or V | ||
int total_sequence_length_ = 0; // total sequence length of K or V | ||
int max_sequence_length_ = 0; // max sequence length from 4D mask | ||
int input_hidden_size_ = 0; // first dimension of weights for input projection | ||
int hidden_size_ = 0; // hidden size of Q or K | ||
int head_size_ = 0; // hidden size per head of Q or K | ||
int v_hidden_size_ = 0; // hidden size of V | ||
int v_head_size_ = 0; // hidden size per head of V | ||
int num_heads_ = 0; | ||
int rotary_embedding_ = 0; | ||
bool is_unidirectional_ = false; | ||
bool past_present_share_buffer_ = false; | ||
bool do_rotary_ = false; | ||
bool broadcast_attn_bias_dim_0_ = false; | ||
bool broadcast_attn_bias_dim_1_ = false; | ||
float mask_filter_value_ = -10000.0f; | ||
float scale_ = 0.0f; | ||
bool use_tf32_ = false; | ||
; | ||
// The following members are in onnxruntime::contrib::GroupQueryAttentionParameters | ||
// and not in onnxruntime::contrib::AttentionParameters | ||
int seqlen_past_kv_cache_ = 0; // sequence length of past kv tensor | ||
int seqlen_present_kv_cache_ = 0; // sequence length of present kv tensor | ||
int kv_hidden_size_ = 0; | ||
int kv_num_heads_ = 0; | ||
int num_splits_ = 0; // number of splits for splitkv | ||
int rotary_dim_ = 0; // rotary embedding dimension | ||
int local_window_size_ = 0; | ||
bool kv_share_buffer_ = false; | ||
bool is_packed_qkv_ = false; | ||
bool is_subsequent_prompt_ = false; // indicates whether we have past context and seqlen > 1 | ||
bool is_first_prompt_ = false; // indicates whether this is first decoding step | ||
bool rotary_interleaved_ = false; | ||
bool use_smooth_softmax_ = false; | ||
float softcap_ = 0.0; | ||
int zeros_count_ = 0; | ||
; | ||
int* zero_ptr_ = nullptr; | ||
// Computed values | ||
int n_reps = 1; | ||
AttentionMaskType mask_type_ = MASK_NONE; | ||
AttentionQkvFormat qkv_format_ = UNKNOWN; | ||
}; | ||
|
||
Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, | ||
int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor); | ||
|
||
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, | ||
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, | ||
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr); | ||
|
||
} // namespace webgpu | ||
} // namespace contrib | ||
} // namespace onnxruntime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason WebGPU needs a parameters struct that combines AttentionParameters and GroupQueryAttentionParameters? Feels a little confusing to merge those and wondering why it's necessary if we don't need to do that for other EPs that implement these operators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am trying to avoid code duplication. I refactored code into attention used by both GQA and MHA. The CPU version has GQA separate implementation. group_query_attention_helper::CheckInputs() and AttentionBase::CheckInputs output different structs, GroupQueryAttentionParameters and AttentionParameters respectively. WebGPU parameters is a union of these to structs.