Skip to content

Commit 31a2f58

Browse files
authored
[SPIR-V] Fix usage of indices in subfunctions (#7242)
The parameters tagged with indices are linked to a builtin. Because their layout is different between HLSL and SPIR-V, there is a common mechanism to handle those 'stage I/O variables'. Usually, a local variable with the correct HLSL layout is created, and when required, the value is copied in and copied out in the entrypoint wrapper. Then, a function-scoped pointer is passed to sub-functions. The issue is that `indices` marks an array which is also shared across invocations. Meaning we cannot simple copy-in/copy-out. We are only allowed to write to the indices touched by the shader. This required pushing the handling to the assignment expression handling: when a value is assigned to such builtin, the layout transformation is done, and the builtin written to. Issue was how to find back the Builtin from an assignment: the code assumed the ParmDecl of the entrypoint was the only way to access this variable, but nothing prevents the user to pass this indice array to another function. The simple solution is to move this out of the generic map, and have a new field which stored the SpirvVariable we created, and allow any HLSL function to access this as soon as the HLSLIndices attribute is found. Fixes #7009 --------- Signed-off-by: Nathan Gauër <[email protected]>
1 parent d8aad78 commit 31a2f58

File tree

4 files changed

+108
-12
lines changed

4 files changed

+108
-12
lines changed

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
860860
QualType arrayType = astContext.getConstantArrayType(
861861
type, llvm::APInt(32, arraySize), clang::ArrayType::Normal, 0);
862862

863-
stageVarInstructions[cast<DeclaratorDecl>(decl)] =
863+
msOutIndicesBuiltin =
864864
getBuiltinVar(builtinID, arrayType, decl->getLocation());
865865
} else {
866866
// For NV_mesh_shader, the built type is PrimitiveIndicesNV
@@ -871,7 +871,7 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
871871
astContext.UnsignedIntTy, llvm::APInt(32, arraySize),
872872
clang::ArrayType::Normal, 0);
873873

874-
stageVarInstructions[cast<DeclaratorDecl>(decl)] =
874+
msOutIndicesBuiltin =
875875
getBuiltinVar(builtinID, arrayType, decl->getLocation());
876876
}
877877

tools/clang/lib/SPIRV/DeclResultIdMapper.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,11 @@ class DeclResultIdMapper {
559559
return value;
560560
}
561561

