Skip to content
This repository was archived by the owner on May 7, 2025. It is now read-only.

Commit 5caf15b

Browse files
committed
Implement the ConvTranspose operation
1 parent fbb7ab1 commit 5caf15b

File tree

4 files changed

+366
-1
lines changed

4 files changed

+366
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ fn test_matmul_square_matrix() {
237237
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConstantOfShape">ConstantOfShape</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ConstantOfShape-9">9</a>|||
238238
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv">Conv</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Conv-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Conv-1">1</a>||
239239
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvInteger">ConvInteger</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ConvInteger-10">10</a>|
240-
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose">ConvTranspose</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ConvTranspose-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ConvTranspose-1">1</a>|
240+
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose">ConvTranspose</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ConvTranspose-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ConvTranspose-1">1</a>||
241241
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cos">Cos</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cos-7">7</a>|||
242242
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cosh">Cosh</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cosh-9">9</a>|||
243243
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum">CumSum</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#CumSum-14">14</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#CumSum-11">11</a>|

wonnx/src/compiler.rs

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,11 @@ lazy_static! {
135135
include_str!("../templates/endomorphism/broadcast.wgsl"),
136136
)
137137
.unwrap();
138+
tera.add_raw_template(
139+
"unpool/convtranspose.wgsl",
140+
include_str!("../templates/unpool/convtranspose.wgsl"),
141+
)
142+
.unwrap();
138143
tera
139144
};
140145
}
@@ -197,6 +202,12 @@ pub enum CompileError {
197202
input_shapes: Vec<Shape>,
198203
output_shape: Shape,
199204
},
205+
206+
#[error("output shape mismatch")]
207+
UnexpectedOutputShape {
208+
output_shape: Shape,
209+
expected_shape: Vec<i64>,
210+
},
200211
}
201212

