Skip to content
This repository was archived by the owner on May 7, 2025. It is now read-only.

Implement the ConvTranspose operation #182

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ fn test_matmul_square_matrix() {
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConstantOfShape">ConstantOfShape</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ConstantOfShape-9">9</a>|✅|✅|
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv">Conv</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Conv-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Conv-1">1</a>|✅|
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvInteger">ConvInteger</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ConvInteger-10">10</a>|
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose">ConvTranspose</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ConvTranspose-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ConvTranspose-1">1</a>|
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose">ConvTranspose</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ConvTranspose-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ConvTranspose-1">1</a>|✅|
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cos">Cos</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cos-7">7</a>|✅|✅|
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cosh">Cosh</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cosh-9">9</a>|✅|✅|
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum">CumSum</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#CumSum-14">14</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#CumSum-11">11</a>|
Expand Down
5 changes: 5 additions & 0 deletions wonnx-py/tests/test_onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ def do_enforce_test_coverage_safelist(model): # type: (ModelProto) -> bool
backend_test.include(f"test_softmax_negative_axis_cpu$")
backend_test.include(f"test_softmax_default_axis_cpu$")

# ConvTranspose
# We only have partial attribute support right now, so we hand select a few test cases limited to the supported ones
backend_test.include(f"test_convtranspose$")
backend_test.include(f"test_convtranspose_pads$")

globals().update(backend_test.enable_report().test_cases)

if __name__ == "__main__":
Expand Down
144 changes: 144 additions & 0 deletions wonnx/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ lazy_static! {
include_str!("../templates/endomorphism/broadcast.wgsl"),
)
.unwrap();
tera.add_raw_template(
"unpool/convtranspose.wgsl",
include_str!("../templates/unpool/convtranspose.wgsl"),
)
.unwrap();
tera
};
}
Expand Down Expand Up @@ -197,6 +202,12 @@ pub enum CompileError {
input_shapes: Vec<Shape>,
output_shape: Shape,
},

#[error("output shape mismatch")]
UnexpectedOutputShape {
output_shape: Shape,
expected_shape: Vec<i64>,
},
}

struct NodeTemplate {
Expand Down Expand Up @@ -1384,6 +1395,139 @@ pub fn compile(
threads: (ceil(output_lengths[0], 256) as _, 1, 1),
}
}
"ConvTranspose" => {
// FIXME: Add support for dilations, group, output_padding, pads, auto_pad,
// 1d and 3d inputs

// Only support 2D input at the moment
if input_shapes[0].rank() != 4 {
return Err(CompileError::InvalidInputShape {
input_index: 0,
input_shape: input_shapes[0].clone(),
});
}

let input_height = input_shapes[0].dims[2] as i64;
let input_width = input_shapes[0].dims[3] as i64;

let dilations = node.get_attribute_value("dilations", Some(vec![1, 1]))?;
if dilations != vec![1, 1] {
return Err(CompileError::InvalidAttributeValue {
attribute: "dilations".into(),
value: format!("{:?}", dilations),
opset_version,
});
}

let group = node.get_attribute_value("group", Some(1))?;
if group != 1 {
return Err(CompileError::InvalidAttributeValue {
attribute: "group".into(),
value: group.to_string(),
opset_version,
});
}

let inferred_kernel_shape = input_shapes[1]
.dims
.iter()
.skip(2)
.map(|&x| x as i64)
.collect::<Vec<_>>();

let kernel_shape =
node.get_attribute_value("kernel_shape", Some(inferred_kernel_shape.clone()))?;
if inferred_kernel_shape != kernel_shape {
log::error!("Inferred kernel shape: {:?}", inferred_kernel_shape);
return Err(CompileError::InvalidAttributeValue {
attribute: "kernel_shape".to_string(),
value: format!("{:?}", kernel_shape),
opset_version,
});
}

let output_padding = node.get_attribute_value("output_padding", Some(vec![0, 0]))?;
if output_padding != vec![0, 0] {
return Err(CompileError::InvalidAttributeValue {
attribute: "output_padding".into(),
value: format!("{:?}", output_padding),
opset_version,
});
}

let auto_pad = node.get_attribute_value("auto_pad", Some("NOTSET".to_string()))?;
if auto_pad != "NOTSET" {
return Err(CompileError::InvalidAttributeValue {
attribute: "auto_pad".into(),
value: auto_pad,
opset_version,
});
}

let pads = node.get_attribute_value("pads", Some(vec![0, 0, 0, 0]))?;
if pads.iter().any(|pad| *pad < 0) {
return Err(CompileError::InvalidAttributeValue {
attribute: "pads".into(),
value: format!("{:?}", pads),
opset_version,
});
}

context.insert("pads", &pads);

let strides = node.get_attribute_value("strides", Some(vec![1, 1]))?;
if strides.iter().any(|stride| *stride <= 0) {
return Err(CompileError::InvalidAttributeValue {
attribute: "strides".into(),
value: format!("{:?}", strides),
opset_version,
});
}

context.insert("stride", &strides);

let output_height = strides[0] * (input_height - 1)
+ output_padding[0]
+ ((kernel_shape[0] - 1) * dilations[0] + 1)
- pads[0]
- pads[2];
let output_width = strides[1] * (input_width - 1)
+ output_padding[1]
+ ((kernel_shape[1] - 1) * dilations[1] + 1)
- pads[1]
- pads[3];

if output_shapes[0].dim(2) as i64 != output_height
|| output_shapes[0].dim(3) as i64 != output_width
{
return Err(CompileError::UnexpectedOutputShape {
output_shape: output_shapes[0].clone(),
expected_shape: vec![output_height, output_width],
});
}

let (x_threads, workgroup_size_x) = workgroup_size(
output_lengths[0],
MAX_COMPUTE_WORKGROUPS_PER_DIMENSION,
MAX_WORKGROUP_SIZE_X,
)?;
context.insert("workgroup_size_x", &workgroup_size_x);

let scalar_type = agreed_type(input_shapes, output_shapes)?;

if scalar_type.is_float() {
NodeTemplate {
scalar_type,
template: "unpool/convtranspose.wgsl",
threads: (x_threads, 1, 1),
}
} else {
return Err(CompileError::UnimplementedVariant {
variant: "Non-Float".into(),
op: "ConvTranspose".into(),
});
}
}
op => return Err(CompileError::UnimplementedOp(op.to_string())),
};

