Skip to content
This repository was archived by the owner on May 7, 2025. It is now read-only.

Commit 78958f2

Browse files
authored
Merge pull request #162 from SludgePhD/fix-resize
2 parents b043613 + 97d00ee commit 78958f2

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

wonnx/templates/matrix/resize.wgsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
2929
+
3030
{%- endif -%}
3131
u32(floor(
32-
({{ scalar_type }}(d_{{ loop.index0 }}) + {{ scalar_type }}(0.5)) / {{ scale }} - {{ scalar_type }}(0.5)
32+
(f32(d_{{ loop.index0 }}) + 0.5) / {{ scale }}
3333
)) * {{ chunks }}u
3434
{%- endfor -%}
3535
;

wonnx/tests/matrix.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,8 @@ fn test_pad_complex() {
341341
#[test]
342342
fn test_resize() {
343343
let _ = env_logger::builder().is_test(true).try_init();
344+
345+
// --- resize_downsample_scales_nearest ---
344346
let mut input_data = HashMap::new();
345347
let data = (1..=2 * 4).map(|x| x as f32).collect::<Vec<f32>>();
346348
input_data.insert("X".to_string(), data.as_slice().into());
@@ -366,6 +368,7 @@ fn test_resize() {
366368
let test_y = vec![1., 3.];
367369
common::assert_eq_vector((&result["Y"]).try_into().unwrap(), &test_y);
368370

371+
// --- resize_upsample_scales_nearest ---
369372
let mut input_data = HashMap::new();
370373
let data = (1..=4).map(|x| x as f32).collect::<Vec<f32>>();
371374
input_data.insert("X".to_string(), data.as_slice().into());
@@ -386,13 +389,17 @@ fn test_resize() {
386389

387390
let session = pollster::block_on(wonnx::Session::from_model(upsampling_model))
388391
.expect("session did not create");
389-
let _result = pollster::block_on(session.run(&input_data)).unwrap();
392+
let result = pollster::block_on(session.run(&input_data)).unwrap();
390393

391-
//let test_y = vec![
392-
// 1., 1., 1., 2., 2., 2., 1., 1., 1., 2., 2., 2., 3., 3., 3., 4., 4., 4., 3., 3., 3., 4., 4.,
393-
// 4.,
394-
//];
395-
//assert_eq!(result["Y"], test_y);
394+
#[rustfmt::skip]
395+
let test_y = vec![
396+
1., 1., 1., 2., 2., 2.,
397+
1., 1., 1., 2., 2., 2.,
398+
3., 3., 3., 4., 4., 4.,
399+
3., 3., 3., 4., 4., 4.,
400+
];
401+
let output: &[f32] = (&result["Y"]).try_into().unwrap();
402+
assert_eq!(output, &test_y);
396403
}
397404

398405
// Multiply a 2x2 matrix with an identity matrix of size 2x2.

0 commit comments

Comments
 (0)