Skip to content

Commit d280913

Browse files
authored
Zero-initialize workgroup memory (#3174)
fixes #2430
1 parent f309095 commit d280913

File tree

13 files changed

+267
-6
lines changed

13 files changed

+267
-6
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ surface.configure(&device, &config);
225225
- Implemented correleation between user timestamps and platform specific presentation timestamps via [`Adapter::get_presentation_timestamp`]. By @cwfitzgerald in [#3240](https://github.com/gfx-rs/wgpu/pull/3240)
226226
- Added support for `Features::SHADER_PRIMITIVE_INDEX` on all backends. By @cwfitzgerald in [#3272](https://github.com/gfx-rs/wgpu/pull/3272)
227227
- Implemented `TextureFormat::Stencil8`, allowing for stencil testing without depth components. By @Dinnerbone in [#3343](https://github.com/gfx-rs/wgpu/pull/3343)
228-
- Implemented `add_srgb_suffix()` for `TextureFormat` for converting linear formats to sRGB. By @Elabajaba in [#3419](https://github.com/gfx-rs/wgpu/pull/3419)
228+
- Implemented `add_srgb_suffix()` for `TextureFormat` for converting linear formats to sRGB. By @Elabajaba in [#3419](https://github.com/gfx-rs/wgpu/pull/3419)
229+
- Zero-initialize workgroup memory. By @teoxoy in [#3174](https://github.com/gfx-rs/wgpu/pull/3174)
229230

230231
#### GLES
231232

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ path = "./wgpu-hal"
3939

4040
[workspace.dependencies.naga]
4141
git = "https://github.com/gfx-rs/naga"
42-
rev = "1be8024"
42+
rev = "c7d02151f08d6285683795289b5725b827d836d1"
4343
version = "0.10"
4444

4545
[workspace.dependencies]

wgpu-core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ thiserror = "1"
6767

6868
[dependencies.naga]
6969
git = "https://github.com/gfx-rs/naga"
70-
rev = "1be8024"
70+
rev = "c7d02151f08d6285683795289b5725b827d836d1"
7171
version = "0.10"
7272
features = ["clone", "span", "validate"]
7373

wgpu-hal/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,14 @@ android_system_properties = "0.1.1"
113113

114114
[dependencies.naga]
115115
git = "https://github.com/gfx-rs/naga"
116-
rev = "1be8024"
116+
rev = "c7d02151f08d6285683795289b5725b827d836d1"
117117
version = "0.10"
118118
features = ["clone"]
119119

120120
# DEV dependencies
121121
[dev-dependencies.naga]
122122
git = "https://github.com/gfx-rs/naga"
123-
rev = "1be8024"
123+
rev = "c7d02151f08d6285683795289b5725b827d836d1"
124124
version = "0.10"
125125
features = ["wgsl-in"]
126126

wgpu-hal/src/dx12/device.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,7 @@ impl crate::Device<super::Api> for super::Device {
10701070
fake_missing_bindings: false,
10711071
special_constants_binding,
10721072
push_constants_target,
1073+
zero_initialize_workgroup_memory: true,
10731074
},
10741075
})
10751076
}

wgpu-hal/src/gles/device.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,7 @@ impl crate::Device<super::Api> for super::Device {
10321032
version: self.shared.shading_language_version,
10331033
writer_flags,
10341034
binding_map,
1035+
zero_initialize_workgroup_memory: true,
10351036
},
10361037
})
10371038
}

wgpu-hal/src/metal/device.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,7 @@ impl crate::Device<super::Api> for super::Device {
699699
// TODO: support bounds checks on binding arrays
700700
binding_array: naga::proc::BoundsCheckPolicy::Unchecked,
701701
},
702+
zero_initialize_workgroup_memory: true,
702703
},
703704
total_push_constants,
704705
})

wgpu-hal/src/vulkan/adapter.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ pub struct PhysicalDeviceFeatures {
3131
vk::PhysicalDeviceShaderFloat16Int8Features,
3232
vk::PhysicalDevice16BitStorageFeatures,
3333
)>,
34+
zero_initialize_workgroup_memory:
35+
Option<vk::PhysicalDeviceZeroInitializeWorkgroupMemoryFeatures>,
3436
}
3537

