Skip to content
This repository was archived by the owner on May 7, 2025. It is now read-only.

Commit 63ce8e5

Browse files
committed
Implement the ConvTranspose operation
1 parent 07b7a9e commit 63ce8e5

File tree

3 files changed

+266
-0
lines changed

3 files changed

+266
-0
lines changed

wonnx/src/compiler.rs

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,11 @@ lazy_static! {
135135
include_str!("../templates/endomorphism/broadcast.wgsl"),
136136
)
137137
.unwrap();
138+
tera.add_raw_template(
139+
"unpool/convtranspose.wgsl",
140+
include_str!("../templates/unpool/convtranspose.wgsl"),
141+
)
142+
.unwrap();
138143
tera
139144
};
140145
}
@@ -1384,6 +1389,104 @@ pub fn compile(
13841389
threads: (ceil(output_lengths[0], 256) as _, 1, 1),
13851390
}
13861391
}
1392+
"ConvTranspose" => {
1393+
/* Inputs:
1394+
* 1. X Input (N x C x H x W; Batch Size x Channels x Height x Weight)
1395+
* 2. Kernel
1396+
* 3. Bias
1397+
*/
1398+
log::debug!("{:?}", input_shapes);
1399+
1400+
if input_shapes[0].rank() != 4 {
1401+
/* FIXME: We don't handle non-2D input for now */
1402+
return Err(CompileError::InvalidInputShape {
1403+
input_index: 0,
1404+
input_shape: input_shapes[0].clone(),
1405+
});
1406+
}
1407+
1408+
/* Step 1: Get the input dimensions */
1409+
let input_height = input_shapes[0].dims[2] as i64;
1410+
let input_width = input_shapes[0].dims[3] as i64;
1411+
1412+
/* Step 2: Read attributes */
1413+
/* TODO: auto_pad */
1414+
let dilations = node.get_attribute_value("dilations", Some(vec![1, 1]))?;
1415+
1416+
let group = node.get_attribute_value("group", Some(1))?;
1417+
if group != 1 {
1418+
return Err(CompileError::InvalidAttributeValue {
1419+
attribute: "group".into(),
1420+
value: group.to_string(),
1421+
opset_version,
1422+
});
1423+
}
1424+
1425+
let inferred_kernel_shape = input_shapes[1]
1426+
.dims
1427+
.iter()
1428+
.skip(2)
1429+
.map(|&x| x as i64)
1430+
.collect::<Vec<_>>();
1431+
1432+
let kernel_shape =
1433+
node.get_attribute_value("kernel_shape", Some(inferred_kernel_shape.clone()))?;
1434+
if inferred_kernel_shape != kernel_shape {
1435+
log::error!("Inferred kernel shape: {:?}", inferred_kernel_shape);
1436+
return Err(CompileError::InvalidAttributeValue {
1437+
attribute: "kernel_shape".to_string(),
1438+
value: format!("{:?}", kernel_shape),
1439+
opset_version,
1440+
});
1441+
}
1442+
1443+
let output_padding = node.get_attribute_value("output_padding", Some(vec![0, 0]))?;
1444+
1445+
let pads = node.get_attribute_value("pads", Some(vec![0, 0, 0, 0]))?;
1446+
1447+
let strides = node.get_attribute_value("strides", Some(vec![1, 1]))?;
1448+
1449+
context.insert("stride", &strides);
1450+
1451+
let output_height = strides[0] * (input_height - 1)
1452+
+ output_padding[0]
1453+
+ ((kernel_shape[0] - 1) * dilations[0] + 1)
1454+
- pads[0]
1455+
- pads[2];
1456+
let output_width = strides[1] * (input_width - 1)
1457+
+ output_padding[1]
1458+
+ ((kernel_shape[1] - 1) * dilations[1] + 1)
1459+
- pads[1]
1460+
- pads[3];
1461+
1462+
log::debug!(
1463+
"Calculated output size: {:?}x{:?}",
1464+
output_width,
1465+
output_height
1466+
);
1467+
1468+
let (x_threads, workgroup_size_x) = workgroup_size(
1469+
output_lengths[0],
1470+
MAX_COMPUTE_WORKGROUPS_PER_DIMENSION,
1471+
MAX_WORKGROUP_SIZE_X,
1472+
)?;
1473+
context.insert("workgroup_size_x", &workgroup_size_x);
1474+
1475+
let scalar_type = agreed_type(input_shapes, output_shapes)?;
1476+
1477+
if scalar_type.is_float() {
1478+
NodeTemplate {
1479+
scalar_type,
1480+
template: "unpool/convtranspose.wgsl",
1481+
threads: (x_threads, 1, 1),
1482+
}
1483+
} else {
1484+
return Err(CompileError::UnimplementedVariant {
1485+
variant: "Non-Float".into(),
1486+
op: "ConvTranspose".into(),
1487+
});
1488+
}
1489+
}
13871490
op => return Err(CompileError::UnimplementedOp(op.to_string())),
13881491
};
13891492

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
{%- include "structs.wgsl" -%}
2+
3+
@group(0) @binding(0)
4+
var<storage, read> input_tensor: Array;
5+
6+
@group(0) @binding(1)
7+
var<storage, read> input_kernel_weights: Array;
8+
9+
{% if i_lens | length == 3 -%}
10+
@group(0) @binding(2)
11+
var<storage, read> input_bias: Array;
12+
13+
@group(0) @binding(3)
14+
var<storage, read_write> output_0: Array;
15+
{%- else -%}
16+
@group(0) @binding(2)
17+
var<storage, read_write> output_0: Array;
18+
{%- endif %}
19+
20+
@compute @workgroup_size({{ workgroup_size_x }}, 1, 1)
21+
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
22+
let output_idx = global_id.x;
23+
24+
if (output_idx < {{ o_lens[0] }}u) {
25+
let batch = output_idx / {{ o_chunks[0][0] }}u;
26+
var rest = output_idx % {{ o_chunks[0][0] }}u;
27+
28+
let channel = rest / {{ o_chunks[0][1] }}u;
29+
rest = rest % {{ o_chunks[0][1] }}u;
30+
31+
let y = rest / {{ o_chunks[0][2] }}u;
32+
let x = rest % {{ o_chunks[0][2] }}u;
33+
34+
let sample_root_index = batch * {{ i_chunks[0][0] }}u;
35+
36+
/* Kernel Size: C x M/groups x kH x kW */
37+
/* -> For each ouput channel, we need to calculate the convolution over each input channel and sum all of that up! */
38+
/* C = In Channels */
39+
/* M = Out Channels */
40+
41+
/* Kernel base coordinate limits */
42+
let min_in_y = select(0u, (y - {{ i_shape[1][2] }}u) / {{ stride[0] }}u, y > {{ i_shape[1][2] }}u);
43+
let max_in_y = select({{ i_shape[0][2] }}u - 1u, y / {{ stride[0] }}u, y / {{ stride[0] }}u < {{ i_shape[0][3] }}u);
44+
let min_in_x = select(0u, (x - {{ i_shape[1][3] }}u) / {{ stride[1] }}u, x > {{ i_shape[1][3] }}u);
45+
let max_in_x = select({{ i_shape[0][3] }}u - 1u, x / {{ stride[1] }}u, x / {{ stride[1] }}u < {{ i_shape[0][3] }}u);
46+
47+
var result: Scalar = Scalar();
48+
49+
for(var ichannel: u32 = 0u; ichannel < {{ i_shape[0][1] }}u; ichannel = ichannel + 1u) {
50+
let base_index = sample_root_index + ichannel * {{ i_chunks[0][1] }}u;
51+
let base_kernel_index = {{ i_chunks[1][0] }}u * ichannel + channel * {{ i_chunks[1][1] }}u; /* Base kernel address to apply kH and kW to */
52+
53+
for(var in_y: u32 = min_in_y; in_y <= max_in_y; in_y = in_y + 1u) {
54+
for(var in_x: u32 = min_in_x; in_x <= max_in_x; in_x = in_x + 1u) {
55+
let kernel_y = y - (in_y * {{ stride[0] }}u);
56+
let kernel_x = x - (in_x * {{ stride[1] }}u);
57+
58+
if(kernel_y < {{ i_shape[1][2] }}u && kernel_x < {{ i_shape[1][3] }}u) {
59+
result = result + (input_tensor.data[base_index + (in_y * {{ i_chunks[0][2] }}u) + in_x]
60+
* input_kernel_weights.data[base_kernel_index + kernel_y * {{ i_chunks[1][2] }}u + kernel_x]);
61+
}
62+
}
63+
}
64+
}
65+
{% if i_lens | length == 3 -%}
66+
result = result + input_bias.data[channel];
67+
{%- endif %}
68+
output_0.data[output_idx] = result;
69+
}
70+
}

