Skip to content
This repository was archived by the owner on May 7, 2025. It is now read-only.

Commit 07b7a9e

Browse files
authored
Merge pull request #180 from mayjs/larger_concats
Allow larger output sizes for Concat
2 parents 78958f2 + 4806f31 commit 07b7a9e

File tree

3 files changed

+96
-8
lines changed

3 files changed

+96
-8
lines changed

wonnx/src/compiler.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
use crate::utils::{
33
ceil, AttributeNotFoundError, DataTypeError, MultiType, NodeAttributes, ScalarType, Shape,
44
};
5-
use num::integer::gcd;
5+
use num::integer::{gcd, Roots};
66
use tera::{Context, Tera};
77
use thiserror::Error;
88

@@ -792,10 +792,13 @@ pub fn compile(
792792
}
793793
context.insert("cum_len", &input_cumulative_len);
794794

795+
let root = output_lengths[0].sqrt() + 1;
796+
let per_dim = ceil(root, 16) + 1;
797+
795798
NodeTemplate {
796799
scalar_type: agreed_type(input_shapes, output_shapes)?,
797800
template: "matrix/concat.wgsl",
798-
threads: (ceil(output_lengths[0], 256) as u32, 1, 1),
801+
threads: (per_dim as u32, per_dim as u32, 1),
799802
}
800803
}
801804
op @ ("MaxPool" | "AveragePool" | "Conv" | "ConvRelu" | "ConvLeakyRelu" | "ConvMish"

wonnx/templates/matrix/concat.wgsl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,24 @@ var<storage, read> input_{{ loop.index0 }}: Array;
1212
@group({{ binding_len / 4 | int }}) @binding({{ binding_len % 4 }})
1313
var<storage, read_write> output_0: Array;
1414

15-
@compute @workgroup_size(256, 1, 1)
16-
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
15+
@compute @workgroup_size(16, 16, 1)
16+
fn main(@builtin(global_invocation_id) global_id: vec3<u32>, @builtin(num_workgroups) num_workgroups: vec3<u32>) {
1717
let gidx = global_id.x;
18+
let gidy = global_id.y;
19+
20+
let nx = num_workgroups.x;
21+
22+
let actual_idx = gidx + gidy * nx;
1823

1924
{% for input in i_lens %}
2025
{% if loop.first %}
21-
if (gidx < {{ i_lens[0] }}u) {
22-
output_0.data[gidx] = input_0.data[gidx];
26+
if (actual_idx < {{ i_lens[0] }}u) {
27+
output_0.data[actual_idx] = input_0.data[actual_idx];
2328
}
2429

2530
{% else %}
26-
if ((gidx >= {{ cum_len | nth(n=loop.index0 -1) }}u) && (gidx < {{ cum_len | nth(n=loop.index0)}}u)) {
27-
output_0.data[gidx] = input_{{ loop.index0 }}.data[gidx - {{ cum_len | nth(n=loop.index0 -1) }}u];
31+
if ((actual_idx >= {{ cum_len | nth(n=loop.index0 -1) }}u) && (actual_idx < {{ cum_len | nth(n=loop.index0)}}u)) {
32+
output_0.data[actual_idx] = input_{{ loop.index0 }}.data[actual_idx - {{ cum_len | nth(n=loop.index0 -1) }}u];
2833
}
2934

3035
{% endif %}

wonnx/tests/concat.rs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
use std::{collections::HashMap, convert::TryInto};
2+
use wonnx::utils::{graph, model, node, tensor};
3+
mod common;
4+
5+
#[test]
6+
fn test_concat() {
7+
let n: usize = 16;
8+
9+
let xdata: Vec<f32> = (0..n).map(|x| x as f32).collect();
10+
let mut ydata: Vec<f32> = (n..2 * n).map(|x| x as f32).collect();
11+
let input_dims = vec![n as i64];
12+
let output_dims = vec![(n * 2) as i64];
13+
14+
let input_data = HashMap::from([
15+
("X".into(), xdata.as_slice().into()),
16+
("Y".into(), ydata.as_slice().into()),
17+
]);
18+
19+
let model = model(graph(
20+
vec![tensor("X", &input_dims), tensor("Y", &input_dims)],
21+
vec![tensor("Z", &output_dims)],
22+
vec![],
23+
vec![],
24+
vec![node(vec!["X", "Y"], vec!["Z"], "a", "Concat", vec![])],
25+
));
26+
27+
let session =
28+
pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");
29+
30+
let result = pollster::block_on(session.run(&input_data)).unwrap();
31+
32+
let mut expected_result = xdata.clone();
33+
expected_result.append(&mut ydata);
34+
35+
common::assert_eq_vector((&result["Z"]).try_into().unwrap(), &expected_result);
36+
}
37+
38+
#[test]
39+
fn test_concat4() {
40+
let n: usize = 13;
41+
42+
let xdata: Vec<f32> = (0..n).map(|x| x as f32).collect();
43+
let mut ydata: Vec<f32> = (n..2 * n).map(|x| x as f32).collect();
44+
let mut zdata: Vec<f32> = (n * 2..3 * n).map(|x| x as f32).collect();
45+
let mut wdata: Vec<f32> = (n * 3..4 * n).map(|x| x as f32).collect();
46+
let input_dims = vec![n as i64];
47+
let output_dims = vec![(n * 4) as i64];
48+
49+
let input_data = HashMap::from([
50+
("X".into(), xdata.as_slice().into()),
51+
("Y".into(), ydata.as_slice().into()),
52+
("Z".into(), zdata.as_slice().into()),
53+
("W".into(), wdata.as_slice().into()),
54+
]);
55+
56+
let model = model(graph(
57+
vec![
58+
tensor("X", &input_dims),
59+
tensor("Y", &input_dims),
60+
tensor("Z", &input_dims),
61+
tensor("W", &input_dims),
62+
],
63+
vec![tensor("O", &output_dims)],
64+
vec![],
65+
vec![],
66+
vec![node(vec!["X", "Y", "Z", "W"], vec!["O"], "a", "Concat", vec![])],
67+
));
68+
69+
let session =
70+
pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");
71+
72+
let result = pollster::block_on(session.run(&input_data)).unwrap();
73+
74+
let mut expected_result = xdata.clone();
75+
expected_result.append(&mut ydata);
76+
expected_result.append(&mut zdata);
77+
expected_result.append(&mut wdata);
78+
79+
common::assert_eq_vector((&result["O"]).try_into().unwrap(), &expected_result);
80+
}

0 commit comments

Comments
 (0)