Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU EP] Support GroupQueryAttention #22658

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0a5d212
Added attention_common.h
satyajandhyala Oct 27, 2024
5bfa070
wip
satyajandhyala Oct 28, 2024
e6615e9
Fix compilation errors
satyajandhyala Oct 28, 2024
449afb4
lint
satyajandhyala Oct 28, 2024
8d10472
Modified MultiHeadAttention to not derive from AttentionBase class
satyajandhyala Oct 29, 2024
4ea58d1
Uncomment GQA registration
satyajandhyala Oct 29, 2024
4bcf257
Moved TransferBSToBNSH and ApplyAttention declaration to attention_co…
satyajandhyala Oct 29, 2024
5c5c934
Revert "Modified MultiHeadAttention to not derive from AttentionBase …
satyajandhyala Oct 29, 2024
e716546
Converted CheckInput function to template to fix compiler/linker mult…
satyajandhyala Oct 30, 2024
aba59e5
lint
satyajandhyala Oct 30, 2024
067ecd1
Fixed conflicts.
satyajandhyala Oct 30, 2024
53f1c78
copying errors
satyajandhyala Oct 30, 2024
f4dc9fc
Fixed inplacesoftmax dispatch
satyajandhyala Oct 31, 2024
3d1af1c
Initialize required parameter data
satyajandhyala Nov 1, 2024
2eaeebc
Map total_seqlen_tensor input to CPU
satyajandhyala Nov 1, 2024
9c828cc
Use uniforms variable name consistently to avoid confusion.
satyajandhyala Oct 31, 2024
26caa06
Keep InplaceSoftmax dispatch 3-dim.
satyajandhyala Oct 31, 2024
64b093f
Formatting changes.
satyajandhyala Oct 31, 2024
a8bd38b
Use total_seqlen_tensor input only to determin is_first_prompt.
satyajandhyala Nov 4, 2024
d613df4
initialize is_packed_qkv_
satyajandhyala Nov 4, 2024
0fedb9f
Handle past key/value and present key/value buffer sharing.
satyajandhyala Nov 6, 2024
993140b
lint
satyajandhyala Nov 6, 2024
7502493
Added past_present_share_buffer to the hint. typo
satyajandhyala Nov 7, 2024
5f1fdae
past_present_share_buffer related changes.
satyajandhyala Nov 13, 2024
6d2bd68
lint
satyajandhyala Nov 13, 2024
82a005d
Fix integer division
satyajandhyala Nov 13, 2024
fd9409f
Updated hints
satyajandhyala Nov 13, 2024
15c96b3
match jsep code
satyajandhyala Nov 13, 2024
72601d1
Fixed a minor issue
satyajandhyala Nov 14, 2024
65495b6
lint
satyajandhyala Nov 14, 2024
63f20ed
Fix a bug using total_sequence_length instead of uniform.total_sequen…
satyajandhyala Nov 16, 2024
0102206
Revert "match jsep code"
satyajandhyala Nov 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@ namespace onnxruntime {
namespace contrib {
namespace group_query_attention_helper {

Status CheckInputs(const Tensor* query,
const Tensor* key,
const Tensor* value,
const Tensor* past_key,
const Tensor* past_value,
const Tensor* cos_cache,
const Tensor* sin_cache,
template <typename T = Tensor>
Status CheckInputs(const T* query,
const T* key,
const T* value,
const T* past_key,
const T* past_value,
const T* cos_cache,
const T* sin_cache,
void* parameters,
int num_heads,
int kv_num_heads,
const Tensor* seqlens_k,
const Tensor* total_seqlen,
const T* seqlens_k,
const T* total_seqlen,
float scale,
float softcap) {
// Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache
Expand Down Expand Up @@ -265,18 +266,19 @@ Status CheckInputs(const Tensor* query,
return Status::OK();
}

Status CheckInputs(const Tensor* query,
const Tensor* key,
const Tensor* value,
const Tensor* past_key,
const Tensor* past_value,
const Tensor* cos_cache,
const Tensor* sin_cache,
template <typename T = Tensor>
Status CheckInputs(const T* query,
const T* key,
const T* value,
const T* past_key,
const T* past_value,
const T* cos_cache,
const T* sin_cache,
void* parameters,
int num_heads,
int kv_num_heads,
const Tensor* seqlens_k,
const Tensor* total_seqlen,
const T* seqlens_k,
const T* total_seqlen,
float scale,
float softcap,
int max_threads_per_block) {
Expand Down
506 changes: 506 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc

Large diffs are not rendered by default.

123 changes: 123 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/compute_context.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "contrib_ops/webgpu/bert/attention_common.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram> {
public:
TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32},
{"batch_offset", ProgramUniformVariableDataType::Uint32},
{"sequence_offset", ProgramUniformVariableDataType::Uint32},
{"head_offset", ProgramUniformVariableDataType::Uint32},
{"bias_offset", ProgramUniformVariableDataType::Uint32});

private:
bool has_bias_;
};

class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
public:
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
bool has_attention_bias, int tile_size, int components, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"head_size", ProgramUniformVariableDataType::Uint32},
{"alpha", ProgramUniformVariableDataType::Float32},
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});

WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});

