From 5caf15b83611118368cce7a451314a2f8d13e623 Mon Sep 17 00:00:00 2001 From: Johannes May Date: Wed, 2 Aug 2023 22:37:41 +0200 Subject: [PATCH 1/2] Implement the ConvTranspose operation --- README.md | 2 +- wonnx/src/compiler.rs | 144 ++++++++++++++++++++++ wonnx/templates/unpool/convtranspose.wgsl | 88 +++++++++++++ wonnx/tests/convtranspose.rs | 133 ++++++++++++++++++++ 4 files changed, 366 insertions(+), 1 deletion(-) create mode 100644 wonnx/templates/unpool/convtranspose.wgsl create mode 100644 wonnx/tests/convtranspose.rs diff --git a/README.md b/README.md index 5c9b430d..2330ea99 100644 --- a/README.md +++ b/README.md @@ -237,7 +237,7 @@ fn test_matmul_square_matrix() { |ConstantOfShape|9|✅|✅| |Conv|11, 1|✅| |ConvInteger|10| -|ConvTranspose|11, 1| +|ConvTranspose|11, 1|✅| |Cos|7|✅|✅| |Cosh|9|✅|✅| |CumSum|14, 11| diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index ba757c20..c6870b4b 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -135,6 +135,11 @@ lazy_static! { include_str!("../templates/endomorphism/broadcast.wgsl"), ) .unwrap(); + tera.add_raw_template( + "unpool/convtranspose.wgsl", + include_str!("../templates/unpool/convtranspose.wgsl"), + ) + .unwrap(); tera }; } @@ -197,6 +202,12 @@ pub enum CompileError { input_shapes: Vec, output_shape: Shape, }, + + #[error("output shape mismatch")] + UnexpectedOutputShape { + output_shape: Shape, + expected_shape: Vec, + }, } struct NodeTemplate { @@ -1384,6 +1395,139 @@ pub fn compile( threads: (ceil(output_lengths[0], 256) as _, 1, 1), } } + "ConvTranspose" => { + // FIXME: Add support for dilations, group, output_padding, pads, auto_pad, + // 1d and 3d inputs + + // Only support 2D input at the moment + if input_shapes[0].rank() != 4 { + return Err(CompileError::InvalidInputShape { + input_index: 0, + input_shape: input_shapes[0].clone(), + }); + } + + let input_height = input_shapes[0].dims[2] as i64; + let input_width = input_shapes[0].dims[3] as i64; + + let dilations = node.get_attribute_value("dilations", Some(vec![1, 1]))?; + if dilations != vec![1, 1] { + return Err(CompileError::InvalidAttributeValue { + attribute: "dilations".into(), + value: format!("{:?}", dilations), + opset_version, + }); + } + + let group = node.get_attribute_value("group", Some(1))?; + if group != 1 { + return Err(CompileError::InvalidAttributeValue { + attribute: "group".into(), + value: group.to_string(), + opset_version, + }); + } + + let inferred_kernel_shape = input_shapes[1] + .dims + .iter() + .skip(2) + .map(|&x| x as i64) + .collect::>(); + + let kernel_shape = + node.get_attribute_value("kernel_shape", Some(inferred_kernel_shape.clone()))?; + if inferred_kernel_shape != kernel_shape { + log::error!("Inferred kernel shape: {:?}", inferred_kernel_shape); + return Err(CompileError::InvalidAttributeValue { + attribute: "kernel_shape".to_string(), + value: format!("{:?}", kernel_shape), + opset_version, + }); + } + + let output_padding = node.get_attribute_value("output_padding", Some(vec![0, 0]))?; + if output_padding != vec![0, 0] { + return Err(CompileError::InvalidAttributeValue { + attribute: "output_padding".into(), + value: format!("{:?}", output_padding), + opset_version, + }); + } + + let auto_pad = node.get_attribute_value("auto_pad", Some("NOTSET".to_string()))?; + if auto_pad != "NOTSET" { + return Err(CompileError::InvalidAttributeValue { + attribute: "auto_pad".into(), + value: auto_pad, + opset_version, + }); + } + + let pads = node.get_attribute_value("pads", Some(vec![0, 0, 0, 0]))?; + if pads.iter().any(|pad| *pad < 0) { + return Err(CompileError::InvalidAttributeValue { + attribute: "pads".into(), + value: format!("{:?}", pads), + opset_version, + }); + } + + context.insert("pads", &pads); + + let strides = node.get_attribute_value("strides", Some(vec![1, 1]))?; + if strides.iter().any(|stride| *stride <= 0) { + return Err(CompileError::InvalidAttributeValue { + attribute: "strides".into(), + value: format!("{:?}", strides), + opset_version, + }); + } + + context.insert("stride", &strides); + + let output_height = strides[0] * (input_height - 1) + + output_padding[0] + + ((kernel_shape[0] - 1) * dilations[0] + 1) + - pads[0] + - pads[2]; + let output_width = strides[1] * (input_width - 1) + + output_padding[1] + + ((kernel_shape[1] - 1) * dilations[1] + 1) + - pads[1] + - pads[3]; + + if output_shapes[0].dim(2) as i64 != output_height + || output_shapes[0].dim(3) as i64 != output_width + { + return Err(CompileError::UnexpectedOutputShape { + output_shape: output_shapes[0].clone(), + expected_shape: vec![output_height, output_width], + }); + } + + let (x_threads, workgroup_size_x) = workgroup_size( + output_lengths[0], + MAX_COMPUTE_WORKGROUPS_PER_DIMENSION, + MAX_WORKGROUP_SIZE_X, + )?; + context.insert("workgroup_size_x", &workgroup_size_x); + + let scalar_type = agreed_type(input_shapes, output_shapes)?; + + if scalar_type.is_float() { + NodeTemplate { + scalar_type, + template: "unpool/convtranspose.wgsl", + threads: (x_threads, 1, 1), + } + } else { + return Err(CompileError::UnimplementedVariant { + variant: "Non-Float".into(), + op: "ConvTranspose".into(), + }); + } + } op => return Err(CompileError::UnimplementedOp(op.to_string())), }; diff --git a/wonnx/templates/unpool/convtranspose.wgsl b/wonnx/templates/unpool/convtranspose.wgsl new file mode 100644 index 00000000..13d15f4b --- /dev/null +++ b/wonnx/templates/unpool/convtranspose.wgsl @@ -0,0 +1,88 @@ +{%- include "structs.wgsl" -%} + +// Input tensor, shape NxCxHxW +@group(0) @binding(0) +var input_tensor: Array; + +// Kernel weight tensor, shape CxM/groupxkHxkW +@group(0) @binding(1) +var input_kernel_weights: Array; + +{% if i_lens | length == 3 -%} + @group(0) @binding(2) + var input_bias: Array; + + @group(0) @binding(3) + var output_0: Array; +{%- else -%} + @group(0) @binding(2) + var output_0: Array; +{%- endif %} + +{% set input_shape = i_shape[0] %} +{% set input_chunks = i_chunks[0] %} +{% set kernel_shape = i_shape[1] %} +{% set kernel_chunks = i_chunks[1] %} + +@compute @workgroup_size({{ workgroup_size_x }}, 1, 1) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let output_idx = global_id.x; + + if (output_idx < {{ o_lens[0] }}u) { + // Calculate the output coordinates we are responsible for + let batch = output_idx / {{ o_chunks[0][0] }}u; + var rest = output_idx % {{ o_chunks[0][0] }}u; + + let channel = rest / {{ o_chunks[0][1] }}u; + rest = rest % {{ o_chunks[0][1] }}u; + + let y = rest / {{ o_chunks[0][2] }}u; + let x = rest % {{ o_chunks[0][2] }}u; + + // The actual output is a slice of the full output, + // where the given padding values are removed on each end. + // We don't need to worry about this at the upper coordinate end, + // but we need to consider it on the lower end and calculate + // virtual output coordinates to calculate the input coordinate range later. + let unpadded_y = y + {{ pads[0] }}u; + let unpadded_x = x + {{ pads[1] }}u; + + let sample_root_index = batch * {{ input_chunks[0] }}u; + + // Calculate the input coordinate range for our output coordinate + let min_in_y = select(0u, (unpadded_y - {{ kernel_shape[2] }}u) / {{ stride[0] }}u, unpadded_y > {{ kernel_shape[2] }}u); + let max_in_y = select({{ input_shape[2] }}u - 1u, unpadded_y / {{ stride[0] }}u, unpadded_y / {{ stride[0] }}u < {{ input_shape[3] }}u); + let min_in_x = select(0u, (unpadded_x - {{ kernel_shape[3] }}u) / {{ stride[1] }}u, unpadded_x > {{ kernel_shape[3] }}u); + let max_in_x = select({{ input_shape[3] }}u - 1u, unpadded_x / {{ stride[1] }}u, unpadded_x / {{ stride[1] }}u < {{ input_shape[3] }}u); + + var result: Scalar = Scalar(); + + // Now, go over each input channel and apply the corresponing kernel for that channel + // to calculate the output piece by piece. + for(var ichannel: u32 = 0u; ichannel < {{ input_shape[1] }}u; ichannel = ichannel + 1u) { + // Base index for the 2D data in the input data + let base_index = sample_root_index + ichannel * {{ input_chunks[1] }}u; + // Get the starting position of the kernel for the given input and output channel + let base_kernel_index = ichannel *{{ kernel_chunks[0] }}u + channel * {{ kernel_chunks[1] }}u; + + // Iterate of all potential input values + for(var in_y: u32 = min_in_y; in_y <= max_in_y; in_y = in_y + 1u) { + for(var in_x: u32 = min_in_x; in_x <= max_in_x; in_x = in_x + 1u) { + let kernel_y = unpadded_y - (in_y * {{ stride[0] }}u); + let kernel_x = unpadded_x - (in_x * {{ stride[1] }}u); + + if(kernel_y < {{ kernel_shape[2] }}u && kernel_x < {{ kernel_shape[3] }}u) { + result = result + (input_tensor.data[base_index + (in_y * {{ input_chunks[2] }}u) + in_x] + * input_kernel_weights.data[base_kernel_index + kernel_y * {{ kernel_chunks[2] }}u + kernel_x]); + } + } + } + } + {% if i_lens | length == 3 -%} + // Apply Bias if specified + result = result + input_bias.data[channel]; + {%- endif %} + + output_0.data[output_idx] = result; + } +} diff --git a/wonnx/tests/convtranspose.rs b/wonnx/tests/convtranspose.rs new file mode 100644 index 00000000..eb9bcc77 --- /dev/null +++ b/wonnx/tests/convtranspose.rs @@ -0,0 +1,133 @@ +use std::collections::HashMap; +use wonnx::utils::{attribute, graph, initializer, model, node, tensor, OutputTensor}; +mod common; + +#[test] +fn convtranspose_default() { + let data: Vec = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let input_shape = vec![1, 1, 3, 3]; + + let data_w = vec![ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + ]; + let kernel_shape = vec![1, 2, 3, 3]; + + let output_shape = vec![1, 2, 5, 5]; + + let input_data = HashMap::from([("X".to_string(), data.as_slice().into())]); + + let convtranpose_model = model(graph( + vec![tensor("X", &input_shape)], + vec![tensor("Y", &output_shape)], + vec![], + vec![initializer("W", data_w, kernel_shape)], + vec![node( + vec!["X", "W"], + vec!["Y"], + "convtranspose", + "ConvTranspose", + vec![attribute("kernel_shape", vec![3, 3])], + )], + )); + + let session = pollster::block_on(wonnx::Session::from_model(convtranpose_model)) + .expect("Session did not create"); + let result = pollster::block_on(session.run(&input_data)).unwrap(); + + assert_eq!( + result["Y"], + OutputTensor::F32(vec![ + 0.0, 1.0, 3.0, 3.0, 2.0, 3.0, 8.0, 15.0, 12.0, 7.0, 9.0, 21.0, 36.0, 27.0, 15.0, 9.0, + 20.0, 33.0, 24.0, 13.0, 6.0, 13.0, 21.0, 15.0, 8.0, 0.0, 1.0, 3.0, 3.0, 2.0, 3.0, 8.0, + 15.0, 12.0, 7.0, 9.0, 21.0, 36.0, 27.0, 15.0, 9.0, 20.0, 33.0, 24.0, 13.0, 6.0, 13.0, + 21.0, 15.0, 8.0, + ]) + ); +} + +#[test] +fn convtranspose_strides() { + let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; // (1, 1, 3, 3) + let input_shape = vec![1, 1, 3, 3]; + + let data_w = vec![ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + ]; + let kernel_shape = vec![1, 2, 3, 3]; + + let output_data = vec![ + 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0, 1.0, 1.0, + 3.0, 2.0, 2.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 3.0, + 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 6.0, 6.0, 13.0, 7.0, + 15.0, 8.0, 8.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, + 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0, 3.0, 7.0, 4.0, + 9.0, 5.0, 5.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 6.0, + 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 6.0, 6.0, 13.0, 7.0, + 15.0, 8.0, 8.0, + ]; + let output_shape = vec![1, 2, 9, 7]; + + let convtranpose_model = model(graph( + vec![tensor("X", &input_shape)], + vec![tensor("Y", &output_shape)], + vec![], + vec![initializer("W", data_w, kernel_shape)], + vec![node( + vec!["X", "W"], + vec!["Y"], + "convtranspose", + "ConvTranspose", + vec![ + attribute("kernel_shape", vec![3, 3]), + attribute("strides", vec![3, 2]), + ], + )], + )); + + let input_data = HashMap::from([("X".to_string(), data.as_slice().into())]); + let session = pollster::block_on(wonnx::Session::from_model(convtranpose_model)) + .expect("Session did not create"); + let result = pollster::block_on(session.run(&input_data)).unwrap(); + assert_eq!(result["Y"], OutputTensor::F32(output_data)); +} + +#[test] +fn convtranspose_pads() { + let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let input_shape = vec![1, 1, 3, 3]; + + let data_w = vec![ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + ]; + let kernel_shape = vec![1, 2, 3, 3]; + + let output_data = vec![ + 1.0, 1.0, 3.0, 1.0, 1.0, 3.0, 7.0, 4.0, 9.0, 7.0, 4.0, 9.0, 7.0, 4.0, 9.0, 13.0, 7.0, 15.0, + 13.0, 7.0, 15.0, 1.0, 1.0, 3.0, 1.0, 1.0, 3.0, 7.0, 4.0, 9.0, 7.0, 4.0, 9.0, 7.0, 4.0, 9.0, + 13.0, 7.0, 15.0, 13.0, 7.0, 15.0, + ]; + let output_shape = vec![1, 2, 7, 3]; + + let convtranpose_model = model(graph( + vec![tensor("X", &input_shape)], + vec![tensor("Y", &output_shape)], + vec![], + vec![initializer("W", data_w, kernel_shape)], + vec![node( + vec!["X", "W"], + vec!["Y"], + "convtranspose", + "ConvTranspose", + vec![ + attribute("strides", vec![3, 2]), + attribute("pads", vec![1, 2, 1, 2]), + ], + )], + )); + + let input_data = HashMap::from([("X".to_string(), data.as_slice().into())]); + let session = pollster::block_on(wonnx::Session::from_model(convtranpose_model)) + .expect("Session did not create"); + let result = pollster::block_on(session.run(&input_data)).unwrap(); + assert_eq!(result["Y"], OutputTensor::F32(output_data)); +} From cc0461624fd64168cc885df3a4751642eb245d91 Mon Sep 17 00:00:00 2001 From: Johannes May Date: Sat, 12 Aug 2023 21:15:39 +0200 Subject: [PATCH 2/2] Enable some ConvTranspose backend tests --- wonnx-py/tests/test_onnx_backend.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/wonnx-py/tests/test_onnx_backend.py b/wonnx-py/tests/test_onnx_backend.py index a0dd935a..832b49a6 100644 --- a/wonnx-py/tests/test_onnx_backend.py +++ b/wonnx-py/tests/test_onnx_backend.py @@ -197,6 +197,11 @@ def do_enforce_test_coverage_safelist(model): # type: (ModelProto) -> bool backend_test.include(f"test_softmax_negative_axis_cpu$") backend_test.include(f"test_softmax_default_axis_cpu$") +# ConvTranspose +# We only have partial attribute support right now, so we hand select a few test cases limited to the supported ones +backend_test.include(f"test_convtranspose$") +backend_test.include(f"test_convtranspose_pads$") + globals().update(backend_test.enable_report().test_cases) if __name__ == "__main__":