Skip to content

Commit 265fe43

Browse files
committed
improve test
1 parent 5e3e1e0 commit 265fe43

File tree

2 files changed

+74
-48
lines changed

2 files changed

+74
-48
lines changed

wgpu/tests/shader/zero_init_workgroup_mem.rs

Lines changed: 57 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
use std::num::NonZeroU64;
2+
13
use wgpu::{
24
include_wgsl, BindGroupDescriptor, BindGroupEntry, BindGroupLayoutDescriptor,
3-
BindGroupLayoutEntry, BindingType, BufferBindingType, BufferDescriptor, BufferUsages,
4-
CommandEncoderDescriptor, ComputePassDescriptor, ComputePipelineDescriptor, DownlevelFlags,
5-
Limits, Maintain, MapMode, PipelineLayoutDescriptor, ShaderStages,
5+
BindGroupLayoutEntry, BindingResource, BindingType, BufferBinding, BufferBindingType,
6+
BufferDescriptor, BufferUsages, CommandEncoderDescriptor, ComputePassDescriptor,
7+
ComputePipelineDescriptor, DownlevelFlags, Limits, Maintain, MapMode, PipelineLayoutDescriptor,
8+
ShaderStages,
69
};
710

811
use crate::common::{initialize_test, TestParameters, TestingContext};
@@ -17,12 +20,19 @@ fn zero_init_workgroup_mem() {
1720
);
1821
}
1922

20-
/// Increases iterations and writes random data to workgroup memory before reading it each iteration.
21-
const TRY_TO_FAIL: bool = false;
23+
const DISPATCH_SIZE: (u32, u32, u32) = (64, 64, 64);
24+
const TOTAL_WORK_GROUPS: u32 = DISPATCH_SIZE.0 * DISPATCH_SIZE.1 * DISPATCH_SIZE.2;
25+
26+
/// nr of bytes we use in the shader
27+
const SHADER_WORKGROUP_MEMORY: u32 = 512 * 4 + 4;
28+
// assume we have this much workgroup memory (2GB)
29+
const MAX_DEVICE_WORKGROUP_MEMORY: u32 = i32::MAX as u32;
30+
const NR_OF_DISPATCHES: u32 =
31+
MAX_DEVICE_WORKGROUP_MEMORY / (SHADER_WORKGROUP_MEMORY * TOTAL_WORK_GROUPS) + 1; // TODO: use div_ceil once stabilized
2232

23-
const ARR_SIZE: usize = 512;
24-
const BUFFER_SIZE: u64 = 4 * (ARR_SIZE as u64);
25-
const ITERATIONS: u32 = if TRY_TO_FAIL { 100 } else { 1 };
33+
const OUTPUT_ARRAY_SIZE: u32 = TOTAL_WORK_GROUPS * NR_OF_DISPATCHES;
34+
const BUFFER_SIZE: u64 = OUTPUT_ARRAY_SIZE as u64 * 4;
35+
const BUFFER_BINDING_SIZE: u32 = TOTAL_WORK_GROUPS * 4;
2636

2737
fn zero_init_workgroup_mem_impl(ctx: TestingContext) {
2838
let bgl = ctx
@@ -34,7 +44,7 @@ fn zero_init_workgroup_mem_impl(ctx: TestingContext) {
3444
visibility: ShaderStages::COMPUTE,
3545
ty: BindingType::Buffer {
3646
ty: BufferBindingType::Storage { read_only: false },
37-
has_dynamic_offset: false,
47+
has_dynamic_offset: true,
3848
min_binding_size: None,
3949
},
4050
count: None,
@@ -60,7 +70,11 @@ fn zero_init_workgroup_mem_impl(ctx: TestingContext) {
6070
layout: &bgl,
6171
entries: &[BindGroupEntry {
6272
binding: 0,
63-
resource: output_buffer.as_entire_binding(),
73+
resource: BindingResource::Buffer(BufferBinding {
74+
buffer: &output_buffer,
75+
offset: 0,
76+
size: Some(NonZeroU64::new(BUFFER_BINDING_SIZE as u64).unwrap()),
77+
}),
6478
}],
6579
});
6680

@@ -96,7 +110,7 @@ fn zero_init_workgroup_mem_impl(ctx: TestingContext) {
96110

97111
// -- Initializing data --
98112

99-
let output_pre_init_data = [1; ARR_SIZE];
113+
let output_pre_init_data = vec![1; OUTPUT_ARRAY_SIZE as usize];
100114
ctx.queue.write_buffer(
101115
&output_buffer,
102116
0,
@@ -105,46 +119,48 @@ fn zero_init_workgroup_mem_impl(ctx: TestingContext) {
105119

106120
// -- Run test --
107121

108-
for i in 0..ITERATIONS {
109-
let mut encoder = ctx
110-
.device
111-
.create_command_encoder(&CommandEncoderDescriptor::default());
122+
let mut encoder = ctx
123+
.device
124+
.create_command_encoder(&CommandEncoderDescriptor::default());
125+
126+
let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor::default());
112127

113-
let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor::default());
114-
if TRY_TO_FAIL {
115-
cpass.set_pipeline(&pipeline_write);
116-
cpass.dispatch_workgroups(64, 64, 64);
117-
}
128+
cpass.set_pipeline(&pipeline_write);
129+
for _ in 0..NR_OF_DISPATCHES {
130+
cpass.dispatch_workgroups(DISPATCH_SIZE.0, DISPATCH_SIZE.1, DISPATCH_SIZE.2);
131+
}
118132

119-
cpass.set_pipeline(&pipeline_read);
120-
cpass.set_bind_group(0, &bg, &[]);
121-
cpass.dispatch_workgroups(1, 1, 1);
122-
drop(cpass);
133+
cpass.set_pipeline(&pipeline_read);
134+
for i in 0..NR_OF_DISPATCHES {
135+
cpass.set_bind_group(0, &bg, &[i * BUFFER_BINDING_SIZE]);
136+
cpass.dispatch_workgroups(DISPATCH_SIZE.0, DISPATCH_SIZE.1, DISPATCH_SIZE.2);
137+
}
138+
drop(cpass);
123139

124-
// -- Pulldown data --
140+
// -- Pulldown data --
125141

126-
encoder.copy_buffer_to_buffer(&output_buffer, 0, &mapping_buffer, 0, BUFFER_SIZE);
142+
encoder.copy_buffer_to_buffer(&output_buffer, 0, &mapping_buffer, 0, BUFFER_SIZE);
127143

128-
ctx.queue.submit(Some(encoder.finish()));
144+
ctx.queue.submit(Some(encoder.finish()));
129145

130-
mapping_buffer.slice(..).map_async(MapMode::Read, |_| ());
131-
ctx.device.poll(Maintain::Wait);
146+
mapping_buffer.slice(..).map_async(MapMode::Read, |_| ());
147+
ctx.device.poll(Maintain::Wait);
132148

133-
let mapped = mapping_buffer.slice(..).get_mapped_range();
149+
let mapped = mapping_buffer.slice(..).get_mapped_range();
134150

135-
let typed: &[u32] = bytemuck::cast_slice(&*mapped);
151+
let typed: &[u32] = bytemuck::cast_slice(&*mapped);
136152

137-
// -- Check results --
153+
// -- Check results --
138154

139-
let expected = [0; ARR_SIZE];
155+
let num_disptaches_failed = typed.iter().filter(|&&res| res != 0).count();
156+
let ratio = (num_disptaches_failed as f32 / OUTPUT_ARRAY_SIZE as f32) * 100.;
140157

141-
assert!(
142-
typed == expected,
143-
"Zero-initialization of workgroup memory failed (in iteration: {}).",
144-
i
145-
);
158+
assert!(
159+
num_disptaches_failed == 0,
160+
"Zero-initialization of workgroup memory failed ({:.0}% of disptaches failed).",
161+
ratio
162+
);
146163

147-
drop(mapped);
148-
mapping_buffer.unmap();
149-
}
164+
drop(mapped);
165+
mapping_buffer.unmap();
150166
}
Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,31 @@
1+
let array_size = 512u;
2+
13
struct WStruct {
2-
arr: array<i32, 512>,
3-
atom: atomic<i32>
4+
arr: array<u32, array_size>,
5+
atom: atomic<u32>
46
}
57

68
var<workgroup> w_mem: WStruct;
79

810
@group(0) @binding(0)
9-
var<storage, read_write> output: array<i32, 512>;
11+
var<storage, read_write> output: array<u32>;
1012

1113
@compute @workgroup_size(1)
12-
fn read() {
13-
output = w_mem.arr;
14+
fn read(@builtin(workgroup_id) wgid: vec3<u32>, @builtin(num_workgroups) num_workgroups: vec3<u32>) {
15+
var is_zero = true;
16+
for(var i = 0u; i < array_size; i++) {
17+
is_zero &= w_mem.arr[i] == 0u;
18+
}
19+
is_zero &= atomicLoad(&w_mem.atom) == 0u;
20+
21+
let idx = wgid.x + (wgid.y * num_workgroups.x) + (wgid.z * num_workgroups.x * num_workgroups.y);
22+
output[idx] = u32(!is_zero);
1423
}
1524

16-
@compute @workgroup_size(64)
25+
@compute @workgroup_size(1)
1726
fn write() {
18-
for(var i: i32 = 0; i < 512; i++) {
27+
for(var i = 0u; i < array_size; i++) {
1928
w_mem.arr[i] = i;
2029
}
30+
atomicStore(&w_mem.atom, 3u);
2131
}

0 commit comments

Comments
 (0)