202213
struct NodeTemplate {
@@ -1384,6 +1395,139 @@ pub fn compile(
13841395
threads: (ceil(output_lengths[0], 256) as _, 1, 1),
13851396
}
13861397
}
1398+
"ConvTranspose" => {
1399+
// FIXME: Add support for dilations, group, output_padding, pads, auto_pad,
1400+
// 1d and 3d inputs
1401+
1402+
// Only support 2D input at the moment
1403+
if input_shapes[0].rank() != 4 {
1404+
return Err(CompileError::InvalidInputShape {
1405+
input_index: 0,
1406+
input_shape: input_shapes[0].clone(),
1407+
});
1408+
}
1409+
1410+
let input_height = input_shapes[0].dims[2] as i64;
1411+
let input_width = input_shapes[0].dims[3] as i64;
1412+
1413+
let dilations = node.get_attribute_value("dilations", Some(vec![1, 1]))?;
1414+
if dilations != vec![1, 1] {
1415+
return Err(CompileError::InvalidAttributeValue {
1416+
attribute: "dilations".into(),
1417+
value: format!("{:?}", dilations),
1418+
opset_version,
1419+
});
1420+
}
1421+
1422+
let group = node.get_attribute_value("group", Some(1))?;
1423+
if group != 1 {
1424+
return Err(CompileError::InvalidAttributeValue {
1425+
attribute: "group".into(),
1426+
value: group.to_string(),
1427+
opset_version,
1428+
});
1429+
}
1430+
1431+
let inferred_kernel_shape = input_shapes[1]
1432+
.dims
1433+
.iter()
1434+
.skip(2)
1435+
.map(|&x| x as i64)
1436+
.collect::<Vec<_>>();
1437+
1438+
let kernel_shape =
1439+
node.get_attribute_value("kernel_shape", Some(inferred_kernel_shape.clone()))?;
1440+
if inferred_kernel_shape != kernel_shape {
1441+
log::error!("Inferred kernel shape: {:?}", inferred_kernel_shape);
1442+
return Err(CompileError::InvalidAttributeValue {
1443+
attribute: "kernel_shape".to_string(),
1444+
value: format!("{:?}", kernel_shape),
1445+
opset_version,
1446+
});
1447+
}
1448+
1449+
let output_padding = node.get_attribute_value("output_padding", Some(vec![0, 0]))?;
1450+
if output_padding != vec![0, 0] {
1451+
return Err(CompileError::InvalidAttributeValue {
1452+
attribute: "output_padding".into(),
1453+
value: format!("{:?}", output_padding),
1454+
opset_version,
1455+
});
1456+
}
1457+
1458+
let auto_pad = node.get_attribute_value("auto_pad", Some("NOTSET".to_string()))?;
1459+
if auto_pad != "NOTSET" {
1460+
return Err(CompileError::InvalidAttributeValue {
1461+
attribute: "auto_pad".into(),
1462+
value: auto_pad,
1463+
opset_version,
1464+
});
1465+
}
1466+
1467+
let pads = node.get_attribute_value("pads", Some(vec![0, 0, 0, 0]))?;
1468+
if pads.iter().any(|pad| *pad < 0) {
1469+
return Err(CompileError::InvalidAttributeValue {
1470+
attribute: "pads".into(),
1471+
value: format!("{:?}", pads),
1472+
opset_version,
1473+
});
1474+
}
1475+
1476+
context.insert("pads", &pads);
1477+
1478+
let strides = node.get_attribute_value("strides", Some(vec![1, 1]))?;
1479+
if strides.iter().any(|stride| *stride <= 0) {
1480+
return Err(CompileError::InvalidAttributeValue {
1481+
attribute: "strides".into(),
1482+
value: format!("{:?}", strides),
1483+
opset_version,
1484+
});
1485+
}
1486+
1487+
context.insert("stride", &strides);
1488+
1489+
let output_height = strides[0] * (input_height - 1)
1490+
+ output_padding[0]
1491+
+ ((kernel_shape[0] - 1) * dilations[0] + 1)
1492+
- pads[0]
1493+
- pads[2];
1494+
let output_width = strides[1] * (input_width - 1)
1495+
+ output_padding[1]
1496+
+ ((kernel_shape[1] - 1) * dilations[1] + 1)
1497+
- pads[1]
1498+
- pads[3];
1499+
1500+
if output_shapes[0].dim(2) as i64 != output_height
1501+
|| output_shapes[0].dim(3) as i64 != output_width
1502+
{
1503+
return Err(CompileError::UnexpectedOutputShape {
1504+
output_shape: output_shapes[0].clone(),
1505+
expected_shape: vec![output_height, output_width],
1506+
});
1507+
}
1508+
1509+
let (x_threads, workgroup_size_x) = workgroup_size(
1510+
output_lengths[0],
1511+
MAX_COMPUTE_WORKGROUPS_PER_DIMENSION,
1512+
MAX_WORKGROUP_SIZE_X,
1513+
)?;
1514+
context.insert("workgroup_size_x", &workgroup_size_x);
1515+
1516+
let scalar_type = agreed_type(input_shapes, output_shapes)?;
1517+
1518+
if scalar_type.is_float() {
1519+
NodeTemplate {
1520+
scalar_type,
1521+
template: "unpool/convtranspose.wgsl",
1522+
threads: (x_threads, 1, 1),
1523+
}
1524+
} else {
1525+
return Err(CompileError::UnimplementedVariant {
1526+
variant: "Non-Float".into(),
1527+
op: "ConvTranspose".into(),
1528+
});
1529+
}
1530+
}
13871531
op => return Err(CompileError::UnimplementedOp(op.to_string())),
13881532
};
13891533

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
{%- include "structs.wgsl" -%}
2+
3+
// Input tensor, shape NxCxHxW
4+
@group(0) @binding(0)
5+
var<storage, read> input_tensor: Array;
6+
7+
// Kernel weight tensor, shape CxM/groupxkHxkW
8+
@group(0) @binding(1)
9+
var<storage, read> input_kernel_weights: Array;
10+
11+
{% if i_lens | length == 3 -%}
12+
@group(0) @binding(2)
13+
var<storage, read> input_bias: Array;
14+
15+
@group(0) @binding(3)
16+
var<storage, read_write> output_0: Array;
17+
{%- else -%}
18+
@group(0) @binding(2)
19+
var<storage, read_write> output_0: Array;
20+
{%- endif %}
21+
22+
{% set input_shape = i_shape[0] %}
23+
{% set input_chunks = i_chunks[0] %}
24+
{% set kernel_shape = i_shape[1] %}
25+
{% set kernel_chunks = i_chunks[1] %}
26+
27+
@compute @workgroup_size({{ workgroup_size_x }}, 1, 1)
28+
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
29+
let output_idx = global_id.x;
30+
31+
if (output_idx < {{ o_lens[0] }}u) {
32+
// Calculate the output coordinates we are responsible for
33+
let batch = output_idx / {{ o_chunks[0][0] }}u;
34+
var rest = output_idx % {{ o_chunks[0][0] }}u;
35+
36+
let channel = rest / {{ o_chunks[0][1] }}u;
37+
rest = rest % {{ o_chunks[0][1] }}u;
38+
39+
let y = rest / {{ o_chunks[0][2] }}u;
40+
let x = rest % {{ o_chunks[0][2] }}u;
41+
42+
// The actual output is a slice of the full output,
43+
// where the given padding values are removed on each end.
44+
// We don't need to worry about this at the upper coordinate end,
45+
// but we need to consider it on the lower end and calculate
46+
// virtual output coordinates to calculate the input coordinate range later.
47+
let unpadded_y = y + {{ pads[0] }}u;
48+
let unpadded_x = x + {{ pads[1] }}u;
49+
50+
let sample_root_index = batch * {{ input_chunks[0] }}u;
51+
52+
// Calculate the input coordinate range for our output coordinate
53+
let min_in_y = select(0u, (unpadded_y - {{ kernel_shape[2] }}u) / {{ stride[0] }}u, unpadded_y > {{ kernel_shape[2] }}u);
54+
let max_in_y = select({{ input_shape[2] }}u - 1u, unpadded_y / {{ stride[0] }}u, unpadded_y / {{ stride[0] }}u < {{ input_shape[3] }}u);
55+
let min_in_x = select(0u, (unpadded_x - {{ kernel_shape[3] }}u) / {{ stride[1] }}u, unpadded_x > {{ kernel_shape[3] }}u);
56+
let max_in_x = select({{ input_shape[3] }}u - 1u, unpadded_x / {{ stride[1] }}u, unpadded_x / {{ stride[1] }}u < {{ input_shape[3] }}u);
57+
58+
var result: Scalar = Scalar();
59+
60+
// Now, go over each input channel and apply the corresponing kernel for that channel
61+
// to calculate the output piece by piece.
62+
for(var ichannel: u32 = 0u; ichannel < {{ input_shape[1] }}u; ichannel = ichannel + 1u) {
63+
// Base index for the 2D data in the input data
64+
let base_index = sample_root_index + ichannel * {{ input_chunks[1] }}u;
65+
// Get the starting position of the kernel for the given input and output channel
66+
let base_kernel_index = ichannel *{{ kernel_chunks[0] }}u + channel * {{ kernel_chunks[1] }}u;
67+
68+
// Iterate of all potential input values
69+
for(var in_y: u32 = min_in_y; in_y <= max_in_y; in_y = in_y + 1u) {
70+
for(var in_x: u32 = min_in_x; in_x <= max_in_x; in_x = in_x + 1u) {
71+
let kernel_y = unpadded_y - (in_y * {{ stride[0] }}u);
72+
let kernel_x = unpadded_x - (in_x * {{ stride[1] }}u);
73+
74+
if(kernel_y < {{ kernel_shape[2] }}u && kernel_x < {{ kernel_shape[3] }}u) {
75+
result = result + (input_tensor.data[base_index + (in_y * {{ input_chunks[2] }}u) + in_x]
76+
* input_kernel_weights.data[base_kernel_index + kernel_y * {{ kernel_chunks[2] }}u + kernel_x]);
77+
}
78+
}
79+
}
80+
}
81+
{% if i_lens | length == 3 -%}
82+
// Apply Bias if specified
83+
result = result + input_bias.data[channel];
84+
{%- endif %}
85+
86+
output_0.data[output_idx] = result;
87+
}
88+
}

