Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU EP] SoftMax Implementation #23538

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 231 additions & 0 deletions onnxruntime/core/providers/webgpu/math/softmax.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/inlined_containers.h"
#include "core/providers/webgpu/math/softmax.h"
#include "core/providers/webgpu/tensor/transpose.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/webgpu/shader_variable.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
namespace onnxruntime {
namespace webgpu {

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Softmax,
kOnnxDomain,
1, 10,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Softmax);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Softmax,
kOnnxDomain,
11, 12,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Softmax);

ONNX_OPERATOR_KERNEL_EX(
Softmax,
kOnnxDomain,
13,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Softmax);

static std::string MaxVector(std::string name, int components) {
switch (components) {
case 1:
return name;
case 2:
return "max(" + name + ".x, " + name + ".y)";
case 3:
return "max(max(" + name + ".x, " + name + ".y), " + name + ".z)";
case 4:
return "max(max(" + name + ".x, " + name + ".y), max(" + name + ".z, " + name + ".w))";
default:
ORT_THROW("Unsupported number of components: ", components);
}
}

static std::string SumVector(std::string x, int components) {
switch (components) {
case 1:
return x;
case 2:
return "(" + x + ".x + " + x + ".y" + ")";
case 4:
return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")";
default:
ORT_THROW("Unsupported number of components: ", components);
}
}

static int GetMaxComponents(int64_t size) {
if (size % 4 == 0) {
return 4;
} else if (size % 2 == 0) {
return 2;
}
return 1;
}

Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
// Add input and output variables
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
int components = input.NumComponents();

std::string threadMaxDecl = input.ElementType() == "f32" ? "var threadMax = x_value_t(-3.402823e+38f);\n" : "var threadMax = x_value_t(-65504.0h);\n";

Check warning on line 84 in onnxruntime/core/providers/webgpu/math/softmax.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/math/softmax.cc:84: Add #include <string> for string [build/include_what_you_use] [4]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not a good idea to rely on the return value of ShaderVariableHelper::ElementType.

The design of the shader helper classes (ShaderHelper, ShaderVariableHelper and ShaderIndicesHelper) uses an internal variable usage_ to track whether one or more certain flags are activated for a variable/indices to determine the final generated shader code. Functions like ShaderVariableHelper::ElementType are designed as internal methods that are only for the usage of generating shader code. Making them public will break the design assumption and is error-prone.

If you want to get the data type of a specific input, you can simply check Inputs()[0].var_type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use const instead of var


// Define shared memory for row max and row sum
shader.AdditionalImplementation()
<< "var<workgroup> rowMaxShared : x_value_t;\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use snake_case for variables and user defined functions in WGSL shader code

<< "var<workgroup> rowSumShared : x_value_t;\n"
<< "var<workgroup> threadShared : array<x_value_t, " << WG << ">;\n";

// Define helper functions to get and set values
shader.AdditionalImplementation()
<< "fn getValue(row: i32, col: i32, row_stride: i32) -> x_value_t {\n"
<< " let index = row * row_stride + col;\n"
<< " return x[index];\n"
<< "}\n"
<< "fn setValue(row: i32, col: i32, row_stride: i32, value: x_value_t) {\n"
<< " let index = row * row_stride + col;\n"
<< " result[index] = value;\n"
<< "}\n";

// Main function body
shader.MainFunctionBody()
<< " let gindex = i32(global_idx);\n"
<< " let lindex = i32(local_idx);\n"
<< " const wg = " << WG << ";\n"
<< " let row = gindex / wg;\n"
<< " let cols = uniforms.packedCols;\n"
<< " let row_stride : i32 = uniforms.packedCols;\n"

// Find the row's max value
<< threadMaxDecl
<< " for (var col = lindex; col < cols; col += wg) {\n"
<< " let value = getValue(row, col, row_stride);\n"
<< " threadMax = max(threadMax, value);\n"
<< " }\n"
<< " if (lindex < cols) {\n"
<< " threadShared[lindex] = threadMax;\n"
<< " }\n"
<< " workgroupBarrier();\n"