562+
SpirvVariable *getMSOutIndicesBuiltin() {
563+
assert(msOutIndicesBuiltin && "Variable usage before decl parsing.");
564+
return msOutIndicesBuiltin;
565+
}
566+
562567
/// Decorate with spirv intrinsic attributes with lamda function variable
563568
/// check
564569
void decorateWithIntrinsicAttrs(
@@ -1014,6 +1019,25 @@ class DeclResultIdMapper {
10141019
/// creating that stage variable, so that we don't need to query them again
10151020
/// for reading and writing.
10161021
llvm::DenseMap<const ValueDecl *, SpirvVariable *> stageVarInstructions;
1022+
1023+
/// Special case for the Indices builtin:
1024+
/// - this builtin has a different layout in HLSL & SPIR-V, meaning it
1025+
/// requires
1026+
/// the same kind of handling as classic stageVarInstructions:
1027+
/// -> load into a HLSL compatible tmp
1028+
/// -> write back into the SPIR-V compatible layout.
1029+
/// - but the builtin is shared across invocations (not only lanes).
1030+
/// -> we must only write/read from the indices requested by the user.
1031+
/// - the variable can be passed to other functions as a out param
1032+
/// -> we cannot copy-in/copy-out because shared across invocations.
1033+
/// -> we cannot pass a simple pointer: layout differences between
1034+
/// HLSL/SPIR-V.
1035+
///
1036+
/// All this means we must keep track of the builtin, and each assignment to
1037+
/// this will have to handle the layout differences. The easiest solution is
1038+
/// to keep this builtin global to the module if present.
1039+
SpirvVariable *msOutIndicesBuiltin = nullptr;
1040+
10171041
/// Vector of all defined resource variables.
10181042
llvm::SmallVector<ResourceVar, 8> resourceVars;
10191043
/// Mapping from {RW|Append|Consume}StructuredBuffers to their

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8133,17 +8133,21 @@ void SpirvEmitter::assignToMSOutIndices(
81338133
if (indices.size() > 1) {
81348134
vecComponent = indices.back();
81358135
}
8136-
auto *var = declIdMapper.getStageVarInstruction(decl);
8137-
const auto *varTypeDecl = astContext.getAsConstantArrayType(decl->getType());
8138-
QualType varType = varTypeDecl->getElementType();
8136+
SpirvVariable *var = declIdMapper.getMSOutIndicesBuiltin();
8137+
81398138
uint32_t numVertices = 1;
8140-
if (!isVectorType(varType, nullptr, &numVertices)) {
8141-
assert(isScalarType(varType));
8142-
}
8143-
QualType valueType = value->getAstResultType();
81448139
uint32_t numValues = 1;
8145-
if (!isVectorType(valueType, nullptr, &numValues)) {
8146-
assert(isScalarType(valueType));
8140+
{
8141+
const auto *varTypeDecl =
8142+
astContext.getAsConstantArrayType(decl->getType());
8143+
QualType varType = varTypeDecl->getElementType();
8144+
if (!isVectorType(varType, nullptr, &numVertices)) {
8145+
assert(isScalarType(varType));
8146+
}
8147+
QualType valueType = value->getAstResultType();
8148+
if (!isVectorType(valueType, nullptr, &numValues)) {
8149+
assert(isScalarType(valueType));
8150+
}
81478151
}
81488152

81498153
const auto loc = decl->getLocation();
@@ -8190,7 +8194,10 @@ void SpirvEmitter::assignToMSOutIndices(
81908194
assert(numValues == numVertices);
81918195
if (extMesh) {
81928196
// create accesschain for Primitive*IndicesEXT[vertIndex].
8193-
auto *ptr = spvBuilder.createAccessChain(varType, var, vertIndex, loc);
8197+
const ConstantArrayType *CAT =
8198+
astContext.getAsConstantArrayType(var->getAstResultType());
8199+
auto *ptr = spvBuilder.createAccessChain(CAT->getElementType(), var,
8200+
vertIndex, loc);
81948201
// finally create store for Primitive*IndicesEXT[vertIndex] = value.
81958202
spvBuilder.createStore(ptr, value, loc);
81968203
} else {
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// RUN: %dxc -T ms_6_5 -E outie -fcgl %s -spirv | FileCheck %s
2+
// RUN: %dxc -T ms_6_5 -E innie -fcgl %s -spirv | FileCheck %s
3+
4+
// CHECK-DAG: [[v4_n05_05_0_1:%[0-9]+]] = OpConstantComposite %v4float %float_n0_5 %float_0_5 %float_0 %float_1
5+
// CHECK-DAG: [[v4_05_05_0_1:%[0-9]+]] = OpConstantComposite %v4float %float_0_5 %float_0_5 %float_0 %float_1
6+
// CHECK-DAG: [[v4_0_n05_0_1:%[0-9]+]] = OpConstantComposite %v4float %float_0 %float_n0_5 %float_0 %float_1
7+
// CHECK-DAG: [[v3_1_0_0:%[0-9]+]] = OpConstantComposite %v3float %float_1 %float_0 %float_0
8+
// CHECK-DAG: [[v3_0_1_0:%[0-9]+]] = OpConstantComposite %v3float %float_0 %float_1 %float_0
9+
// CHECK-DAG: [[v3_0_0_1:%[0-9]+]] = OpConstantComposite %v3float %float_0 %float_0 %float_1
10+
// CHECK-DAG: [[u3_0_1_2:%[0-9]+]] = OpConstantComposite %v3uint %uint_0 %uint_1 %uint_2
11+
12+
// CHECK-DAG: OpDecorate [[indices:%[0-9]+]] BuiltIn PrimitiveIndicesNV
13+
14+
struct MeshOutput {
15+
float4 position : SV_Position;
16+
float3 color : COLOR0;
17+
};
18+
19+
[outputtopology("triangle")]
20+
[numthreads(1, 1, 1)]
21+
void innie(out indices uint3 triangles[1], out vertices MeshOutput verts[3]) {
22+
SetMeshOutputCounts(3, 2);
23+
24+
triangles[0] = uint3(0, 1, 2);
25+
// CHECK: [[off:%[0-9]+]] = OpIMul %uint %uint_0 %uint_3
26+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Output_uint [[indices]] [[off]]
27+
// CHECK: [[tmp:%[0-9]+]] = OpCompositeExtract %uint [[u3_0_1_2]] 0
28+
// CHECK: OpStore [[ptr]] [[tmp]]
29+
// CHECK: [[idx:%[0-9]+]] = OpIAdd %uint [[off]] %uint_1
30+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Output_uint [[indices]] [[idx]]
31+
// CHECK: [[tmp:%[0-9]+]] = OpCompositeExtract %uint [[u3_0_1_2]] 1
32+
// CHECK: OpStore [[ptr]] [[tmp]]
33+
// CHECK: [[idx:%[0-9]+]] = OpIAdd %uint [[off]] %uint_2
34+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Output_uint [[indices]] [[idx]]
35+
// CHECK: [[tmp:%[0-9]+]] = OpCompositeExtract %uint [[u3_0_1_2]] 2
36+
// CHECK: OpStore [[ptr]] [[tmp]]
37+
38+
verts[0].position = float4(-0.5, 0.5, 0.0, 1.0);
39+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Output_v4float %gl_Position %int_0
40+
// CHECK: OpStore [[ptr]] [[v4_n05_05_0_1]]
41+
verts[0].color = float3(1.0, 0.0, 0.0);
42+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Output_v3float %out_var_COLOR0 %int_0
43+
// CHECK: OpStore [[ptr]] [[v3_1_0_0]]
44+
45+
verts[1].position = float4(0.5, 0.5, 0.0, 1.0);
46+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Output_v4float %gl_Position %int_1
47+
// CHECK: OpStore [[ptr]] [[v4_05_05_0_1]]
48+
verts[1].color = float3(0.0, 1.0, 0.0);
49+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Output_v3float %out_var_COLOR0 %int_1
50+
// CHECK: OpStore [[ptr]] [[v3_0_1_0]]
51+
52+
verts[2].position = float4(0.0, -0.5, 0.0, 1.0);
53+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Output_v4float %gl_Position %int_2
54+
// CHECK: OpStore [[ptr]] [[v4_0_n05_0_1]]
55+
verts[2].color = float3(0.0, 0.0, 1.0);
56+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Output_v3float %out_var_COLOR0 %int_2
57+
// CHECK: OpStore [[ptr]] [[v3_0_0_1]]
58+
59+
}
60+
61+
[outputtopology("triangle")]
62+
[numthreads(1, 1, 1)]
63+
void outie(out indices uint3 triangles[1], out vertices MeshOutput verts[3]) {
64+
innie(triangles, verts);
65+
}

0 commit comments

Comments
 (0)