-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WebGPU EP] SoftMax Implementation #23538
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/common/inlined_containers.h" | ||
#include "core/providers/webgpu/math/softmax.h" | ||
#include "core/providers/webgpu/tensor/transpose.h" | ||
#include "core/providers/cpu/tensor/utils.h" | ||
#include "core/providers/webgpu/shader_variable.h" | ||
#include "core/providers/webgpu/shader_helper.h" | ||
#include "core/providers/webgpu/webgpu_supported_types.h" | ||
namespace onnxruntime { | ||
namespace webgpu { | ||
|
||
ONNX_OPERATOR_VERSIONED_KERNEL_EX( | ||
Softmax, | ||
kOnnxDomain, | ||
1, 10, | ||
kWebGpuExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", WebGpuSupportedNumberTypes()), | ||
Softmax); | ||
|
||
ONNX_OPERATOR_VERSIONED_KERNEL_EX( | ||
Softmax, | ||
kOnnxDomain, | ||
11, 12, | ||
kWebGpuExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", WebGpuSupportedNumberTypes()), | ||
Softmax); | ||
|
||
ONNX_OPERATOR_KERNEL_EX( | ||
Softmax, | ||
kOnnxDomain, | ||
13, | ||
kWebGpuExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", WebGpuSupportedNumberTypes()), | ||
Softmax); | ||
|
||
static std::string MaxVector(std::string name, int components) { | ||
switch (components) { | ||
case 1: | ||
return name; | ||
case 2: | ||
return "max(" + name + ".x, " + name + ".y)"; | ||
case 3: | ||
return "max(max(" + name + ".x, " + name + ".y), " + name + ".z)"; | ||
case 4: | ||
return "max(max(" + name + ".x, " + name + ".y), max(" + name + ".z, " + name + ".w))"; | ||
default: | ||
ORT_THROW("Unsupported number of components: ", components); | ||
} | ||
} | ||
|
||
static std::string SumVector(std::string x, int components) { | ||
switch (components) { | ||
case 1: | ||
return x; | ||
case 2: | ||
return "(" + x + ".x + " + x + ".y" + ")"; | ||
case 4: | ||
return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; | ||
default: | ||
ORT_THROW("Unsupported number of components: ", components); | ||
} | ||
} | ||
|
||
static int GetMaxComponents(int64_t size) { | ||
if (size % 4 == 0) { | ||
return 4; | ||
} else if (size % 2 == 0) { | ||
return 2; | ||
} | ||
return 1; | ||
} | ||
|
||
Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||
// Add input and output variables | ||
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); | ||
shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); | ||
int components = input.NumComponents(); | ||
|
||
std::string threadMaxDecl = input.ElementType() == "f32" ? "var threadMax = x_value_t(-3.402823e+38f);\n" : "var threadMax = x_value_t(-65504.0h);\n"; | ||
Check warning on line 84 in onnxruntime/core/providers/webgpu/math/softmax.cc GitHub Actions / Optional Lint C++
|
||
|
||
// Define shared memory for row max and row sum | ||
shader.AdditionalImplementation() | ||
<< "var<workgroup> rowMaxShared : x_value_t;\n" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use snake_case for variables and user defined functions in WGSL shader code |
||
<< "var<workgroup> rowSumShared : x_value_t;\n" | ||
<< "var<workgroup> threadShared : array<x_value_t, " << WG << ">;\n"; | ||
|
||
// Define helper functions to get and set values | ||
shader.AdditionalImplementation() | ||
<< "fn getValue(row: i32, col: i32, row_stride: i32) -> x_value_t {\n" | ||
<< " let index = row * row_stride + col;\n" | ||
<< " return x[index];\n" | ||
<< "}\n" | ||
<< "fn setValue(row: i32, col: i32, row_stride: i32, value: x_value_t) {\n" | ||
<< " let index = row * row_stride + col;\n" | ||
<< " result[index] = value;\n" | ||
<< "}\n"; | ||
|
||
// Main function body | ||
shader.MainFunctionBody() | ||
<< " let gindex = i32(global_idx);\n" | ||
<< " let lindex = i32(local_idx);\n" | ||
<< " const wg = " << WG << ";\n" | ||
<< " let row = gindex / wg;\n" | ||
<< " let cols = uniforms.packedCols;\n" | ||
<< " let row_stride : i32 = uniforms.packedCols;\n" | ||
|
||
// Find the row's max value | ||
<< threadMaxDecl | ||
<< " for (var col = lindex; col < cols; col += wg) {\n" | ||
<< " let value = getValue(row, col, row_stride);\n" | ||
<< " threadMax = max(threadMax, value);\n" | ||
<< " }\n" | ||
<< " if (lindex < cols) {\n" | ||
<< " threadShared[lindex] = threadMax;\n" | ||
<< " }\n" | ||
<< " workgroupBarrier();\n" | ||
|
||
// Reduce to find the max value | ||
<< " var reduceSize = min(cols, wg);\n" | ||
<< " for (var currSize = reduceSize >> 1; currSize > 0; currSize = reduceSize >> 1) {\n" | ||
<< " reduceSize = currSize + (reduceSize & 1);\n" | ||
<< " if (lindex < currSize) {\n" | ||
<< " threadShared[lindex] = max(threadShared[lindex], threadShared[lindex + reduceSize]);\n" | ||
<< " }\n" | ||
<< " workgroupBarrier();\n" | ||
<< " }\n" | ||
<< " if (lindex == 0) {\n" | ||
<< " rowMaxShared = x_value_t(" << MaxVector("threadShared[0]", components) << ");\n" | ||
<< " }\n" | ||
<< " workgroupBarrier();\n" | ||
|
||
// Find the row's sum of exponentials | ||
<< " var threadSum = x_value_t(0.0);\n" | ||
<< " for (var col = lindex; col < cols; col += wg) {\n" | ||
<< " let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);\n" | ||
<< " threadSum += subExp;\n" | ||
<< " }\n" | ||
<< " threadShared[lindex] = threadSum;\n" | ||
<< " workgroupBarrier();\n" | ||
|
||
// Reduce to find the sum of exponentials | ||
<< " for (var currSize = wg >> 1; currSize > 0; currSize = currSize >> 1) {\n" | ||
<< " if (lindex < currSize) {\n" | ||
<< " threadShared[lindex] = threadShared[lindex] + threadShared[lindex + currSize];\n" | ||
<< " }\n" | ||
<< " workgroupBarrier();\n" | ||
<< " }\n" | ||
<< " if (lindex == 0) {\n" | ||
<< " rowSumShared = x_value_t(" << SumVector("threadShared[0]", components) << ");\n" | ||
<< " }\n" | ||
<< " workgroupBarrier();\n" | ||
|
||
// Calculate the final value for each element in the row | ||
<< " for (var col = lindex; col < cols; col += wg) {\n" | ||
<< " let value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared;\n" | ||
<< " setValue(row, col, row_stride, value);\n" | ||
<< " }\n"; | ||
|
||
return Status::OK(); | ||
} | ||
|
||
Status Softmax::ComputeInternal(ComputeContext& context) const { | ||
const auto* input_tensor = context.Input(0); | ||
const TensorShape& input_shape = input_tensor->Shape(); | ||
int64_t input_rank = input_shape.NumDimensions(); | ||
auto* output_tensor = context.Output(0, input_shape); | ||
|
||
// normalize axis | ||
int64_t axis = axis_ < 0 ? axis_ + input_rank : axis_; | ||
bool is_transpose_required = axis < input_rank - 1; | ||
|
||
TensorShape transposed_input_shape; | ||
Tensor transposed_input_tensor; | ||
Tensor intermediate_output; | ||
InlinedVector<size_t> perm(input_rank); | ||
|
||
if (is_transpose_required) { | ||
std::iota(std::begin(perm), std::end(perm), 0); | ||
perm[axis] = input_rank - 1; | ||
perm[input_rank - 1] = axis; | ||
|
||
std::vector<int64_t> transposed_input_dims; | ||
Check warning on line 187 in onnxruntime/core/providers/webgpu/math/softmax.cc GitHub Actions / Optional Lint C++
|
||
for (auto e : perm) { | ||
transposed_input_dims.push_back(input_shape[e]); | ||
} | ||
|
||
transposed_input_shape = TensorShape(transposed_input_dims); | ||
transposed_input_tensor = context.CreateGPUTensor(input_tensor->DataType(), transposed_input_shape); | ||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, *input_tensor, transposed_input_tensor)); | ||
intermediate_output = context.CreateGPUTensor(output_tensor->DataType(), transposed_input_shape); | ||
} | ||
|
||
const int64_t cols = is_transpose_required ? transposed_input_shape[input_rank - 1] : input_shape[input_rank - 1]; | ||
const int64_t rows = input_shape.Size() / cols; | ||
const int64_t components = GetMaxComponents(cols); | ||
const auto packedCols = cols / components; | ||
uint32_t WG = rows == 1 ? 256 : 64; | ||
|
||
SoftmaxProgram program{WG}; | ||
if (is_transpose_required) { | ||
program | ||
.AddInputs({{&transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}}) | ||
.AddOutputs({{&intermediate_output, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}}); | ||
} else { | ||
program | ||
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}}) | ||
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}}); | ||
} | ||
|
||
program | ||
.CacheHint(std::to_string(components), std::to_string(WG)) | ||
.SetWorkgroupSize(WG) | ||
.SetDispatchGroupSize(rows) | ||
.AddUniformVariables({{static_cast<int32_t>(packedCols)}}); | ||
|
||
ORT_RETURN_IF_ERROR(context.RunProgram(program)); | ||
|
||
// If transpose was required, transpose the result back | ||
if (is_transpose_required) { | ||
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, intermediate_output, *output_tensor)); | ||
} | ||
|
||
return Status::OK(); | ||
} | ||
} // namespace webgpu | ||
} // namespace onnxruntime |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/providers/webgpu/webgpu_supported_types.h" | ||
#include "core/providers/webgpu/webgpu_kernel.h" | ||
#include "core/providers/webgpu/program.h" | ||
#include "core/framework/op_kernel.h" | ||
|
||
namespace onnxruntime { | ||
namespace webgpu { | ||
|
||
class Softmax final : public WebGpuKernel { | ||
public: | ||
Softmax(const OpKernelInfo& info) : WebGpuKernel{info} { | ||
int opset_ = info.node().SinceVersion(); | ||
int64_t axis; | ||
Status status = info.GetAttr<int64_t>("axis", &axis); | ||
|
||
if (status.IsOK()) { | ||
axis_ = axis; | ||
} else { | ||
if (opset_ < 13) { | ||
axis_ = 1; // opset-12 and below, the default axis value is 1 | ||
} else { | ||
axis_ = -1; // opset-13, the default axis value is -1 | ||
} | ||
} | ||
} | ||
|
||
Status ComputeInternal(ComputeContext& context) const override; | ||
|
||
private: | ||
int64_t axis_; | ||
}; | ||
|
||
class SoftmaxProgram final : public Program<SoftmaxProgram> { | ||
public: | ||
SoftmaxProgram(uint32_t wg) : Program{"Softmax"}, WG{wg} { | ||
} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"packedCols", ProgramUniformVariableDataType::Int32}); | ||
|
||
private: | ||
uint32_t WG; | ||
}; | ||
|
||
} // namespace webgpu | ||
} // namespace onnxruntime |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -176,16 +176,17 @@ class ShaderVariableHelper : public ShaderIndicesHelper { | |
template <typename TOffset> | ||
inline std::string GetByOffset(TOffset&& offset) const; | ||
|
||
std::string_view StorageType() const; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found them making these methods public easiest way to get tensor data types info when generating shader code. I am not sure if this is the best way to do this |
||
std::string_view ValueType() const; | ||
std::string_view ElementType() const; | ||
|
||
private: | ||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariableHelper); | ||
|
||
void Impl(std::ostream& ss) const; | ||
|
||
std::string GetByOffsetImpl(std::string_view offset) const; | ||
std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const; | ||
std::string_view StorageType() const; | ||
std::string_view ValueType() const; | ||
std::string_view ElementType() const; | ||
|
||
friend class ShaderHelper; | ||
}; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,6 +97,59 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { | |
return Status::OK(); | ||
} | ||
|
||
Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. always pass |
||
const auto& input_shape = input.Shape(); | ||
const auto& input_dims = input_shape.GetDims(); | ||
int32_t rank = gsl::narrow_cast<int32_t>(input_shape.NumDimensions()); | ||
|
||
TensorShapeVector output_dims(rank); | ||
|
||
for (int32_t i = 0; i < rank; i++) { | ||
output_dims[i] = input_dims[permutations[i]]; | ||
} | ||
|
||
TensorShape output_shape(output_dims); | ||
|
||
InlinedVector<int64_t> new_shape{}; | ||
InlinedVector<int64_t> new_perm{}; | ||
SqueezeShape(input_shape.GetDims(), permutations, new_shape, new_perm); | ||
const bool channels_last = new_perm == InlinedVector<int64_t>({2, 3, 1}); | ||
const bool channels_first = new_perm == InlinedVector<int64_t>({3, 1, 2}); | ||
const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first; | ||
auto new_input_shape = input_shape; | ||
TensorShape new_output_shape(output_dims); | ||
|
||
if (use_shared) { | ||
new_input_shape = channels_last | ||
? TensorShape({new_shape[0], new_shape[1] * new_shape[2]}) | ||
: channels_first | ||
? TensorShape({new_shape[0] * new_shape[1], new_shape[2]}) | ||
: new_shape; | ||
new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]}); | ||
} | ||
|
||
uint32_t output_size = gsl::narrow_cast<int32_t>(input_shape.Size()); | ||
TransposeProgram program{permutations, use_shared}; | ||
|
||
if (use_shared) { | ||
program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1); | ||
} | ||
program | ||
.CacheHint(absl::StrJoin(permutations, "-")) | ||
.AddInputs({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) | ||
.AddOutputs({{&output, ProgramTensorMetadataDependency::None, new_output_shape, 1}}) | ||
.SetDispatchGroupSize(static_cast<uint32_t>((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), | ||
static_cast<uint32_t>(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) | ||
.AddUniformVariables({ | ||
{static_cast<uint32_t>(output_size)}, | ||
}); | ||
|
||
use_shared ? program.SetDispatchGroupSize(static_cast<uint32_t>((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), | ||
static_cast<uint32_t>(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) | ||
: program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); | ||
return context.RunProgram(program); | ||
} | ||
|
||
Status Transpose::ComputeInternal(ComputeContext& context) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function should call |
||
const auto* input_tensor = context.Input(0); | ||
const TensorShape& input_shape = input_tensor->Shape(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not a good idea to rely on the return value of
ShaderVariableHelper::ElementType
.The design of the shader helper classes (
ShaderHelper
,ShaderVariableHelper
andShaderIndicesHelper
) uses an internal variableusage_
to track whether one or more certain flags are activated for a variable/indices to determine the final generated shader code. Functions likeShaderVariableHelper::ElementType
are designed as internal methods that are only for the usage of generating shader code. Making them public will break the design assumption and is error-prone.If you want to get the data type of a specific input, you can simply check
Inputs()[0].var_type
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use const instead of var