// Reduce to find the max value
<< " var reduceSize = min(cols, wg);\n"
<< " for (var currSize = reduceSize >> 1; currSize > 0; currSize = reduceSize >> 1) {\n"
<< " reduceSize = currSize + (reduceSize & 1);\n"
<< " if (lindex < currSize) {\n"
<< " threadShared[lindex] = max(threadShared[lindex], threadShared[lindex + reduceSize]);\n"
<< " }\n"
<< " workgroupBarrier();\n"
<< " }\n"
<< " if (lindex == 0) {\n"
<< " rowMaxShared = x_value_t(" << MaxVector("threadShared[0]", components) << ");\n"
<< " }\n"
<< " workgroupBarrier();\n"

// Find the row's sum of exponentials
<< " var threadSum = x_value_t(0.0);\n"
<< " for (var col = lindex; col < cols; col += wg) {\n"
<< " let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);\n"
<< " threadSum += subExp;\n"
<< " }\n"
<< " threadShared[lindex] = threadSum;\n"
<< " workgroupBarrier();\n"

// Reduce to find the sum of exponentials
<< " for (var currSize = wg >> 1; currSize > 0; currSize = currSize >> 1) {\n"
<< " if (lindex < currSize) {\n"
<< " threadShared[lindex] = threadShared[lindex] + threadShared[lindex + currSize];\n"
<< " }\n"
<< " workgroupBarrier();\n"
<< " }\n"
<< " if (lindex == 0) {\n"
<< " rowSumShared = x_value_t(" << SumVector("threadShared[0]", components) << ");\n"
<< " }\n"
<< " workgroupBarrier();\n"

// Calculate the final value for each element in the row
<< " for (var col = lindex; col < cols; col += wg) {\n"
<< " let value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared;\n"
<< " setValue(row, col, row_stride, value);\n"
<< " }\n";

return Status::OK();
}

Status Softmax::ComputeInternal(ComputeContext& context) const {
const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
int64_t input_rank = input_shape.NumDimensions();
auto* output_tensor = context.Output(0, input_shape);

// normalize axis
int64_t axis = axis_ < 0 ? axis_ + input_rank : axis_;
bool is_transpose_required = axis < input_rank - 1;

TensorShape transposed_input_shape;
Tensor transposed_input_tensor;
Tensor intermediate_output;
InlinedVector<size_t> perm(input_rank);

if (is_transpose_required) {
std::iota(std::begin(perm), std::end(perm), 0);
perm[axis] = input_rank - 1;
perm[input_rank - 1] = axis;

std::vector<int64_t> transposed_input_dims;

Check warning on line 187 in onnxruntime/core/providers/webgpu/math/softmax.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/math/softmax.cc:187: Add #include <vector> for vector<> [build/include_what_you_use] [4]
for (auto e : perm) {
transposed_input_dims.push_back(input_shape[e]);
}

transposed_input_shape = TensorShape(transposed_input_dims);
transposed_input_tensor = context.CreateGPUTensor(input_tensor->DataType(), transposed_input_shape);
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, *input_tensor, transposed_input_tensor));
intermediate_output = context.CreateGPUTensor(output_tensor->DataType(), transposed_input_shape);
}

const int64_t cols = is_transpose_required ? transposed_input_shape[input_rank - 1] : input_shape[input_rank - 1];
const int64_t rows = input_shape.Size() / cols;
const int64_t components = GetMaxComponents(cols);
const auto packedCols = cols / components;
uint32_t WG = rows == 1 ? 256 : 64;

SoftmaxProgram program{WG};
if (is_transpose_required) {
program
.AddInputs({{&transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}})
.AddOutputs({{&intermediate_output, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}});
} else {
program
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}});
}

program
.CacheHint(std::to_string(components), std::to_string(WG))
.SetWorkgroupSize(WG)
.SetDispatchGroupSize(rows)
.AddUniformVariables({{static_cast<int32_t>(packedCols)}});

ORT_RETURN_IF_ERROR(context.RunProgram(program));

// If transpose was required, transpose the result back
if (is_transpose_required) {
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, intermediate_output, *output_tensor));
}

return Status::OK();
}
} // namespace webgpu
} // namespace onnxruntime
52 changes: 52 additions & 0 deletions onnxruntime/core/providers/webgpu/math/softmax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/program.h"
#include "core/framework/op_kernel.h"