3638
// This is safe because the structs have `p_next: *mut c_void`, which we null out/never read.
@@ -69,6 +71,9 @@ impl PhysicalDeviceFeatures {
6971
info = info.push_next(f16_i8_feature);
7072
info = info.push_next(_16bit_feature);
7173
}
74+
if let Some(ref mut feature) = self.zero_initialize_workgroup_memory {
75+
info = info.push_next(feature);
76+
}
7277
info
7378
}
7479

@@ -286,6 +291,19 @@ impl PhysicalDeviceFeatures {
286291
} else {
287292
None
288293
},
294+
zero_initialize_workgroup_memory: if effective_api_version >= vk::API_VERSION_1_3
295+
|| enabled_extensions.contains(&vk::KhrZeroInitializeWorkgroupMemoryFn::name())
296+
{
297+
Some(
298+
vk::PhysicalDeviceZeroInitializeWorkgroupMemoryFeatures::builder()
299+
.shader_zero_initialize_workgroup_memory(
300+
private_caps.zero_initialize_workgroup_memory,
301+
)
302+
.build(),
303+
)
304+
} else {
305+
None
306+
},
289307
}
290308
}
291309

@@ -885,6 +903,16 @@ impl super::InstanceShared {
885903
builder = builder.push_next(&mut next.1);
886904
}
887905

906+
// `VK_KHR_zero_initialize_workgroup_memory` is promoted to 1.3
907+
if capabilities.effective_api_version >= vk::API_VERSION_1_3
908+
|| capabilities.supports_extension(vk::KhrZeroInitializeWorkgroupMemoryFn::name())
909+
{
910+
let next = features
911+
.zero_initialize_workgroup_memory
912+
.insert(vk::PhysicalDeviceZeroInitializeWorkgroupMemoryFeatures::default());
913+
builder = builder.push_next(next);
914+
}
915+
888916
let mut features2 = builder.build();
889917
unsafe {
890918
get_device_properties.get_physical_device_features2(phd, &mut features2);
@@ -1044,6 +1072,11 @@ impl super::Instance {
10441072
.image_robustness
10451073
.map_or(false, |ext| ext.robust_image_access != 0),
10461074
},
1075+
zero_initialize_workgroup_memory: phd_features
1076+
.zero_initialize_workgroup_memory
1077+
.map_or(false, |ext| {
1078+
ext.shader_zero_initialize_workgroup_memory == vk::TRUE
1079+
}),
10471080
};
10481081
let capabilities = crate::Capabilities {
10491082
limits: phd_capabilities.to_wgpu_limits(),
@@ -1246,6 +1279,14 @@ impl super::Adapter {
12461279
// TODO: support bounds checks on binding arrays
12471280
binding_array: naga::proc::BoundsCheckPolicy::Unchecked,
12481281
},
1282+
zero_initialize_workgroup_memory: if self
1283+
.private_caps
1284+
.zero_initialize_workgroup_memory
1285+
{
1286+
spv::ZeroInitializeWorkgroupMemoryMode::Native
1287+
} else {
1288+
spv::ZeroInitializeWorkgroupMemoryMode::Polyfill
1289+
},
12491290
// We need to build this separately for each invocation, so just default it out here
12501291
binding_map: BTreeMap::default(),
12511292
}

wgpu-hal/src/vulkan/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ struct PrivateCapabilities {
166166
non_coherent_map_mask: wgt::BufferAddress,
167167
robust_buffer_access: bool,
168168
robust_image_access: bool,
169+
zero_initialize_workgroup_memory: bool,
169170
}
170171

