Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU EP] SoftMax Implementation #23538

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

vraspar
Copy link
Contributor

@vraspar vraspar commented Jan 30, 2025

Increase coverage for WebGPU Op

@@ -176,16 +176,18 @@ class ShaderVariableHelper : public ShaderIndicesHelper {
template <typename TOffset>
inline std::string GetByOffset(TOffset&& offset) const;

std::string_view StorageType() const;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found them making these methods public easiest way to get tensor data types info when generating shader code. I am not sure if this is the best way to do this

@vraspar vraspar added the ep:WebGPU ort-web webgpu provider label Jan 30, 2025
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

onnxruntime/core/providers/webgpu/shader_variable.h Outdated Show resolved Hide resolved
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@vraspar vraspar requested review from skottmckay and fs-eire January 30, 2025 19:50
kNnapiExecutionProvider, // NNAPI softmax does not support empty input
kQnnExecutionProvider} // QNN doesn't support dim 0
kNnapiExecutionProvider, // NNAPI softmax does not support empty input
kWebGpuExecutionProvider, // WebGPU does not dim 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kWebGpuExecutionProvider, // WebGPU does not dim 0
kWebGpuExecutionProvider, // WebGPU does not support dim 0

: program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
return context.RunProgram(program);
}

Status Transpose::ComputeInternal(ComputeContext& context) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should call DoTranspose() instead of duplicating the code.

@@ -97,6 +97,59 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const {
return Status::OK();
}

Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always pass span as value instead of const reference.

shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
int components = input.NumComponents();

std::string threadMaxDecl = input.ElementType() == "f32" ? "var threadMax = x_value_t(-3.402823e+38f);\n" : "var threadMax = x_value_t(-65504.0h);\n";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not a good idea to rely on the return value of ShaderVariableHelper::ElementType.

The design of the shader helper classes (ShaderHelper, ShaderVariableHelper and ShaderIndicesHelper) uses an internal variable usage_ to track whether one or more certain flags are activated for a variable/indices to determine the final generated shader code. Functions like ShaderVariableHelper::ElementType are designed as internal methods that are only for the usage of generating shader code. Making them public will break the design assumption and is error-prone.

If you want to get the data type of a specific input, you can simply check Inputs()[0].var_type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use const instead of var


// Define shared memory for row max and row sum
shader.AdditionalImplementation()
<< "var<workgroup> rowMaxShared : x_value_t;\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use snake_case for variables and user defined functions in WGSL shader code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants