1
+ use std:: num:: NonZeroU64 ;
2
+
1
3
use wgpu:: {
2
4
include_wgsl, BindGroupDescriptor , BindGroupEntry , BindGroupLayoutDescriptor ,
3
- BindGroupLayoutEntry , BindingType , BufferBindingType , BufferDescriptor , BufferUsages ,
4
- CommandEncoderDescriptor , ComputePassDescriptor , ComputePipelineDescriptor , DownlevelFlags ,
5
- Limits , Maintain , MapMode , PipelineLayoutDescriptor , ShaderStages ,
5
+ BindGroupLayoutEntry , BindingResource , BindingType , BufferBinding , BufferBindingType ,
6
+ BufferDescriptor , BufferUsages , CommandEncoderDescriptor , ComputePassDescriptor ,
7
+ ComputePipelineDescriptor , DownlevelFlags , Limits , Maintain , MapMode , PipelineLayoutDescriptor ,
8
+ ShaderStages ,
6
9
} ;
7
10
8
11
use crate :: common:: { initialize_test, TestParameters , TestingContext } ;
@@ -17,12 +20,19 @@ fn zero_init_workgroup_mem() {
17
20
) ;
18
21
}
19
22
20
- /// Increases iterations and writes random data to workgroup memory before reading it each iteration.
21
- const TRY_TO_FAIL : bool = false ;
23
+ const DISPATCH_SIZE : ( u32 , u32 , u32 ) = ( 64 , 64 , 64 ) ;
24
+ const TOTAL_WORK_GROUPS : u32 = DISPATCH_SIZE . 0 * DISPATCH_SIZE . 1 * DISPATCH_SIZE . 2 ;
25
+
26
+ /// nr of bytes we use in the shader
27
+ const SHADER_WORKGROUP_MEMORY : u32 = 512 * 4 + 4 ;
28
+ // assume we have this much workgroup memory (2GB)
29
+ const MAX_DEVICE_WORKGROUP_MEMORY : u32 = i32:: MAX as u32 ;
30
+ const NR_OF_DISPATCHES : u32 =
31
+ MAX_DEVICE_WORKGROUP_MEMORY / ( SHADER_WORKGROUP_MEMORY * TOTAL_WORK_GROUPS ) + 1 ; // TODO: use div_ceil once stabilized
22
32
23
- const ARR_SIZE : usize = 512 ;
24
- const BUFFER_SIZE : u64 = 4 * ( ARR_SIZE as u64 ) ;
25
- const ITERATIONS : u32 = if TRY_TO_FAIL { 100 } else { 1 } ;
33
+ const OUTPUT_ARRAY_SIZE : u32 = TOTAL_WORK_GROUPS * NR_OF_DISPATCHES ;
34
+ const BUFFER_SIZE : u64 = OUTPUT_ARRAY_SIZE as u64 * 4 ;
35
+ const BUFFER_BINDING_SIZE : u32 = TOTAL_WORK_GROUPS * 4 ;
26
36
27
37
fn zero_init_workgroup_mem_impl ( ctx : TestingContext ) {
28
38
let bgl = ctx
@@ -34,7 +44,7 @@ fn zero_init_workgroup_mem_impl(ctx: TestingContext) {
34
44
visibility : ShaderStages :: COMPUTE ,
35
45
ty : BindingType :: Buffer {
36
46
ty : BufferBindingType :: Storage { read_only : false } ,
37
- has_dynamic_offset : false ,
47
+ has_dynamic_offset : true ,
38
48
min_binding_size : None ,
39
49
} ,
40
50
count : None ,
@@ -60,7 +70,11 @@ fn zero_init_workgroup_mem_impl(ctx: TestingContext) {
60
70
layout : & bgl,
61
71
entries : & [ BindGroupEntry {
62
72
binding : 0 ,
63
- resource : output_buffer. as_entire_binding ( ) ,
73
+ resource : BindingResource :: Buffer ( BufferBinding {
74
+ buffer : & output_buffer,
75
+ offset : 0 ,
76
+ size : Some ( NonZeroU64 :: new ( BUFFER_BINDING_SIZE as u64 ) . unwrap ( ) ) ,
77
+ } ) ,
64
78
} ] ,
65
79
} ) ;
66
80
@@ -96,7 +110,7 @@ fn zero_init_workgroup_mem_impl(ctx: TestingContext) {
96
110
97
111
// -- Initializing data --
98
112
99
- let output_pre_init_data = [ 1 ; ARR_SIZE ] ;
113
+ let output_pre_init_data = vec ! [ 1 ; OUTPUT_ARRAY_SIZE as usize ] ;
100
114
ctx. queue . write_buffer (
101
115
& output_buffer,
102
116
0 ,
@@ -105,46 +119,48 @@ fn zero_init_workgroup_mem_impl(ctx: TestingContext) {
105
119
106
120
// -- Run test --
107
121
108
- for i in 0 ..ITERATIONS {
109
- let mut encoder = ctx
110
- . device
111
- . create_command_encoder ( & CommandEncoderDescriptor :: default ( ) ) ;
122
+ let mut encoder = ctx
123
+ . device
124
+ . create_command_encoder ( & CommandEncoderDescriptor :: default ( ) ) ;
125
+
126
+ let mut cpass = encoder. begin_compute_pass ( & ComputePassDescriptor :: default ( ) ) ;
112
127
113
- let mut cpass = encoder. begin_compute_pass ( & ComputePassDescriptor :: default ( ) ) ;
114
- if TRY_TO_FAIL {
115
- cpass. set_pipeline ( & pipeline_write) ;
116
- cpass. dispatch_workgroups ( 64 , 64 , 64 ) ;
117
- }
128
+ cpass. set_pipeline ( & pipeline_write) ;
129
+ for _ in 0 ..NR_OF_DISPATCHES {
130
+ cpass. dispatch_workgroups ( DISPATCH_SIZE . 0 , DISPATCH_SIZE . 1 , DISPATCH_SIZE . 2 ) ;
131
+ }
118
132
119
- cpass. set_pipeline ( & pipeline_read) ;
120
- cpass. set_bind_group ( 0 , & bg, & [ ] ) ;
121
- cpass. dispatch_workgroups ( 1 , 1 , 1 ) ;
122
- drop ( cpass) ;
133
+ cpass. set_pipeline ( & pipeline_read) ;
134
+ for i in 0 ..NR_OF_DISPATCHES {
135
+ cpass. set_bind_group ( 0 , & bg, & [ i * BUFFER_BINDING_SIZE ] ) ;
136
+ cpass. dispatch_workgroups ( DISPATCH_SIZE . 0 , DISPATCH_SIZE . 1 , DISPATCH_SIZE . 2 ) ;
137
+ }
138
+ drop ( cpass) ;
123
139
124
- // -- Pulldown data --
140
+ // -- Pulldown data --
125
141
126
- encoder. copy_buffer_to_buffer ( & output_buffer, 0 , & mapping_buffer, 0 , BUFFER_SIZE ) ;
142
+ encoder. copy_buffer_to_buffer ( & output_buffer, 0 , & mapping_buffer, 0 , BUFFER_SIZE ) ;
127
143
128
- ctx. queue . submit ( Some ( encoder. finish ( ) ) ) ;
144
+ ctx. queue . submit ( Some ( encoder. finish ( ) ) ) ;
129
145
130
- mapping_buffer. slice ( ..) . map_async ( MapMode :: Read , |_| ( ) ) ;
131
- ctx. device . poll ( Maintain :: Wait ) ;
146
+ mapping_buffer. slice ( ..) . map_async ( MapMode :: Read , |_| ( ) ) ;
147
+ ctx. device . poll ( Maintain :: Wait ) ;
132
148
133
- let mapped = mapping_buffer. slice ( ..) . get_mapped_range ( ) ;
149
+ let mapped = mapping_buffer. slice ( ..) . get_mapped_range ( ) ;
134
150
135
- let typed: & [ u32 ] = bytemuck:: cast_slice ( & * mapped) ;
151
+ let typed: & [ u32 ] = bytemuck:: cast_slice ( & * mapped) ;
136
152
137
- // -- Check results --
153
+ // -- Check results --
138
154
139
- let expected = [ 0 ; ARR_SIZE ] ;
155
+ let num_disptaches_failed = typed. iter ( ) . filter ( |& & res| res != 0 ) . count ( ) ;
156
+ let ratio = ( num_disptaches_failed as f32 / OUTPUT_ARRAY_SIZE as f32 ) * 100. ;
140
157
141
- assert ! (
142
- typed == expected ,
143
- "Zero-initialization of workgroup memory failed (in iteration: {} )." ,
144
- i
145
- ) ;
158
+ assert ! (
159
+ num_disptaches_failed == 0 ,
160
+ "Zero-initialization of workgroup memory failed ({:.0}% of disptaches failed )." ,
161
+ ratio
162
+ ) ;
146
163
147
- drop ( mapped) ;
148
- mapping_buffer. unmap ( ) ;
149
- }
164
+ drop ( mapped) ;
165
+ mapping_buffer. unmap ( ) ;
150
166
}
0 commit comments