171172
bitflags::bitflags!(

wgpu/tests/shader/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use crate::common::TestingContext;
1717

1818
mod numeric_builtins;
1919
mod struct_layout;
20+
mod zero_init_workgroup_mem;
2021

2122
#[derive(Clone, Copy, PartialEq)]
2223
enum InputStorageType {
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
use std::num::NonZeroU64;
2+
3+
use wgpu::{
4+
include_wgsl, Backends, BindGroupDescriptor, BindGroupEntry, BindGroupLayoutDescriptor,
5+
BindGroupLayoutEntry, BindingResource, BindingType, BufferBinding, BufferBindingType,
6+
BufferDescriptor, BufferUsages, CommandEncoderDescriptor, ComputePassDescriptor,
7+
ComputePipelineDescriptor, DownlevelFlags, Limits, Maintain, MapMode, PipelineLayoutDescriptor,
8+
ShaderStages,
9+
};
10+
11+
use crate::common::{initialize_test, TestParameters, TestingContext};
12+
13+
#[test]
14+
fn zero_init_workgroup_mem() {
15+
initialize_test(
16+
TestParameters::default()
17+
.downlevel_flags(DownlevelFlags::COMPUTE_SHADERS)
18+
.limits(Limits::downlevel_defaults())
19+
// remove once we get to https://github.com/gfx-rs/wgpu/issues/3193 or
20+
// https://github.com/gfx-rs/wgpu/issues/3160
21+
.specific_failure(
22+
Some(Backends::DX12),
23+
Some(5140),
24+
Some("Microsoft Basic Render Driver"),
25+
true,
26+
)
27+
// this one is flakey
28+
.specific_failure(
29+
Some(Backends::VULKAN),
30+
Some(6880),
31+
Some("SwiftShader"),
32+
true,
33+
)
34+
// TODO: investigate why it fails
35+
.specific_failure(Some(Backends::GL), Some(65541), Some("llvmpipe"), false),
36+
zero_init_workgroup_mem_impl,
37+
);
38+
}
39+
40+
const DISPATCH_SIZE: (u32, u32, u32) = (64, 64, 64);
41+
const TOTAL_WORK_GROUPS: u32 = DISPATCH_SIZE.0 * DISPATCH_SIZE.1 * DISPATCH_SIZE.2;
42+
43+
/// nr of bytes we use in the shader
44+
const SHADER_WORKGROUP_MEMORY: u32 = 512 * 4 + 4;
45+
// assume we have this much workgroup memory (2GB)
46+
const MAX_DEVICE_WORKGROUP_MEMORY: u32 = i32::MAX as u32;
47+
const NR_OF_DISPATCHES: u32 =
48+
MAX_DEVICE_WORKGROUP_MEMORY / (SHADER_WORKGROUP_MEMORY * TOTAL_WORK_GROUPS) + 1; // TODO: use div_ceil once stabilized
49+
50+
const OUTPUT_ARRAY_SIZE: u32 = TOTAL_WORK_GROUPS * NR_OF_DISPATCHES;
51+
const BUFFER_SIZE: u64 = OUTPUT_ARRAY_SIZE as u64 * 4;
52+
const BUFFER_BINDING_SIZE: u32 = TOTAL_WORK_GROUPS * 4;
53+
54+
fn zero_init_workgroup_mem_impl(ctx: TestingContext) {
55+
let bgl = ctx
56+
.device
57+
.create_bind_group_layout(&BindGroupLayoutDescriptor {
58+
label: None,
59+
entries: &[BindGroupLayoutEntry {
60+
binding: 0,
61+
visibility: ShaderStages::COMPUTE,
62+
ty: BindingType::Buffer {
63+
ty: BufferBindingType::Storage { read_only: false },
64+
has_dynamic_offset: true,
65+
min_binding_size: None,
66+
},
67+
count: None,
68+
}],
69+
});
70+
71+
let output_buffer = ctx.device.create_buffer(&BufferDescriptor {
72+
label: Some("output buffer"),
73+
size: BUFFER_SIZE,
74+
usage: BufferUsages::COPY_DST | BufferUsages::COPY_SRC | BufferUsages::STORAGE,
75+
mapped_at_creation: false,
76+
});
77+
78+
let mapping_buffer = ctx.device.create_buffer(&BufferDescriptor {
79+
label: Some("mapping buffer"),
80+
size: BUFFER_SIZE,
81+
usage: BufferUsages::COPY_DST | BufferUsages::MAP_READ,
82+
mapped_at_creation: false,
83+
});
84+
85+
let bg = ctx.device.create_bind_group(&BindGroupDescriptor {
86+
label: None,
87+
layout: &bgl,
88+
entries: &[BindGroupEntry {
89+
binding: 0,
90+
resource: BindingResource::Buffer(BufferBinding {
91+
buffer: &output_buffer,
92+
offset: 0,
93+
size: Some(NonZeroU64::new(BUFFER_BINDING_SIZE as u64).unwrap()),
94+
}),
95+
}],
96+
});
97+
98+
let pll = ctx
99+
.device
100+
.create_pipeline_layout(&PipelineLayoutDescriptor {
101+
label: None,
102+
bind_group_layouts: &[&bgl],
103+
push_constant_ranges: &[],
104+
});
105+
106+
let sm = ctx
107+
.device
108+
.create_shader_module(include_wgsl!("zero_init_workgroup_mem.wgsl"));
109+
110+
let pipeline_read = ctx
111+
.device
112+
.create_compute_pipeline(&ComputePipelineDescriptor {
113+
label: Some("pipeline read"),
114+
layout: Some(&pll),
115+
module: &sm,
116+
entry_point: "read",
117+
});
118+
119+
let pipeline_write = ctx
120+
.device
121+
.create_compute_pipeline(&ComputePipelineDescriptor {
122+
label: Some("pipeline write"),
123+
layout: None,
124+
module: &sm,
125+
entry_point: "write",
126+
});
127+
128+
// -- Initializing data --
129+
130+
let output_pre_init_data = vec![1; OUTPUT_ARRAY_SIZE as usize];
131+
ctx.queue.write_buffer(
132+
&output_buffer,
133+
0,
134+
bytemuck::cast_slice(&output_pre_init_data),
135+
);
136+
137+
// -- Run test --
138+
139+
let mut encoder = ctx
140+
.device
141+
.create_command_encoder(&CommandEncoderDescriptor::default());
142+
143+
let mut cpass = encoder.begin_compute_pass(&ComputePassDescriptor::default());
144+
145+
cpass.set_pipeline(&pipeline_write);
146+
for _ in 0..NR_OF_DISPATCHES {
147+
cpass.dispatch_workgroups(DISPATCH_SIZE.0, DISPATCH_SIZE.1, DISPATCH_SIZE.2);
148+
}
149+
150+
cpass.set_pipeline(&pipeline_read);
151+
for i in 0..NR_OF_DISPATCHES {
152+
cpass.set_bind_group(0, &bg, &[i * BUFFER_BINDING_SIZE]);
153+
cpass.dispatch_workgroups(DISPATCH_SIZE.0, DISPATCH_SIZE.1, DISPATCH_SIZE.2);
154+
}
155+
drop(cpass);
156+
157+
// -- Pulldown data --
158+
159+
encoder.copy_buffer_to_buffer(&output_buffer, 0, &mapping_buffer, 0, BUFFER_SIZE);
160+
161+
ctx.queue.submit(Some(encoder.finish()));
162+
163+
mapping_buffer.slice(..).map_async(MapMode::Read, |_| ());
164+
ctx.device.poll(Maintain::Wait);
165+
166+
let mapped = mapping_buffer.slice(..).get_mapped_range();
167+
168+
let typed: &[u32] = bytemuck::cast_slice(&*mapped);
169+
170+
// -- Check results --
171+
172+
let num_disptaches_failed = typed.iter().filter(|&&res| res != 0).count();
173+
let ratio = (num_disptaches_failed as f32 / OUTPUT_ARRAY_SIZE as f32) * 100.;
174+
175+
assert!(
176+
num_disptaches_failed == 0,
177+
"Zero-initialization of workgroup memory failed ({:.0}% of disptaches failed).",
178+
ratio
179+
);
180+
181+
drop(mapped);
182+
mapping_buffer.unmap();
183+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
const array_size = 512u;
2+
3+
struct WStruct {
4+
arr: array<u32, array_size>,
5+
atom: atomic<u32>
6+
}
7+
8+
var<workgroup> w_mem: WStruct;
9+
10+
@group(0) @binding(0)
11+
var<storage, read_write> output: array<u32>;
12+
13+
@compute @workgroup_size(1)
14+
fn read(@builtin(workgroup_id) wgid: vec3<u32>, @builtin(num_workgroups) num_workgroups: vec3<u32>) {
15+
var is_zero = true;
16+
for(var i = 0u; i < array_size; i++) {
17+
is_zero &= w_mem.arr[i] == 0u;
18+
}
19+
is_zero &= atomicLoad(&w_mem.atom) == 0u;
20+
21+
let idx = wgid.x + (wgid.y * num_workgroups.x) + (wgid.z * num_workgroups.x * num_workgroups.y);
22+
output[idx] = u32(!is_zero);
23+
}
24+
25+
@compute @workgroup_size(1)
26+
fn write() {
27+
for(var i = 0u; i < array_size; i++) {
28+
w_mem.arr[i] = i;
29+
}
30+
atomicStore(&w_mem.atom, 3u);
31+
}

0 commit comments

Comments
 (0)