diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index bed2b865..56f2b09d 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -299,7 +299,7 @@ pub fn compile( op @ ("ReduceMean" | "ReduceSum" | "ReduceMax" | "ReduceMin" | "ReduceProd" | "ReduceL1" | "ReduceL2" | "ReduceLogSum" | "ReduceLogSumExp" - | "ReduceSumSquare") => { + | "ReduceSumSquare" | "ArgMax") => { let all_axes: Vec = (0..(i_dims[0].len() as i64)).collect(); let axes: Vec = node .get_attribute_value("axes", Some(all_axes))? diff --git a/wonnx/templates/pool/reduce.wgsl b/wonnx/templates/pool/reduce.wgsl index d4edfa1e..93403763 100644 --- a/wonnx/templates/pool/reduce.wgsl +++ b/wonnx/templates/pool/reduce.wgsl @@ -40,7 +40,9 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { Now for each reduced axis, iterate all values and reduce. Note, starting value may not always be zero. For ReduceMin/Max we should initialize as NaN and keep a flag to check if we have seen at least one element -#} - var accumulator = {% if op_type == "ReduceProd" %} {{ scalar_type }}(1) {% else %} Scalar() {% endif %}; + var accumulator = {% if op_type == "ReduceProd" %} {{ scalar_type }}(1) {% else %} Scalar() {% endif %}; + var max_element: Scalar = log(Scalar()); + var count = 0u; {% for reducing_axis in axes %} @@ -79,6 +81,11 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { else if(accumulator < input_val) { accumulator = input_val; } + {% elif op_type == "ArgMax" %} + if(input_val > max_element) { + max_element = input_val; + accumulator = f32(count); + } {% endif %} count = count + 1u; diff --git a/wonnx/tests/reduce.rs b/wonnx/tests/reduce.rs index e43308c2..b26bb6d1 100644 --- a/wonnx/tests/reduce.rs +++ b/wonnx/tests/reduce.rs @@ -62,6 +62,18 @@ fn reduce() { 60.0, 2.0, ]; + #[rustfmt::skip] + let data_two = [ + 20.0, 1.0, + 5.0, 2.0, + + 30.0, 1.0, + 40.0, 2.0, + + 60.0, 1.0, + 55.0, 2.0, + ]; + // ReduceSum: sum all test_reduce( &data, @@ -223,6 +235,17 @@ fn reduce() { &[3, 2], ); + // ONNX test case: do_not_keepdims with ArgMax + test_reduce( + &data_two, + &[3, 2, 2], + Some(vec![1]), + "ArgMax", + false, + &[0., 1., 1., 1., 0., 1.], + &[3, 2], + ); + // ONNX test case for ReduceSumSquare (https://github.com/onnx/onnx/blob/94e2f64551ded652df53a7e9111031e8aabddaee/onnx/backend/test/case/node/reducesumsquare.py#L27) test_reduce( &[1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],