namespace onnxruntime {
namespace webgpu {

class Softmax final : public WebGpuKernel {
public:
Softmax(const OpKernelInfo& info) : WebGpuKernel{info} {
int opset_ = info.node().SinceVersion();
int64_t axis;
Status status = info.GetAttr<int64_t>("axis", &axis);

if (status.IsOK()) {
axis_ = axis;
} else {
if (opset_ < 13) {
axis_ = 1; // opset-12 and below, the default axis value is 1
} else {
axis_ = -1; // opset-13, the default axis value is -1
}
}
}

Status ComputeInternal(ComputeContext& context) const override;

private:
int64_t axis_;
};

class SoftmaxProgram final : public Program<SoftmaxProgram> {
public:
SoftmaxProgram(uint32_t wg) : Program{"Softmax"}, WG{wg} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"packedCols", ProgramUniformVariableDataType::Int32});

private:
uint32_t WG;
};

} // namespace webgpu
} // namespace onnxruntime
7 changes: 4 additions & 3 deletions onnxruntime/core/providers/webgpu/shader_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,17 @@ class ShaderVariableHelper : public ShaderIndicesHelper {
template <typename TOffset>
inline std::string GetByOffset(TOffset&& offset) const;

std::string_view StorageType() const;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found them making these methods public easiest way to get tensor data types info when generating shader code. I am not sure if this is the best way to do this

std::string_view ValueType() const;
std::string_view ElementType() const;

private:
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariableHelper);

void Impl(std::ostream& ss) const;

std::string GetByOffsetImpl(std::string_view offset) const;
std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const;
std::string_view StorageType() const;
std::string_view ValueType() const;
std::string_view ElementType() const;

friend class ShaderHelper;
};
Expand Down
53 changes: 53 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,59 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const {
return Status::OK();
}

Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always pass span as value instead of const reference.

const auto& input_shape = input.Shape();
const auto& input_dims = input_shape.GetDims();
int32_t rank = gsl::narrow_cast<int32_t>(input_shape.NumDimensions());

TensorShapeVector output_dims(rank);

for (int32_t i = 0; i < rank; i++) {
output_dims[i] = input_dims[permutations[i]];
}

TensorShape output_shape(output_dims);

InlinedVector<int64_t> new_shape{};
InlinedVector<int64_t> new_perm{};
SqueezeShape(input_shape.GetDims(), permutations, new_shape, new_perm);
const bool channels_last = new_perm == InlinedVector<int64_t>({2, 3, 1});
const bool channels_first = new_perm == InlinedVector<int64_t>({3, 1, 2});
const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first;
auto new_input_shape = input_shape;
TensorShape new_output_shape(output_dims);

if (use_shared) {
new_input_shape = channels_last
? TensorShape({new_shape[0], new_shape[1] * new_shape[2]})
: channels_first
? TensorShape({new_shape[0] * new_shape[1], new_shape[2]})
: new_shape;
new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]});
}

uint32_t output_size = gsl::narrow_cast<int32_t>(input_shape.Size());
TransposeProgram program{permutations, use_shared};

if (use_shared) {
program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1);
}
program
.CacheHint(absl::StrJoin(permutations, "-"))
.AddInputs({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}})
.AddOutputs({{&output, ProgramTensorMetadataDependency::None, new_output_shape, 1}})
.SetDispatchGroupSize(static_cast<uint32_t>((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
static_cast<uint32_t>(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)))
.AddUniformVariables({
{static_cast<uint32_t>(output_size)},
});

use_shared ? program.SetDispatchGroupSize(static_cast<uint32_t>((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
static_cast<uint32_t>(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)))
: program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
return context.RunProgram(program);
}

Status Transpose::ComputeInternal(ComputeContext& context) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should call DoTranspose() instead of duplicating the code.

const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class Transpose final : public WebGpuKernel, public TransposeBase {
Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {
}
Status ComputeInternal(ComputeContext& context) const override;
static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output);

constexpr static uint32_t TILE_SIZE = 16;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,9 +625,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMin)>,

// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Softmax)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Softmax)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Softmax)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 3, Concat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 4, 10, Concat)>,
Expand Down
Loading
Loading