diff --git a/README.md b/README.md
index 5c9b430d..2330ea99 100644
--- a/README.md
+++ b/README.md
@@ -237,7 +237,7 @@ fn test_matmul_square_matrix() {
|ConstantOfShape|9|✅|✅|
|Conv|11, 1|✅|
|ConvInteger|10|
-|ConvTranspose|11, 1|
+|ConvTranspose|11, 1|✅|
|Cos|7|✅|✅|
|Cosh|9|✅|✅|
|CumSum|14, 11|
diff --git a/wonnx-py/tests/test_onnx_backend.py b/wonnx-py/tests/test_onnx_backend.py
index a0dd935a..832b49a6 100644
--- a/wonnx-py/tests/test_onnx_backend.py
+++ b/wonnx-py/tests/test_onnx_backend.py
@@ -197,6 +197,11 @@ def do_enforce_test_coverage_safelist(model): # type: (ModelProto) -> bool
backend_test.include(f"test_softmax_negative_axis_cpu$")
backend_test.include(f"test_softmax_default_axis_cpu$")
+# ConvTranspose
+# We only have partial attribute support right now, so we hand select a few test cases limited to the supported ones
+backend_test.include(f"test_convtranspose$")
+backend_test.include(f"test_convtranspose_pads$")
+
globals().update(backend_test.enable_report().test_cases)
if __name__ == "__main__":
diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs
index ba757c20..c6870b4b 100644
--- a/wonnx/src/compiler.rs
+++ b/wonnx/src/compiler.rs
@@ -135,6 +135,11 @@ lazy_static! {
include_str!("../templates/endomorphism/broadcast.wgsl"),
)
.unwrap();
+ tera.add_raw_template(
+ "unpool/convtranspose.wgsl",
+ include_str!("../templates/unpool/convtranspose.wgsl"),
+ )
+ .unwrap();
tera
};
}
@@ -197,6 +202,12 @@ pub enum CompileError {
input_shapes: Vec,
output_shape: Shape,
},
+
+ #[error("output shape mismatch")]
+ UnexpectedOutputShape {
+ output_shape: Shape,
+ expected_shape: Vec,
+ },
}
struct NodeTemplate {
@@ -1384,6 +1395,139 @@ pub fn compile(
threads: (ceil(output_lengths[0], 256) as _, 1, 1),
}
}
+ "ConvTranspose" => {
+ // FIXME: Add support for dilations, group, output_padding, pads, auto_pad,
+ // 1d and 3d inputs
+
+ // Only support 2D input at the moment
+ if input_shapes[0].rank() != 4 {
+ return Err(CompileError::InvalidInputShape {
+ input_index: 0,
+ input_shape: input_shapes[0].clone(),
+ });
+ }
+
+ let input_height = input_shapes[0].dims[2] as i64;
+ let input_width = input_shapes[0].dims[3] as i64;
+
+ let dilations = node.get_attribute_value("dilations", Some(vec![1, 1]))?;
+ if dilations != vec![1, 1] {
+ return Err(CompileError::InvalidAttributeValue {
+ attribute: "dilations".into(),
+ value: format!("{:?}", dilations),
+ opset_version,
+ });
+ }
+
+ let group = node.get_attribute_value("group", Some(1))?;
+ if group != 1 {
+ return Err(CompileError::InvalidAttributeValue {
+ attribute: "group".into(),
+ value: group.to_string(),
+ opset_version,
+ });
+ }
+
+ let inferred_kernel_shape = input_shapes[1]
+ .dims
+ .iter()
+ .skip(2)
+ .map(|&x| x as i64)
+ .collect::>();
+
+ let kernel_shape =
+ node.get_attribute_value("kernel_shape", Some(inferred_kernel_shape.clone()))?;
+ if inferred_kernel_shape != kernel_shape {
+ log::error!("Inferred kernel shape: {:?}", inferred_kernel_shape);
+ return Err(CompileError::InvalidAttributeValue {
+ attribute: "kernel_shape".to_string(),
+ value: format!("{:?}", kernel_shape),
+ opset_version,
+ });
+ }
+
+ let output_padding = node.get_attribute_value("output_padding", Some(vec![0, 0]))?;
+ if output_padding != vec![0, 0] {
+ return Err(CompileError::InvalidAttributeValue {
+ attribute: "output_padding".into(),
+ value: format!("{:?}", output_padding),
+ opset_version,
+ });
+ }
+
+ let auto_pad = node.get_attribute_value("auto_pad", Some("NOTSET".to_string()))?;
+ if auto_pad != "NOTSET" {
+ return Err(CompileError::InvalidAttributeValue {
+ attribute: "auto_pad".into(),
+ value: auto_pad,
+ opset_version,
+ });
+ }
+
+ let pads = node.get_attribute_value("pads", Some(vec![0, 0, 0, 0]))?;
+ if pads.iter().any(|pad| *pad < 0) {
+ return Err(CompileError::InvalidAttributeValue {
+ attribute: "pads".into(),
+ value: format!("{:?}", pads),
+ opset_version,
+ });
+ }
+
+ context.insert("pads", &pads);
+
+ let strides = node.get_attribute_value("strides", Some(vec![1, 1]))?;
+ if strides.iter().any(|stride| *stride <= 0) {
+ return Err(CompileError::InvalidAttributeValue {
+ attribute: "strides".into(),
+ value: format!("{:?}", strides),
+ opset_version,
+ });
+ }
+
+ context.insert("stride", &strides);
+
+ let output_height = strides[0] * (input_height - 1)
+ + output_padding[0]
+ + ((kernel_shape[0] - 1) * dilations[0] + 1)
+ - pads[0]
+ - pads[2];
+ let output_width = strides[1] * (input_width - 1)
+ + output_padding[1]
+ + ((kernel_shape[1] - 1) * dilations[1] + 1)
+ - pads[1]
+ - pads[3];
+
+ if output_shapes[0].dim(2) as i64 != output_height
+ || output_shapes[0].dim(3) as i64 != output_width
+ {
+ return Err(CompileError::UnexpectedOutputShape {
+ output_shape: output_shapes[0].clone(),
+ expected_shape: vec![output_height, output_width],
+ });
+ }
+
+ let (x_threads, workgroup_size_x) = workgroup_size(
+ output_lengths[0],
+ MAX_COMPUTE_WORKGROUPS_PER_DIMENSION,
+ MAX_WORKGROUP_SIZE_X,
+ )?;
+ context.insert("workgroup_size_x", &workgroup_size_x);
+
+ let scalar_type = agreed_type(input_shapes, output_shapes)?;
+
+ if scalar_type.is_float() {
+ NodeTemplate {
+ scalar_type,
+ template: "unpool/convtranspose.wgsl",
+ threads: (x_threads, 1, 1),
+ }
+ } else {
+ return Err(CompileError::UnimplementedVariant {
+ variant: "Non-Float".into(),
+ op: "ConvTranspose".into(),
+ });
+ }
+ }
op => return Err(CompileError::UnimplementedOp(op.to_string())),
};
diff --git a/wonnx/templates/unpool/convtranspose.wgsl b/wonnx/templates/unpool/convtranspose.wgsl
new file mode 100644
index 00000000..13d15f4b
--- /dev/null
+++ b/wonnx/templates/unpool/convtranspose.wgsl
@@ -0,0 +1,88 @@
+{%- include "structs.wgsl" -%}
+
+// Input tensor, shape NxCxHxW
+@group(0) @binding(0)
+var input_tensor: Array;
+
+// Kernel weight tensor, shape CxM/groupxkHxkW
+@group(0) @binding(1)
+var input_kernel_weights: Array;
+
+{% if i_lens | length == 3 -%}
+ @group(0) @binding(2)
+ var input_bias: Array;
+
+ @group(0) @binding(3)
+ var output_0: Array;
+{%- else -%}
+ @group(0) @binding(2)
+ var output_0: Array;
+{%- endif %}
+
+{% set input_shape = i_shape[0] %}
+{% set input_chunks = i_chunks[0] %}
+{% set kernel_shape = i_shape[1] %}
+{% set kernel_chunks = i_chunks[1] %}
+
+@compute @workgroup_size({{ workgroup_size_x }}, 1, 1)
+fn main(@builtin(global_invocation_id) global_id: vec3) {
+ let output_idx = global_id.x;
+
+ if (output_idx < {{ o_lens[0] }}u) {
+ // Calculate the output coordinates we are responsible for
+ let batch = output_idx / {{ o_chunks[0][0] }}u;
+ var rest = output_idx % {{ o_chunks[0][0] }}u;
+
+ let channel = rest / {{ o_chunks[0][1] }}u;
+ rest = rest % {{ o_chunks[0][1] }}u;
+
+ let y = rest / {{ o_chunks[0][2] }}u;
+ let x = rest % {{ o_chunks[0][2] }}u;
+
+ // The actual output is a slice of the full output,
+ // where the given padding values are removed on each end.
+ // We don't need to worry about this at the upper coordinate end,
+ // but we need to consider it on the lower end and calculate
+ // virtual output coordinates to calculate the input coordinate range later.
+ let unpadded_y = y + {{ pads[0] }}u;
+ let unpadded_x = x + {{ pads[1] }}u;
+
+ let sample_root_index = batch * {{ input_chunks[0] }}u;
+
+ // Calculate the input coordinate range for our output coordinate
+ let min_in_y = select(0u, (unpadded_y - {{ kernel_shape[2] }}u) / {{ stride[0] }}u, unpadded_y > {{ kernel_shape[2] }}u);
+ let max_in_y = select({{ input_shape[2] }}u - 1u, unpadded_y / {{ stride[0] }}u, unpadded_y / {{ stride[0] }}u < {{ input_shape[3] }}u);
+ let min_in_x = select(0u, (unpadded_x - {{ kernel_shape[3] }}u) / {{ stride[1] }}u, unpadded_x > {{ kernel_shape[3] }}u);
+ let max_in_x = select({{ input_shape[3] }}u - 1u, unpadded_x / {{ stride[1] }}u, unpadded_x / {{ stride[1] }}u < {{ input_shape[3] }}u);
+
+ var result: Scalar = Scalar();
+
+ // Now, go over each input channel and apply the corresponing kernel for that channel
+ // to calculate the output piece by piece.
+ for(var ichannel: u32 = 0u; ichannel < {{ input_shape[1] }}u; ichannel = ichannel + 1u) {
+ // Base index for the 2D data in the input data
+ let base_index = sample_root_index + ichannel * {{ input_chunks[1] }}u;
+ // Get the starting position of the kernel for the given input and output channel
+ let base_kernel_index = ichannel *{{ kernel_chunks[0] }}u + channel * {{ kernel_chunks[1] }}u;
+
+ // Iterate of all potential input values
+ for(var in_y: u32 = min_in_y; in_y <= max_in_y; in_y = in_y + 1u) {
+ for(var in_x: u32 = min_in_x; in_x <= max_in_x; in_x = in_x + 1u) {
+ let kernel_y = unpadded_y - (in_y * {{ stride[0] }}u);
+ let kernel_x = unpadded_x - (in_x * {{ stride[1] }}u);
+
+ if(kernel_y < {{ kernel_shape[2] }}u && kernel_x < {{ kernel_shape[3] }}u) {
+ result = result + (input_tensor.data[base_index + (in_y * {{ input_chunks[2] }}u) + in_x]
+ * input_kernel_weights.data[base_kernel_index + kernel_y * {{ kernel_chunks[2] }}u + kernel_x]);
+ }
+ }
+ }
+ }
+ {% if i_lens | length == 3 -%}
+ // Apply Bias if specified
+ result = result + input_bias.data[channel];
+ {%- endif %}
+
+ output_0.data[output_idx] = result;
+ }
+}
diff --git a/wonnx/tests/convtranspose.rs b/wonnx/tests/convtranspose.rs
new file mode 100644
index 00000000..eb9bcc77
--- /dev/null
+++ b/wonnx/tests/convtranspose.rs
@@ -0,0 +1,133 @@
+use std::collections::HashMap;
+use wonnx::utils::{attribute, graph, initializer, model, node, tensor, OutputTensor};
+mod common;
+
+#[test]
+fn convtranspose_default() {
+ let data: Vec = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
+ let input_shape = vec![1, 1, 3, 3];
+
+ let data_w = vec![
+ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+ ];
+ let kernel_shape = vec![1, 2, 3, 3];
+
+ let output_shape = vec![1, 2, 5, 5];
+
+ let input_data = HashMap::from([("X".to_string(), data.as_slice().into())]);
+
+ let convtranpose_model = model(graph(
+ vec![tensor("X", &input_shape)],
+ vec![tensor("Y", &output_shape)],
+ vec![],
+ vec![initializer("W", data_w, kernel_shape)],
+ vec![node(
+ vec!["X", "W"],
+ vec!["Y"],
+ "convtranspose",
+ "ConvTranspose",
+ vec![attribute("kernel_shape", vec![3, 3])],
+ )],
+ ));
+
+ let session = pollster::block_on(wonnx::Session::from_model(convtranpose_model))
+ .expect("Session did not create");
+ let result = pollster::block_on(session.run(&input_data)).unwrap();
+
+ assert_eq!(
+ result["Y"],
+ OutputTensor::F32(vec![
+ 0.0, 1.0, 3.0, 3.0, 2.0, 3.0, 8.0, 15.0, 12.0, 7.0, 9.0, 21.0, 36.0, 27.0, 15.0, 9.0,
+ 20.0, 33.0, 24.0, 13.0, 6.0, 13.0, 21.0, 15.0, 8.0, 0.0, 1.0, 3.0, 3.0, 2.0, 3.0, 8.0,
+ 15.0, 12.0, 7.0, 9.0, 21.0, 36.0, 27.0, 15.0, 9.0, 20.0, 33.0, 24.0, 13.0, 6.0, 13.0,
+ 21.0, 15.0, 8.0,
+ ])
+ );
+}
+
+#[test]
+fn convtranspose_strides() {
+ let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; // (1, 1, 3, 3)
+ let input_shape = vec![1, 1, 3, 3];
+
+ let data_w = vec![
+ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+ ];
+ let kernel_shape = vec![1, 2, 3, 3];
+
+ let output_data = vec![
+ 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0, 1.0, 1.0,
+ 3.0, 2.0, 2.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 3.0,
+ 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 6.0, 6.0, 13.0, 7.0,
+ 15.0, 8.0, 8.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0,
+ 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0, 3.0, 7.0, 4.0,
+ 9.0, 5.0, 5.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 6.0,
+ 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 6.0, 6.0, 13.0, 7.0,
+ 15.0, 8.0, 8.0,
+ ];
+ let output_shape = vec![1, 2, 9, 7];
+
+ let convtranpose_model = model(graph(
+ vec![tensor("X", &input_shape)],
+ vec![tensor("Y", &output_shape)],
+ vec![],
+ vec![initializer("W", data_w, kernel_shape)],
+ vec![node(
+ vec!["X", "W"],
+ vec!["Y"],
+ "convtranspose",
+ "ConvTranspose",
+ vec![
+ attribute("kernel_shape", vec![3, 3]),
+ attribute("strides", vec![3, 2]),
+ ],
+ )],
+ ));
+
+ let input_data = HashMap::from([("X".to_string(), data.as_slice().into())]);
+ let session = pollster::block_on(wonnx::Session::from_model(convtranpose_model))
+ .expect("Session did not create");
+ let result = pollster::block_on(session.run(&input_data)).unwrap();
+ assert_eq!(result["Y"], OutputTensor::F32(output_data));
+}
+
+#[test]
+fn convtranspose_pads() {
+ let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
+ let input_shape = vec![1, 1, 3, 3];
+
+ let data_w = vec![
+ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
+ ];
+ let kernel_shape = vec![1, 2, 3, 3];
+
+ let output_data = vec![
+ 1.0, 1.0, 3.0, 1.0, 1.0, 3.0, 7.0, 4.0, 9.0, 7.0, 4.0, 9.0, 7.0, 4.0, 9.0, 13.0, 7.0, 15.0,
+ 13.0, 7.0, 15.0, 1.0, 1.0, 3.0, 1.0, 1.0, 3.0, 7.0, 4.0, 9.0, 7.0, 4.0, 9.0, 7.0, 4.0, 9.0,
+ 13.0, 7.0, 15.0, 13.0, 7.0, 15.0,
+ ];
+ let output_shape = vec![1, 2, 7, 3];
+
+ let convtranpose_model = model(graph(
+ vec![tensor("X", &input_shape)],
+ vec![tensor("Y", &output_shape)],
+ vec![],
+ vec![initializer("W", data_w, kernel_shape)],
+ vec![node(
+ vec!["X", "W"],
+ vec!["Y"],
+ "convtranspose",
+ "ConvTranspose",
+ vec![
+ attribute("strides", vec![3, 2]),
+ attribute("pads", vec![1, 2, 1, 2]),
+ ],
+ )],
+ ));
+
+ let input_data = HashMap::from([("X".to_string(), data.as_slice().into())]);
+ let session = pollster::block_on(wonnx::Session::from_model(convtranpose_model))
+ .expect("Session did not create");
+ let result = pollster::block_on(session.run(&input_data)).unwrap();
+ assert_eq!(result["Y"], OutputTensor::F32(output_data));
+}