Skip to content

fix: array stride validation errors #273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
624 changes: 624 additions & 0 deletions crates/rustc_codegen_spirv/src/linker/array_stride_fixer.rs

Large diffs are not rendered by default.

54 changes: 52 additions & 2 deletions crates/rustc_codegen_spirv/src/linker/duplicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,14 @@ fn gather_names(debug_names: &[Instruction]) -> FxHashMap<Word, String> {
.collect()
}

fn make_dedupe_key(
fn make_dedupe_key_with_array_context(
inst: &Instruction,
unresolved_forward_pointers: &FxHashSet<Word>,
annotations: &FxHashMap<Word, Vec<u32>>,
names: &FxHashMap<Word, String>,
array_contexts: Option<
&FxHashMap<Word, crate::linker::array_stride_fixer::ArrayStorageContext>,
>,
) -> Vec<u32> {
let mut data = vec![inst.class.opcode as u32];

Expand Down Expand Up @@ -169,6 +172,38 @@ fn make_dedupe_key(
}
}

// For array types, include storage class context in the key to prevent
// inappropriate deduplication between different storage class contexts
if let Some(result_id) = inst.result_id {
if matches!(inst.class.opcode, Op::TypeArray | Op::TypeRuntimeArray) {
if let Some(contexts) = array_contexts {
if let Some(context) = contexts.get(&result_id) {
// Include usage pattern in the key so arrays with different contexts won't deduplicate
let usage_pattern_discriminant = match context.usage_pattern {
crate::linker::array_stride_fixer::ArrayUsagePattern::LayoutRequired => {
1u32
}
crate::linker::array_stride_fixer::ArrayUsagePattern::LayoutForbidden => {
2u32
}
crate::linker::array_stride_fixer::ArrayUsagePattern::MixedUsage => 3u32,
crate::linker::array_stride_fixer::ArrayUsagePattern::Unused => 4u32,
};
data.push(usage_pattern_discriminant);

// Also include the specific storage classes for fine-grained differentiation
let mut storage_classes: Vec<u32> = context
.storage_classes
.iter()
.map(|sc| *sc as u32)
.collect();
storage_classes.sort(); // Ensure deterministic ordering
data.extend(storage_classes);
}
}
}
}

data
}

Expand All @@ -185,6 +220,15 @@ fn rewrite_inst_with_rules(inst: &mut Instruction, rules: &FxHashMap<u32, u32>)
}

pub fn remove_duplicate_types(module: &mut Module) {
remove_duplicate_types_with_array_context(module, None);
}