wonnx/tests/convtranspose.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
use std::collections::HashMap;
2+
use wonnx::utils::{attribute, graph, initializer, model, node, tensor, OutputTensor};
3+
mod common;
4+
5+
#[test]
6+
fn convtranspose_default() {
7+
let data: Vec<f32> = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
8+
let input_shape = vec![1, 1, 3, 3];
9+
10+
let data_w = vec![
11+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
12+
];
13+
let kernel_shape = vec![1, 2, 3, 3];
14+
15+
let output_shape = vec![1, 2, 5, 5];
16+
17+
let input_data = HashMap::from([("X".to_string(), data.as_slice().into())]);
18+
19+
let convtranpose_model = model(graph(
20+
vec![tensor("X", &input_shape)],
21+
vec![tensor("Y", &output_shape)],
22+
vec![],
23+
vec![initializer("W", data_w, kernel_shape)],
24+
vec![node(
25+
vec!["X", "W"],
26+
vec!["Y"],
27+
"convtranspose",
28+
"ConvTranspose",
29+
vec![attribute("kernel_shape", vec![3, 3])],
30+
)],
31+
));
32+
33+
let session = pollster::block_on(wonnx::Session::from_model(convtranpose_model))
34+
.expect("Session did not create");
35+
let result = pollster::block_on(session.run(&input_data)).unwrap();
36+
37+
assert_eq!(
38+
result["Y"],
39+
OutputTensor::F32(vec![
40+
0.0, 1.0, 3.0, 3.0, 2.0, 3.0, 8.0, 15.0, 12.0, 7.0, 9.0, 21.0, 36.0, 27.0, 15.0, 9.0,
41+
20.0, 33.0, 24.0, 13.0, 6.0, 13.0, 21.0, 15.0, 8.0, 0.0, 1.0, 3.0, 3.0, 2.0, 3.0, 8.0,
42+
15.0, 12.0, 7.0, 9.0, 21.0, 36.0, 27.0, 15.0, 9.0, 20.0, 33.0, 24.0, 13.0, 6.0, 13.0,
43+
21.0, 15.0, 8.0,
44+
])
45+
);
46+
}
47+
48+
#[test]
49+
fn convtranspose_strides() {
50+
let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; // (1, 1, 3, 3)
51+
let input_shape = vec![1, 1, 3, 3];
52+
53+
let data_w = vec![
54+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
55+
];
56+
let kernel_shape = vec![1, 2, 3, 3];
57+
58+
let output_data = vec![
59+
0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0, 0.0,
60+
1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 0.0, 3.0, 3.0, 7.0, 4.0,
61+
9.0, 5.0, 5.0, 0.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 0.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0,
62+
8.0, 0.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 0.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0,
63+
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0,
64+
0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 0.0, 3.0, 3.0, 7.0,
65+
4.0, 9.0, 5.0, 5.0, 0.0, 3.0, 3.0, 7.0, 4.0, 9.0, 5.0, 5.0, 0.0, 3.0, 3.0, 7.0, 4.0, 9.0,
66+
5.0, 5.0, 0.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 0.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0,
67+
8.0, 0.0, 6.0, 6.0, 13.0, 7.0, 15.0, 8.0, 8.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
68+
];
69+
let output_shape = vec![1, 2, 10, 8];
70+
71+
let convtranpose_model = model(graph(
72+
vec![tensor("X", &input_shape)],
73+
vec![tensor("Y", &output_shape)],
74+
vec![],
75+
vec![initializer("W", data_w, kernel_shape)],
76+
vec![node(
77+
vec!["X", "W"],
78+
vec!["Y"],
79+
"convtranspose",
80+
"ConvTranspose",
81+
vec![
82+
attribute("kernel_shape", vec![3, 3]),
83+
attribute("strides", vec![3, 2]),
84+
],
85+
)],
86+
));
87+
88+
let input_data = HashMap::from([("X".to_string(), data.as_slice().into())]);
89+
let session = pollster::block_on(wonnx::Session::from_model(convtranpose_model))
90+
.expect("Session did not create");
91+
let result = pollster::block_on(session.run(&input_data)).unwrap();
92+
assert_eq!(result["Y"], OutputTensor::F32(output_data));
93+
}

0 commit comments

Comments
 (0)