Expand Down
88 changes: 88 additions & 0 deletions wonnx/templates/unpool/convtranspose.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
{%- include "structs.wgsl" -%}

// Input tensor, shape NxCxHxW
@group(0) @binding(0)
var<storage, read> input_tensor: Array;

// Kernel weight tensor, shape CxM/groupxkHxkW
@group(0) @binding(1)
var<storage, read> input_kernel_weights: Array;

{% if i_lens | length == 3 -%}
@group(0) @binding(2)
var<storage, read> input_bias: Array;

@group(0) @binding(3)
var<storage, read_write> output_0: Array;
{%- else -%}
@group(0) @binding(2)
var<storage, read_write> output_0: Array;
{%- endif %}

{% set input_shape = i_shape[0] %}
{% set input_chunks = i_chunks[0] %}
{% set kernel_shape = i_shape[1] %}
{% set kernel_chunks = i_chunks[1] %}

@compute @workgroup_size({{ workgroup_size_x }}, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let output_idx = global_id.x;

if (output_idx < {{ o_lens[0] }}u) {
// Calculate the output coordinates we are responsible for
let batch = output_idx / {{ o_chunks[0][0] }}u;
var rest = output_idx % {{ o_chunks[0][0] }}u;

let channel = rest / {{ o_chunks[0][1] }}u;
rest = rest % {{ o_chunks[0][1] }}u;

let y = rest / {{ o_chunks[0][2] }}u;
let x = rest % {{ o_chunks[0][2] }}u;

// The actual output is a slice of the full output,
// where the given padding values are removed on each end.
// We don't need to worry about this at the upper coordinate end,
// but we need to consider it on the lower end and calculate
// virtual output coordinates to calculate the input coordinate range later.
let unpadded_y = y + {{ pads[0] }}u;
let unpadded_x = x + {{ pads[1] }}u;

let sample_root_index = batch * {{ input_chunks[0] }}u;

// Calculate the input coordinate range for our output coordinate
let min_in_y = select(0u, (unpadded_y - {{ kernel_shape[2] }}u) / {{ stride[0] }}u, unpadded_y > {{ kernel_shape[2] }}u);
let max_in_y = select({{ input_shape[2] }}u - 1u, unpadded_y / {{ stride[0] }}u, unpadded_y / {{ stride[0] }}u < {{ input_shape[3] }}u);
let min_in_x = select(0u, (unpadded_x - {{ kernel_shape[3] }}u) / {{ stride[1] }}u, unpadded_x > {{ kernel_shape[3] }}u);
let max_in_x = select({{ input_shape[3] }}u - 1u, unpadded_x / {{ stride[1] }}u, unpadded_x / {{ stride[1] }}u < {{ input_shape[3] }}u);

var result: Scalar = Scalar();

// Now, go over each input channel and apply the corresponing kernel for that channel
// to calculate the output piece by piece.
for(var ichannel: u32 = 0u; ichannel < {{ input_shape[1] }}u; ichannel = ichannel + 1u) {
// Base index for the 2D data in the input data
let base_index = sample_root_index + ichannel * {{ input_chunks[1] }}u;
// Get the starting position of the kernel for the given input and output channel
let base_kernel_index = ichannel *{{ kernel_chunks[0] }}u + channel * {{ kernel_chunks[1] }}u;

// Iterate of all potential input values
for(var in_y: u32 = min_in_y; in_y <= max_in_y; in_y = in_y + 1u) {
for(var in_x: u32 = min_in_x; in_x <= max_in_x; in_x = in_x + 1u) {
let kernel_y = unpadded_y - (in_y * {{ stride[0] }}u);
let kernel_x = unpadded_x - (in_x * {{ stride[1] }}u);

if(kernel_y < {{ kernel_shape[2] }}u && kernel_x < {{ kernel_shape[3] }}u) {
result = result + (input_tensor.data[base_index + (in_y * {{ input_chunks[2] }}u) + in_x]
* input_kernel_weights.data[base_kernel_index + kernel_y * {{ kernel_chunks[2] }}u + kernel_x]);
}
}
}
}
{% if i_lens | length == 3 -%}
// Apply Bias if specified
result = result + input_bias.data[channel];
{%- endif %}

output_0.data[output_idx] = result;
}
}
Loading