From 2943648054afc984be5a490a65baaa3cf1df8c2f Mon Sep 17 00:00:00 2001 From: Julien Date: Tue, 18 Jul 2023 16:25:38 +0200 Subject: [PATCH 1/5] Add argmax --- wonnx/src/compiler.rs | 2 +- wonnx/templates/pool/reduce.wgsl | 9 ++++++++- wonnx/tests/reduce.rs | 11 +++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index bed2b865..56f2b09d 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -299,7 +299,7 @@ pub fn compile( op @ ("ReduceMean" | "ReduceSum" | "ReduceMax" | "ReduceMin" | "ReduceProd" | "ReduceL1" | "ReduceL2" | "ReduceLogSum" | "ReduceLogSumExp" - | "ReduceSumSquare") => { + | "ReduceSumSquare" | "ArgMax") => { let all_axes: Vec = (0..(i_dims[0].len() as i64)).collect(); let axes: Vec = node .get_attribute_value("axes", Some(all_axes))? diff --git a/wonnx/templates/pool/reduce.wgsl b/wonnx/templates/pool/reduce.wgsl index d4edfa1e..93403763 100644 --- a/wonnx/templates/pool/reduce.wgsl +++ b/wonnx/templates/pool/reduce.wgsl @@ -40,7 +40,9 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { Now for each reduced axis, iterate all values and reduce. Note, starting value may not always be zero. For ReduceMin/Max we should initialize as NaN and keep a flag to check if we have seen at least one element -#} - var accumulator = {% if op_type == "ReduceProd" %} {{ scalar_type }}(1) {% else %} Scalar() {% endif %}; + var accumulator = {% if op_type == "ReduceProd" %} {{ scalar_type }}(1) {% else %} Scalar() {% endif %}; + var max_element: Scalar = log(Scalar()); + var count = 0u; {% for reducing_axis in axes %} @@ -79,6 +81,11 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { else if(accumulator < input_val) { accumulator = input_val; } + {% elif op_type == "ArgMax" %} + if(input_val > max_element) { + max_element = input_val; + accumulator = f32(count); + } {% endif %} count = count + 1u; diff --git a/wonnx/tests/reduce.rs b/wonnx/tests/reduce.rs index e43308c2..9342ecb5 100644 --- a/wonnx/tests/reduce.rs +++ b/wonnx/tests/reduce.rs @@ -223,6 +223,17 @@ fn reduce() { &[3, 2], ); + // ONNX test case: do_not_keepdims with ArgMax + test_reduce( + &data, + &[3, 2, 2], + Some(vec![1]), + "ArgMax", + false, + &[1., 1., 1., 1., 1., 1.], + &[3, 2], + ); + // ONNX test case for ReduceSumSquare (https://github.com/onnx/onnx/blob/94e2f64551ded652df53a7e9111031e8aabddaee/onnx/backend/test/case/node/reducesumsquare.py#L27) test_reduce( &[1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.], From e69f75f859fefd32142b30f532c06c0627fcd278 Mon Sep 17 00:00:00 2001 From: riccardo Date: Tue, 18 Jul 2023 15:27:17 +0000 Subject: [PATCH 2/5] 18 Jul 2023, 17:27 --- wonnx/tests/reduce.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/wonnx/tests/reduce.rs b/wonnx/tests/reduce.rs index 9342ecb5..59e399db 100644 --- a/wonnx/tests/reduce.rs +++ b/wonnx/tests/reduce.rs @@ -62,6 +62,17 @@ fn reduce() { 60.0, 2.0, ]; + let data2 = [ + 5.0, 1.0, + 2.0, 20.0 + + 30.0, 1.0, + 40.0, 2.0, + + 1.0, 55.0, + 60.0, 2.0, + ]; + // ReduceSum: sum all test_reduce( &data, @@ -230,7 +241,7 @@ fn reduce() { Some(vec![1]), "ArgMax", false, - &[1., 1., 1., 1., 1., 1.], + &[1., 2., 1., 1., 2., 1.], &[3, 2], ); From 64e984e3b180dbc96c97851fe54a3fcb102cd6af Mon Sep 17 00:00:00 2001 From: riccardo Date: Tue, 18 Jul 2023 15:28:22 +0000 Subject: [PATCH 3/5] 18 Jul 2023, 17:28 --- wonnx/tests/reduce.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wonnx/tests/reduce.rs b/wonnx/tests/reduce.rs index 59e399db..34760e59 100644 --- a/wonnx/tests/reduce.rs +++ b/wonnx/tests/reduce.rs @@ -241,7 +241,7 @@ fn reduce() { Some(vec![1]), "ArgMax", false, - &[1., 2., 1., 1., 2., 1.], + &[1., 1., 1., 1., 1., 1.], &[3, 2], ); From 330aa258246853ac7e6211f722c82e385682cbe1 Mon Sep 17 00:00:00 2001 From: riccardo Date: Tue, 18 Jul 2023 15:32:38 +0000 Subject: [PATCH 4/5] 18 Jul 2023, 17:32 --- wonnx/tests/reduce.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wonnx/tests/reduce.rs b/wonnx/tests/reduce.rs index 34760e59..5a8de22f 100644 --- a/wonnx/tests/reduce.rs +++ b/wonnx/tests/reduce.rs @@ -64,7 +64,7 @@ fn reduce() { let data2 = [ 5.0, 1.0, - 2.0, 20.0 + 2.0, 20.0, 30.0, 1.0, 40.0, 2.0, @@ -236,12 +236,12 @@ fn reduce() { // ONNX test case: do_not_keepdims with ArgMax test_reduce( - &data, + &data2, &[3, 2, 2], Some(vec![1]), "ArgMax", false, - &[1., 1., 1., 1., 1., 1.], + &[1., 2., 1., 1., 2., 1.], &[3, 2], ); From fe12107ed9ba0224b9e809f500c2d0bbac7a8094 Mon Sep 17 00:00:00 2001 From: riccardo Date: Tue, 18 Jul 2023 17:46:27 +0200 Subject: [PATCH 5/5] more complicated test for argmax --- wonnx/tests/reduce.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/wonnx/tests/reduce.rs b/wonnx/tests/reduce.rs index 5a8de22f..b26bb6d1 100644 --- a/wonnx/tests/reduce.rs +++ b/wonnx/tests/reduce.rs @@ -62,15 +62,16 @@ fn reduce() { 60.0, 2.0, ]; - let data2 = [ - 5.0, 1.0, - 2.0, 20.0, + #[rustfmt::skip] + let data_two = [ + 20.0, 1.0, + 5.0, 2.0, 30.0, 1.0, 40.0, 2.0, - 1.0, 55.0, - 60.0, 2.0, + 60.0, 1.0, + 55.0, 2.0, ]; // ReduceSum: sum all @@ -236,12 +237,12 @@ fn reduce() { // ONNX test case: do_not_keepdims with ArgMax test_reduce( - &data2, + &data_two, &[3, 2, 2], Some(vec![1]), "ArgMax", false, - &[1., 2., 1., 1., 2., 1.], + &[0., 1., 1., 1., 0., 1.], &[3, 2], );