Skip to content
This repository was archived by the owner on May 7, 2025. It is now read-only.

Commit fbb7ab1

Browse files
authored
Merge pull request #183 from mayjs/fix_concat
Fix Concat for larger inputs and add an additional test
2 parents 07b7a9e + c04d389 commit fbb7ab1

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

wonnx/templates/matrix/concat.wgsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>, @builtin(num_workgr
1717
let gidx = global_id.x;
1818
let gidy = global_id.y;
1919

20-
let nx = num_workgroups.x;
20+
let x_executions = num_workgroups.x * 16u;
2121

22-
let actual_idx = gidx + gidy * nx;
22+
let actual_idx = gidx + gidy * x_executions;
2323

2424
{% for input in i_lens %}
2525
{% if loop.first %}

wonnx/tests/concat.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,39 @@ fn test_concat() {
3535
common::assert_eq_vector((&result["Z"]).try_into().unwrap(), &expected_result);
3636
}
3737

38+
#[test]
39+
fn test_concat_long() {
40+
let n: usize = 100000;
41+
42+
let xdata: Vec<f32> = (0..n).map(|x| x as f32).collect();
43+
let mut ydata: Vec<f32> = (n..2 * n).map(|x| x as f32).collect();
44+
let input_dims = vec![n as i64];
45+
let output_dims = vec![(n * 2) as i64];
46+
47+
let input_data = HashMap::from([
48+
("X".into(), xdata.as_slice().into()),
49+
("Y".into(), ydata.as_slice().into()),
50+
]);
51+
52+
let model = model(graph(
53+
vec![tensor("X", &input_dims), tensor("Y", &input_dims)],
54+
vec![tensor("Z", &output_dims)],
55+
vec![],
56+
vec![],
57+
vec![node(vec!["X", "Y"], vec!["Z"], "a", "Concat", vec![])],
58+
));
59+
60+
let session =
61+
pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");
62+
63+
let result = pollster::block_on(session.run(&input_data)).unwrap();
64+
65+
let mut expected_result = xdata.clone();
66+
expected_result.append(&mut ydata);
67+
68+
common::assert_eq_vector((&result["Z"]).try_into().unwrap(), &expected_result);
69+
}
70+
3871
#[test]
3972
fn test_concat4() {
4073
let n: usize = 13;

0 commit comments

Comments
 (0)