private:
bool feed_past_key_;
bool has_present_key_;
bool has_attention_bias_;
int tile_size_;
int components_;
int n_reps_;
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
};

class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
public:
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, const Tensor* seqlen_k = nullptr)
: Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"sequence_length", ProgramUniformVariableDataType::Uint32},
{"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32},
{"elements_per_thread", ProgramUniformVariableDataType::Uint32},
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});

private:
int work_group_size_;
int components_;
const Tensor* seqlen_k_;
};

class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
public:
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"head_size", ProgramUniformVariableDataType::Uint32},
{"v_hidden_size", ProgramUniformVariableDataType::Uint32},
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});

WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});

private:
bool feed_past_value_;
bool has_present_value_;
int tile_size_;
int n_reps_;
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
};

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
129 changes: 129 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/attention_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/compute_context.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "contrib_ops/webgpu/bert/attention_common.h"

#include "contrib_ops/cpu/bert/attention_common.h"
namespace onnxruntime {
namespace contrib {
namespace webgpu {

struct WebgpuAttentionParameters {
WebgpuAttentionParameters(AttentionParameters parameters) : is_gqa_parameters_(false),
batch_size_(parameters.batch_size),
Comment on lines +17 to +19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason WebGPU needs a parameters struct that combines AttentionParameters and GroupQueryAttentionParameters? Feels a little confusing to merge those and wondering why it's necessary if we don't need to do that for other EPs that implement these operators.

Copy link
Contributor Author

@satyajandhyala satyajandhyala Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to avoid code duplication. I refactored code into attention used by both GQA and MHA. The CPU version has GQA separate implementation. group_query_attention_helper::CheckInputs() and AttentionBase::CheckInputs output different structs, GroupQueryAttentionParameters and AttentionParameters respectively. WebGPU parameters is a union of these to structs.

sequence_length_(parameters.sequence_length),
kv_sequence_length_(parameters.kv_sequence_length),
past_sequence_length_(parameters.past_sequence_length),
total_sequence_length_(parameters.total_sequence_length),
max_sequence_length_(parameters.max_sequence_length),
input_hidden_size_(parameters.input_hidden_size),
hidden_size_(parameters.hidden_size),
head_size_(parameters.head_size),
v_hidden_size_(parameters.v_hidden_size),
v_head_size_(parameters.v_head_size),
num_heads_(parameters.num_heads),
is_unidirectional_(parameters.is_unidirectional),
past_present_share_buffer_(parameters.past_present_share_buffer),
do_rotary_(parameters.do_rotary),
broadcast_attn_bias_dim_0_(parameters.broadcast_attn_bias_dim_0),
broadcast_attn_bias_dim_1_(parameters.broadcast_attn_bias_dim_1),
mask_filter_value_(parameters.mask_filter_value),
scale_(parameters.scale),
mask_type_(parameters.mask_type),
qkv_format_(parameters.qkv_format) {
}

WebgpuAttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters) : is_gqa_parameters_(true),
batch_size_(parameters.batch_size),
sequence_length_(parameters.sequence_length),
kv_sequence_length_(parameters.sequence_length),
past_sequence_length_(parameters.seqlen_past_kv_cache),
total_sequence_length_(parameters.total_sequence_length),
hidden_size_(parameters.hidden_size),
head_size_(parameters.head_size),
v_hidden_size_(parameters.kv_hidden_size),
v_head_size_(parameters.kv_hidden_size / parameters.kv_num_heads),
num_heads_(parameters.num_heads),
do_rotary_(parameters.do_rotary),
seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache),
seqlen_present_kv_cache_(parameters.seqlen_present_kv_cache),
kv_hidden_size_(parameters.kv_hidden_size),
kv_num_heads_(parameters.kv_num_heads),
num_splits_(parameters.num_splits),
rotary_dim_(parameters.rotary_dim),
is_packed_qkv_(parameters.is_packed_qkv),
is_subsequent_prompt_(parameters.is_subsequent_prompt),
is_first_prompt_(parameters.is_first_prompt),
rotary_interleaved_(parameters.rotary_interleaved),
use_smooth_softmax_(parameters.use_smooth_softmax),
softcap_(parameters.softcap),
zeros_count_(parameters.zeros_count),
zero_ptr_(parameters.zero_ptr),
n_reps(parameters.num_heads / parameters.kv_num_heads),
qkv_format_(parameters.qkv_format) {
}

bool is_gqa_parameters_;
int batch_size_ = 0;
int sequence_length_ = 0;
int kv_sequence_length_ = 0; // input sequence length of K or V
int past_sequence_length_ = 0; // sequence length in past state of K or V
int total_sequence_length_ = 0; // total sequence length of K or V
int max_sequence_length_ = 0; // max sequence length from 4D mask
int input_hidden_size_ = 0; // first dimension of weights for input projection
int hidden_size_ = 0; // hidden size of Q or K
int head_size_ = 0; // hidden size per head of Q or K
int v_hidden_size_ = 0; // hidden size of V
int v_head_size_ = 0; // hidden size per head of V
int num_heads_ = 0;
int rotary_embedding_ = 0;
bool is_unidirectional_ = false;
bool past_present_share_buffer_ = false;
bool do_rotary_ = false;
bool broadcast_attn_bias_dim_0_ = false;
bool broadcast_attn_bias_dim_1_ = false;
float mask_filter_value_ = -10000.0f;
float scale_ = 0.0f;
bool use_tf32_ = false;
;
// The following members are in onnxruntime::contrib::GroupQueryAttentionParameters
// and not in onnxruntime::contrib::AttentionParameters
int seqlen_past_kv_cache_ = 0; // sequence length of past kv tensor
int seqlen_present_kv_cache_ = 0; // sequence length of present kv tensor
int kv_hidden_size_ = 0;
int kv_num_heads_ = 0;
int num_splits_ = 0; // number of splits for splitkv
int rotary_dim_ = 0; // rotary embedding dimension
int local_window_size_ = 0;
bool kv_share_buffer_ = false;
bool is_packed_qkv_ = false;
bool is_subsequent_prompt_ = false; // indicates whether we have past context and seqlen > 1
bool is_first_prompt_ = false; // indicates whether this is first decoding step
bool rotary_interleaved_ = false;
bool use_smooth_softmax_ = false;
float softcap_ = 0.0;
int zeros_count_ = 0;
;
int* zero_ptr_ = nullptr;
// Computed values
int n_reps = 1;
AttentionMaskType mask_type_ = MASK_NONE;
AttentionQkvFormat qkv_format_ = UNKNOWN;
};

Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length,
int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor);

Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr);

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
Loading
Loading