pub fn remove_duplicate_types_with_array_context(
module: &mut Module,
array_contexts: Option<
&FxHashMap<Word, crate::linker::array_stride_fixer::ArrayStorageContext>,
>,
) {
// Keep in mind, this algorithm requires forward type references to not exist - i.e. it's a valid spir-v module.

// When a duplicate type is encountered, then this is a map from the deleted ID, to the new, deduplicated ID.
Expand Down Expand Up @@ -222,7 +266,13 @@ pub fn remove_duplicate_types(module: &mut Module) {
// all_inst_iter_mut pass below. However, the code is a lil bit cleaner this way I guess.
rewrite_inst_with_rules(inst, &rewrite_rules);

let key = make_dedupe_key(inst, &unresolved_forward_pointers, &annotations, &names);
let key = make_dedupe_key_with_array_context(
inst,
&unresolved_forward_pointers,
&annotations,
&names,
array_contexts,
);

match key_to_result_id.entry(key) {
hash_map::Entry::Vacant(entry) => {
Expand Down
7 changes: 7 additions & 0 deletions crates/rustc_codegen_spirv/src/linker/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[cfg(test)]
mod test;

mod array_stride_fixer;
mod dce;
mod destructure_composites;
mod duplicates;
Expand Down Expand Up @@ -355,6 +356,12 @@ pub fn link(
});
}

// Fix ArrayStride decorations (after storage classes are resolved to avoid conflicts)
{
let _timer = sess.timer("fix_array_stride_decorations");
array_stride_fixer::fix_array_stride_decorations_with_deduplication(&mut output, false);
}

// NOTE(eddyb) with SPIR-T, we can do `mem2reg` before inlining, too!
{
if opts.dce {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Test that ArrayStride decorations are kept for function storage in SPIR-V 1.3

// build-pass
// compile-flags: -C llvm-args=--disassemble-globals
// normalize-stderr-test "OpLine .*\n" -> ""
// normalize-stderr-test "OpSource .*\n" -> ""
// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/"
// only-spv1.3
use spirv_std::spirv;

#[spirv(compute(threads(1)))]
pub fn main(
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] output: &mut [u32; 1],
) {
// Function storage in SPIR-V 1.3 should keep ArrayStride decorations
let mut function_var: [u32; 256] = [0; 256];
function_var[0] = 42;
function_var[1] = function_var[0] + 1;
// Force the array to be used by writing to output
output[0] = function_var[1];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
OpCapability Shader
OpCapability Float64
OpCapability Int64
OpCapability Int16
OpCapability Int8
OpCapability ShaderClockKHR
OpExtension "SPV_KHR_shader_clock"
OpMemoryModel Logical Simple
OpEntryPoint GLCompute %1 "main"
OpExecutionMode %1 LocalSize 1 1 1
%2 = OpString "$OPSTRING_FILENAME/function_storage_spirv13_kept.rs"
OpDecorate %4 ArrayStride 4
OpDecorate %5 Block
OpMemberDecorate %5 0 Offset 0
OpDecorate %3 Binding 0
OpDecorate %3 DescriptorSet 0
%6 = OpTypeVoid
%7 = OpTypeFunction %6
%8 = OpTypeInt 32 0
%9 = OpConstant %8 1
%4 = OpTypeArray %8 %9
%10 = OpTypePointer StorageBuffer %4
%5 = OpTypeStruct %4
%11 = OpTypePointer StorageBuffer %5
%3 = OpVariable %11 StorageBuffer
%12 = OpConstant %8 0
%13 = OpTypeBool
%14 = OpConstant %8 256
%15 = OpConstant %8 42
%16 = OpTypePointer StorageBuffer %8
21 changes: 21 additions & 0 deletions tests/ui/linker/array_stride_fixer/mixed_storage_classes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Test that mixed storage class usage results in proper ArrayStride handling

// compile-flags: -C llvm-args=--disassemble-globals
// only-vulkan1.1
// normalize-stderr-test "OpLine .*\n" -> ""
// normalize-stderr-test "OpSource .*\n" -> ""
// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/"
use spirv_std::spirv;

#[spirv(compute(threads(64)))]
pub fn main(
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] storage_data: &mut [u32; 256],
#[spirv(workgroup)] workgroup_data: &mut [u32; 256],
) {
// Both variables use the same array type [u32; 256] but in different storage classes:
// - storage_data is in StorageBuffer (requires ArrayStride)
// - workgroup_data is in Workgroup (forbids ArrayStride in SPIR-V 1.4+)

storage_data[0] = 42;
workgroup_data[0] = storage_data[0];
}
36 changes: 36 additions & 0 deletions tests/ui/linker/array_stride_fixer/mixed_storage_classes.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
OpCapability Shader
OpCapability Float64
OpCapability Int64
OpCapability Int16
OpCapability Int8
OpCapability ShaderClockKHR
OpCapability VulkanMemoryModel
OpExtension "SPV_KHR_shader_clock"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical Vulkan
OpEntryPoint GLCompute %1 "main"
OpExecutionMode %1 LocalSize 64 1 1
%2 = OpString "$OPSTRING_FILENAME/mixed_storage_classes.rs"
OpName %4 "workgroup_data"
OpDecorate %5 ArrayStride 4
OpDecorate %6 Block
OpMemberDecorate %6 0 Offset 0
OpDecorate %3 Binding 0
OpDecorate %3 DescriptorSet 0
%7 = OpTypeVoid
%8 = OpTypeFunction %7
%9 = OpTypeInt 32 0
%10 = OpConstant %9 256
%5 = OpTypeArray %9 %10
%11 = OpTypePointer StorageBuffer %5
%6 = OpTypeStruct %5
%12 = OpTypePointer StorageBuffer %6
%3 = OpVariable %12 StorageBuffer
%13 = OpConstant %9 0
%14 = OpTypeBool
%15 = OpTypePointer StorageBuffer %9
%16 = OpConstant %9 42
%17 = OpTypePointer Workgroup %9
%18 = OpTypeArray %9 %10
%19 = OpTypePointer Workgroup %18
%4 = OpVariable %19 Workgroup
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Test that ArrayStride decorations are removed from nested structs with arrays in Function storage class

// build-pass
// compile-flags: -C llvm-args=--disassemble-globals
// only-vulkan1.2
// normalize-stderr-test "OpLine .*\n" -> ""
// normalize-stderr-test "OpSource .*\n" -> ""
// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/"
use spirv_std::spirv;

#[derive(Copy, Clone)]
struct InnerStruct {
data: [f32; 4],
}

#[derive(Copy, Clone)]
struct OuterStruct {
inner: InnerStruct,
}

#[spirv(compute(threads(1)))]
pub fn main(
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] output: &mut [f32; 1],
) {
// Function-local variables with nested structs containing arrays
// Should have ArrayStride removed in SPIR-V 1.4+
let mut function_var = OuterStruct {
inner: InnerStruct { data: [0.0; 4] },
};
function_var.inner.data[0] = 42.0;
function_var.inner.data[1] = function_var.inner.data[0] + 1.0;
// Force usage to prevent optimization
output[0] = function_var.inner.data[1];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
OpCapability Shader
OpCapability Float64
OpCapability Int64
OpCapability Int16
OpCapability Int8
OpCapability ShaderClockKHR
OpCapability VulkanMemoryModel
OpExtension "SPV_KHR_shader_clock"
OpMemoryModel Logical Vulkan
OpEntryPoint GLCompute %1 "main" %2
OpExecutionMode %1 LocalSize 1 1 1
%3 = OpString "$OPSTRING_FILENAME/nested_structs_function_storage.rs"
OpName %4 "InnerStruct"
OpMemberName %4 0 "data"
OpName %5 "OuterStruct"
OpMemberName %5 0 "inner"
OpDecorate %6 ArrayStride 4
OpDecorate %7 Block
OpMemberDecorate %7 0 Offset 0
OpDecorate %2 Binding 0
OpDecorate %2 DescriptorSet 0
OpMemberDecorate %4 0 Offset 0
OpMemberDecorate %5 0 Offset 0
%8 = OpTypeFloat 32
%9 = OpTypeInt 32 0
%10 = OpConstant %9 1
%6 = OpTypeArray %8 %10
%7 = OpTypeStruct %6
%11 = OpTypePointer StorageBuffer %7
%12 = OpTypeVoid
%13 = OpTypeFunction %12
%14 = OpTypePointer StorageBuffer %6
%2 = OpVariable %11 StorageBuffer
%15 = OpConstant %9 0
%16 = OpConstant %9 4
%17 = OpTypeArray %8 %16
%4 = OpTypeStruct %17
%5 = OpTypeStruct %4
%18 = OpConstant %8 0
%19 = OpConstantComposite %17 %18 %18 %18 %18
%20 = OpUndef %5
%21 = OpTypeBool
%22 = OpConstant %8 1109917696
%23 = OpConstant %8 1065353216
%24 = OpTypePointer StorageBuffer %8
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Test that ArrayStride decorations are removed from private storage in SPIR-V 1.4

// build-pass
// compile-flags: -C llvm-args=--disassemble-globals
// normalize-stderr-test "OpLine .*\n" -> ""
// normalize-stderr-test "OpSource .*\n" -> ""
// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/"
// only-spv1.4
use spirv_std::spirv;

// Helper function to create an array in private storage
fn create_private_array() -> [u32; 4] {
[0, 1, 2, 3]
}

#[spirv(compute(threads(1)))]
pub fn main() {
// This creates a private storage array in SPIR-V 1.4+
// ArrayStride decorations should be removed
let mut private_array = create_private_array();
private_array[0] = 42;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
OpCapability Shader
OpCapability Float64
OpCapability Int64
OpCapability Int16
OpCapability Int8
OpCapability ShaderClockKHR
OpExtension "SPV_KHR_shader_clock"
OpMemoryModel Logical Simple
OpEntryPoint GLCompute %1 "main"
OpExecutionMode %1 LocalSize 1 1 1
%2 = OpString "$OPSTRING_FILENAME/private_storage_spirv14_removed.rs"
%4 = OpTypeVoid
%5 = OpTypeFunction %4
%6 = OpTypeInt 32 0
%7 = OpConstant %6 4
%8 = OpTypeArray %6 %7
%9 = OpTypeFunction %8
%10 = OpConstant %6 0
%11 = OpConstant %6 1
%12 = OpConstant %6 2
%13 = OpConstant %6 3
%14 = OpTypeBool
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Test that ArrayStride decorations are removed from runtime arrays in Workgroup storage class

// build-pass
// compile-flags: -C llvm-args=--disassemble-globals
// only-vulkan1.1
// normalize-stderr-test "OpLine .*\n" -> ""
// normalize-stderr-test "OpSource .*\n" -> ""
// normalize-stderr-test "\S*/lib/rustlib/" -> "$SYSROOT/lib/rustlib/"
use spirv_std::RuntimeArray;
use spirv_std::spirv;

#[spirv(compute(threads(64)))]
pub fn main(
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] output: &mut [u32; 1],
#[spirv(workgroup)] shared_array: &mut [u32; 256],
) {
// Workgroup arrays should have ArrayStride removed
shared_array[0] = 42;
shared_array[1] = shared_array[0] + 1;
// Force usage to prevent optimization
output[0] = shared_array[1];
}
Loading
Loading