-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WebGPU EP] SoftMax Implementation #23538
base: main
Are you sure you want to change the base?
Conversation
@@ -176,16 +176,18 @@ class ShaderVariableHelper : public ShaderIndicesHelper { | |||
template <typename TOffset> | |||
inline std::string GetByOffset(TOffset&& offset) const; | |||
|
|||
std::string_view StorageType() const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found them making these methods public easiest way to get tensor data types info when generating shader code. I am not sure if this is the best way to do this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
kNnapiExecutionProvider, // NNAPI softmax does not support empty input | ||
kQnnExecutionProvider} // QNN doesn't support dim 0 | ||
kNnapiExecutionProvider, // NNAPI softmax does not support empty input | ||
kWebGpuExecutionProvider, // WebGPU does not dim 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kWebGpuExecutionProvider, // WebGPU does not dim 0 | |
kWebGpuExecutionProvider, // WebGPU does not support dim 0 |
: program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); | ||
return context.RunProgram(program); | ||
} | ||
|
||
Status Transpose::ComputeInternal(ComputeContext& context) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function should call DoTranspose()
instead of duplicating the code.
@@ -97,6 +97,59 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { | |||
return Status::OK(); | |||
} | |||
|
|||
Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
always pass span
as value instead of const reference.
shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); | ||
int components = input.NumComponents(); | ||
|
||
std::string threadMaxDecl = input.ElementType() == "f32" ? "var threadMax = x_value_t(-3.402823e+38f);\n" : "var threadMax = x_value_t(-65504.0h);\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not a good idea to rely on the return value of ShaderVariableHelper::ElementType
.
The design of the shader helper classes (ShaderHelper
, ShaderVariableHelper
and ShaderIndicesHelper
) uses an internal variable usage_
to track whether one or more certain flags are activated for a variable/indices to determine the final generated shader code. Functions like ShaderVariableHelper::ElementType
are designed as internal methods that are only for the usage of generating shader code. Making them public will break the design assumption and is error-prone.
If you want to get the data type of a specific input, you can simply check Inputs()[0].var_type
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use const instead of var
|
||
// Define shared memory for row max and row sum | ||
shader.AdditionalImplementation() | ||
<< "var<workgroup> rowMaxShared : x_value_t;\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use snake_case for variables and user defined functions in WGSL shader code
Increase coverage for WebGPU Op