wonnx/tests/convtranspose.rs

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
use std::collections::HashMap;
2+
use wonnx::utils::{attribute, graph, initializer, model, node, tensor, OutputTensor};
3+
mod common;
4+
5+
#[test]
6+
fn convtranspose_default() {
7+
let data: Vec<f32> = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
8+
let input_shape = vec![1, 1, 3, 3];
9+
10+
let data_w = vec![
11+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
12+
];
13+
let kernel_shape = vec![1, 2, 3, 3];
14+
15+
let output_shape = vec![1, 2, 5, 5];
16+
17+
let input_data = HashMap::from([("X".to_string(), data.as_slice().into())]);
18+
19+
let convtranpose_model = model(graph(
20+
vec![tensor("X", &input_shape)],
21+
vec![tensor("Y", &output_shape)],
22+
vec![],
23+
vec![initializer("W", data_w, kernel_shape)],
24+
vec![node(
25+
vec!["X", "W"],
26+
vec!["Y"],
27+
"convtranspose",
28+
"ConvTranspose",
29+
vec![attribute("kernel_shape", vec![3, 3])],
30+
)],
31+
));
32+
33+
let session = pollster::block_on(wonnx::Session::from_model(convtranpose_model))
34+
.expect("Session did not create");
35+
let result = pollster::block_on(session.run(&input_data)).unwrap();
36+
37+
assert_eq!(
38+
result["Y"],
39+
OutputTensor::F32(vec![
40+
0.0, 1.0, 3.0, 3.0, 2.0, 3.0, 8.0, 15.0, 12.0, 7.0, 9.0, 21.0, 36.0, 27.0, 15.0, 9.0,
41+
20.0, 33.0, 24.0, 13.0, 6.0, 13.0, 21.0, 15.0, 8.0, 0.0, 1.0, 3.0, 3.0, 2.0, 3.0, 8.0,
42+
15.0, 12.0, 7.0, 9.0, 21.0, 36.0, 27.0, 15.0, 9.0, 20.0, 33.0, 24.0, 13.0, 6.0, 13.0,
43+
21.0, 15.0, 8.0,
44+
])
45+
);
46+
}
47+
48+
#[test]
49+
fn convtranspose_strides() {
50+
let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; // (1, 1, 3, 3)
51+
let input_shape = vec![1, 1, 3, 3];
52+
53+
let data_w = vec![
54+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
55+
];
56+
let kernel_shape = vec![1, 2, 3, 3];
57+
58+
let output_data = vec![
59+
0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0, 1.0, 1.0,
60+
3.0, 2.0, 2.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 3.0,
61+
3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 6.0, 6.0, 13.0, 7.0,
62+
15.0, 8.0, 8.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0,
63+
0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0, 3.0, 7.0, 4.0,
64+
9.0, 5.0, 5.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 6.0,
65+
6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 6.0, 6.0, 13.0, 7.0,
66+
15.0, 8.0, 8.0,
67+
];
68+
let output_shape = vec![1, 2, 9, 7];
69+
70+
let convtranpose_model = model(graph(
71+
vec![tensor("X", &input_shape)],
72+
vec![tensor("Y", &output_shape)],
73+
vec![],
74+
vec![initializer("W", data_w, kernel_shape)],
75+
vec![node(
76+
vec!["X", "W"],
77+
vec!["Y"],
78+
"convtranspose",
79+
"ConvTranspose",
80+
vec![
81+
attribute("kernel_shape", vec![3, 3]),
82+
attribute("strides", vec![3, 2]),
83+
],
84+
)],
85+
));
86+
87+
let input_data = HashMap::from([("X".to_string(), data.as_slice().into())]);
88+
let session = pollster::block_on(wonnx::Session::from_model(convtranpose_model))
89+
.expect("Session did not create");
90+
let result = pollster::block_on(session.run(&input_data)).unwrap();
91+
assert_eq!(result["Y"], OutputTensor::F32(output_data));
92+
}
93+
94+
#[test]
95+
fn convtranspose_pads() {
96+
let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
97+
let input_shape = vec![1, 1, 3, 3];
98+
99+
let data_w = vec![
100+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
101+
];
102+
let kernel_shape = vec![1, 2, 3, 3];
103+
104+
let output_data = vec![
105+
1.0, 1.0, 3.0, 1.0, 1.0, 3.0, 7.0, 4.0, 9.0, 7.0, 4.0, 9.0, 7.0, 4.0, 9.0, 13.0, 7.0, 15.0,
106+
13.0, 7.0, 15.0, 1.0, 1.0, 3.0, 1.0, 1.0, 3.0, 7.0, 4.0, 9.0, 7.0, 4.0, 9.0, 7.0, 4.0, 9.0,
107+
13.0, 7.0, 15.0, 13.0, 7.0, 15.0,
108+
];
109+
let output_shape = vec![1, 2, 7, 3];
110+
111+
let convtranpose_model = model(graph(
112+
vec![tensor("X", &input_shape)],
113+
vec![tensor("Y", &output_shape)],
114+
vec![],
115+
vec![initializer("W", data_w, kernel_shape)],
116+
vec![node(
117+
vec!["X", "W"],
118+
vec!["Y"],
119+
"convtranspose",
120+
"ConvTranspose",
121+
vec![
122+
attribute("strides", vec![3, 2]),
123+
attribute("pads", vec![1, 2, 1, 2]),
124+
],
125+
)],
126+
));
127+
128+
let input_data = HashMap::from([("X".to_string(), data.as_slice().into())]);
129+
let session = pollster::block_on(wonnx::Session::from_model(convtranpose_model))
130+
.expect("Session did not create");
131+
let result = pollster::block_on(session.run(&input_data)).unwrap();
132+
assert_eq!(result["Y"], OutputTensor::F32(output_data));
133+
}

0 commit comments

Comments
 (0)