Skip to content
This repository was archived by the owner on May 7, 2025. It is now read-only.

Commit d594a93

Browse files
committed
feat: add support for the Neg operator (elementwise -x)
1 parent e46271c commit d594a93

File tree

4 files changed

+8
-2
lines changed

4 files changed

+8
-2
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ fn test_matmul_square_matrix() {
291291
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Mod">Mod</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Mod-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Mod-10">10</a>|||
292292
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Mul">Mul</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Mul-14">14</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Mul-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Mul-7">7</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Mul-6">6</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Mul-1">1</a>|||
293293
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Multinomial">Multinomial</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Multinomial-7">7</a>|
294-
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Neg">Neg</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Neg-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Neg-6">6</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Neg-1">1</a>|
294+
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Neg">Neg</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Neg-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Neg-6">6</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Neg-1">1</a>|||
295295
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#NonMaxSuppression">NonMaxSuppression</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#NonMaxSuppression-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#NonMaxSuppression-10">10</a>|
296296
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#NonZero">NonZero</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#NonZero-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#NonZero-9">9</a>|
297297
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Not">Not</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Not-1">1</a>||

wonnx-py/tests/test_onnx_backend.py

+2
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def do_enforce_test_coverage_safelist(model): # type: (ModelProto) -> bool
155155
backend_test.include(f"test_sub_bcast_[a-z,_]*")
156156
backend_test.include(f"test_pow_bcast_[a-z,_]*")
157157
backend_test.include(f"test_transpose[a-z,_]*")
158+
backend_test.include(f"test_neg_[a-z,_]*")
159+
backend_test.include(f"test_reciprocal_[a-z,_]*")
158160

159161
# Don't support 'bool' type
160162
# backend_test.include(f"test_and_bcast[a-z0-9,_]*")

wonnx/src/compiler.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ pub fn compile(
288288
// Map simple function
289289
"Abs" | "Acos" | "Asin" | "Atan" | "Ceil" | "Cos" | "Cosh" | "Exp" | "Floor" | "Log"
290290
| "Round" | "Sign" | "Sin" | "Sinh" | "Sqrt" | "Tan" | "Tanh" | "Reciprocal" | "Acosh"
291-
| "Asinh" | "Atanh" => {
291+
| "Asinh" | "Atanh" | "Neg" => {
292292
let (x_threads, workgroup_size_x) = workgroup_size(
293293
ceil(output_lengths[0], 4),
294294
MAX_COMPUTE_WORKGROUPS_PER_DIMENSION,

wonnx/templates/endomorphism/map.wgsl

+4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
1313
let one_scalar = {{scalar_type}}(1);
1414
let one = Vec4(one_scalar, one_scalar, one_scalar, one_scalar);
1515
output_0.data[gidx] = one / (input_0.data[gidx]);
16+
{% elif op_type == "Neg" %}
17+
let zero_scalar = {{scalar_type}}(0);
18+
let zeroes = Vec4(zero_scalar, zero_scalar, zero_scalar, zero_scalar);
19+
output_0.data[gidx] = zeroes - (input_0.data[gidx]);
1620
{% elif op_type == "Tanh" %}
1721
{# Tanh will produce NaNs when fed with inputs that are much larger than +10.0 or smaller than -10.0. As the output
1822
for these inputs converges to 1.0 and -1.0 respectively, we clamp the inputs first. #}

0 commit comments

Comments
 (0)