Skip to content
This repository was archived by the owner on May 7, 2025. It is now read-only.

Commit 2b72b51

Browse files
authored
Merge pull request #185 from ariaghora/feature/hardsigmoid
Integration of HardSigmoid Operation
2 parents 0c56190 + 95e7b38 commit 2b72b51

File tree

6 files changed

+121
-15
lines changed

6 files changed

+121
-15
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ fn test_matmul_square_matrix() {
265265
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalMaxPool">GlobalMaxPool</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GlobalMaxPool-1">1</a>|
266266
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Greater">Greater</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Greater-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Greater-9">9</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Greater-7">7</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Greater-1">1</a>||
267267
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#GridSample">GridSample</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GridSample-16">16</a>|
268-
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#HardSigmoid">HardSigmoid</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#HardSigmoid-6">6</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#HardSigmoid-1">1</a>|
268+
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#HardSigmoid">HardSigmoid</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#HardSigmoid-6">6</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#HardSigmoid-1">1</a>|||
269269
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Hardmax">Hardmax</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Hardmax-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Hardmax-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Hardmax-1">1</a>|
270270
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity">Identity</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16">16</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14">14</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1">1</a>|||
271271
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#If">If</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#If-16">16</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#If-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#If-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#If-1">1</a>|

wonnx-py/tests/test_onnx_backend.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def do_enforce_test_coverage_safelist(model): # type: (ModelProto) -> bool
124124
backend_test = onnx.backend.test.BackendTest(DummyBackend, __name__)
125125

126126

127-
128127
backend_test.include(f"test_constant_cpu")
129128
backend_test.include(f"test_conv_[a-z,_]*")
130129
backend_test.include(f"test_Conv2d[a-z,_]*")
@@ -147,6 +146,9 @@ def do_enforce_test_coverage_safelist(model): # type: (ModelProto) -> bool
147146
backend_test.include(f"test_size_[a-z,_]*")
148147
backend_test.include(f"test_celu_[a-z,_]*")
149148

149+
# Disabled until CastLike is implemented
150+
# backend_test.include(f"test_hardsigmoid_[a-z,_]*")
151+
150152
# For these we only test the default version, as we don't support the bool type
151153
backend_test.include(f"test_prelu_broadcast_cpu$")
152154
backend_test.include(f"test_elu_cpu$")
@@ -162,15 +164,15 @@ def do_enforce_test_coverage_safelist(model): # type: (ModelProto) -> bool
162164
# Disable tests for ReduceSum because ReduceSum accepts the 'axes' list as input instead of as an attribute, and the test
163165
# case sets the 'axes' input dynamically, which we don't support (yet?).
164166
# backend_test.include(f"test_reduce_sum_[a-z,_]*")
165-
#backend_test.include(f"test_reduce_mean_[a-z,_]*")
166-
#backend_test.include(f"test_reduce_l1_[a-z,_]*")
167-
#backend_test.include(f"test_reduce_l2_[a-z,_]*")
168-
#backend_test.include(f"test_reduce_min_[a-z,_]*")
169-
#backend_test.include(f"test_reduce_prod_[a-z,_]*")
170-
#backend_test.include(f"test_reduce_sum_square_[a-z,_]*")
171-
#backend_test.include(f"test_reduce_max_[a-z,_]*")
172-
#backend_test.include(f"test_reduce_log_sum_[a-z,_]*")
173-
#backend_test.include(f"test_reduce_log_sum_exp_[a-z,_]*")
167+
# backend_test.include(f"test_reduce_mean_[a-z,_]*")
168+
# backend_test.include(f"test_reduce_l1_[a-z,_]*")
169+
# backend_test.include(f"test_reduce_l2_[a-z,_]*")
170+
# backend_test.include(f"test_reduce_min_[a-z,_]*")
171+
# backend_test.include(f"test_reduce_prod_[a-z,_]*")
172+
# backend_test.include(f"test_reduce_sum_square_[a-z,_]*")
173+
# backend_test.include(f"test_reduce_max_[a-z,_]*")
174+
# backend_test.include(f"test_reduce_log_sum_[a-z,_]*")
175+
# backend_test.include(f"test_reduce_log_sum_exp_[a-z,_]*")
174176

175177
# Takes dynamic input, we don't support that yet
176178
# backend_test.include(f"test_constantofshape_[a-z,_]*")

wonnx/src/compiler.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -743,13 +743,21 @@ pub fn compile(
743743
}
744744
}
745745
op @ ("Relu" | "Sigmoid" | "Softsign" | "Softplus" | "Clip" | "Celu" | "Elu"
746-
| "LeakyRelu") => {
747-
let alpha = if op == "LeakyRelu" {
748-
node.get_attribute_value("alpha", Some(0.01))?
746+
| "LeakyRelu" | "HardSigmoid") => {
747+
let alpha = match op {
748+
"LeakyRelu" => node.get_attribute_value("alpha", Some(0.01))?,
749+
"HardSigmoid" => node.get_attribute_value("alpha", Some(0.2))?,
750+
_ => node.get_attribute_value("alpha", Some(1.0))?,
751+
};
752+
753+
let beta = if op == "HardSigmoid" {
754+
node.get_attribute_value("beta", Some(0.5))?
749755
} else {
750-
node.get_attribute_value("alpha", Some(1.0))?
756+
node.get_attribute_value("beta", Some(1.0))?
751757
};
758+
752759
context.insert("alpha", &alpha);
760+
context.insert("beta", &beta);
753761

754762
if op == "Clip" {
755763
let min: Vec<f32> =

wonnx/templates/snippets/activation_scalar.wgsl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@
3838
{{ scalar_type }}({{ alpha }}) * (exp(input_vec) - {{ scalar_type }}(1))
3939
);
4040

41+
{%- elif activation_type == "HardSigmoid" -%}
42+
{{ activation_output }} = max(
43+
{{ scalar_type }}(0),
44+
min(
45+
{{ scalar_type }}(1),
46+
{{ scalar_type }}({{ alpha }}) * {{ activation_input }} + {{ scalar_type }}({{ beta }})
47+
)
48+
);
49+
4150
{%- elif activation_output != activation_input -%}
4251
{{ activation_output }} = {{ activation_input }};
4352

wonnx/templates/snippets/activation_vec.wgsl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@
4646
{{ activation_output }} = max({{ activation_input }}, Vec4(Scalar(), Scalar(), Scalar(), Scalar()))
4747
+ min({{ scalar_type }}({{ alpha }}) * {{ activation_input }}, Vec4(Scalar(), Scalar(), Scalar(), Scalar()));
4848

49+
{%- elif activation_type == "HardSigmoid" -%}
50+
{{ activation_output }} = max(
51+
Vec4(Scalar(), Scalar(), Scalar(), Scalar()),
52+
min(
53+
Vec4({{ scalar_type }}(1), {{ scalar_type }}(1), {{ scalar_type }}(1), {{ scalar_type }}(1)),
54+
{{ scalar_type }}({{ alpha }}) * {{ activation_input }} + {{ scalar_type }}({{ beta }})
55+
)
56+
);
57+
4958
{%- elif activation_output != activation_input -%}
5059
{{ activation_output }} = {{ activation_input }};
5160

wonnx/tests/hardsigmoid.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
use std::{collections::HashMap, convert::TryInto};
2+
use wonnx::utils::{attribute, graph, model, node, tensor};
3+
mod common;
4+
5+
/// Test HardSigmoid node with default alpha and beta
6+
/// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-68
7+
#[test]
8+
fn test_hardsigmoid_default() {
9+
let input_data = [-2.0, -1.0, 1.0, 2.0];
10+
let shape = vec![2, 2];
11+
12+
let (default_alpha, default_beta) = (0.2, 0.5);
13+
let expected_output_data: Vec<f32> = input_data
14+
.iter()
15+
.map(|x| x * default_alpha + default_beta)
16+
.collect();
17+
18+
let mut model_input = HashMap::new();
19+
model_input.insert("X".to_string(), input_data.as_slice().into());
20+
21+
let node = node(vec!["X"], vec!["Y"], "hard_sigmoid", "HardSigmoid", vec![]);
22+
23+
let model = model(graph(
24+
vec![tensor("X", &shape)],
25+
vec![tensor("Y", &shape)],
26+
vec![],
27+
vec![],
28+
vec![node],
29+
));
30+
31+
let session =
32+
pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");
33+
34+
let output = pollster::block_on(session.run(&model_input)).unwrap();
35+
let output_data: &[f32] = (&output["Y"]).try_into().unwrap();
36+
37+
common::assert_eq_vector(output_data, expected_output_data.as_slice());
38+
}
39+
40+
/// Test HardSigmoid node with predefined alpha and beta
41+
/// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-68
42+
#[test]
43+
fn test_hardsigmoid() {
44+
let input_data: Vec<f32> = vec![-1.0, 0.0, 1.0];
45+
let shape = vec![1, 3];
46+
47+
let mut model_input = HashMap::new();
48+
model_input.insert("X".to_string(), input_data.as_slice().into());
49+
50+
let alpha = attribute("alpha", 0.5);
51+
let beta = attribute("beta", 0.6);
52+
53+
let node = node(
54+
vec!["X"],
55+
vec!["Y"],
56+
"hard_sigmoid",
57+
"HardSigmoid",
58+
vec![alpha, beta],
59+
);
60+
61+
let model = model(graph(
62+
vec![tensor("X", &shape)],
63+
vec![tensor("Y", &shape)],
64+
vec![],
65+
vec![],
66+
vec![node],
67+
));
68+
69+
let session =
70+
pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");
71+
72+
let output = pollster::block_on(session.run(&model_input)).unwrap();
73+
println!("{:?}", output);
74+
75+
let expected_output = &[0.1, 0.6, 1.0];
76+
let output_data: &[f32] = (&output["Y"]).try_into().unwrap();
77+
common::assert_eq_vector(output_data, expected_output);
78+
}

0 commit comments

Comments
 (0)