Skip to content

Commit 8988e47

Browse files
authored
[SER] Diagnose payload in HitObject::TraceRay|Invoke (#7356)
- Generalize raypayload validation to HitObject::TraceRay|Invoke - Reject non-numeric payload types in [HitObject::]TraceRay|Invoke Specification: https://github.com/microsoft/hlsl-specs/blob/main/proposals/0027-shader-execution-reordering.md Bug: #7234 [SER] Diagnose and validate illegal use of HitObject in unsupported contexts
1 parent b4a3076 commit 8988e47

File tree

3 files changed

+133
-57
lines changed

3 files changed

+133
-57
lines changed

tools/clang/lib/Sema/SemaDXR.cpp

Lines changed: 84 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include "dxc/DXIL/DxilConstants.h"
3030
#include "dxc/DXIL/DxilShaderModel.h"
31+
#include "dxc/HlslIntrinsicOp.h"
3132

3233
using namespace clang;
3334
using namespace sema;
@@ -49,9 +50,9 @@ struct PayloadUse {
4950
const MemberExpr *Member = nullptr;
5051
};
5152

52-
struct TraceRayCall {
53-
TraceRayCall() = default;
54-
TraceRayCall(const CallExpr *Call, const CFGBlock *Parent)
53+
struct PayloadBuiltinCall {
54+
PayloadBuiltinCall() = default;
55+
PayloadBuiltinCall(const CallExpr *Call, const CFGBlock *Parent)
5556
: Call(Call), Parent(Parent) {}
5657
const CallExpr *Call = nullptr;
5758
const CFGBlock *Parent = nullptr;
@@ -71,7 +72,7 @@ struct DxrShaderDiagnoseInfo {
7172
const FunctionDecl *funcDecl;
7273
const VarDecl *Payload;
7374
DXIL::PayloadAccessShaderStage Stage;
74-
std::vector<TraceRayCall> TraceCalls;
75+
std::vector<PayloadBuiltinCall> PayloadBuiltinCalls;
7576
std::map<const FieldDecl *, std::vector<PayloadUse>> WritesPerField;
7677
std::map<const FieldDecl *, std::vector<PayloadUse>> ReadsPerField;
7778
std::vector<PayloadUse> PayloadAsCallArg;
@@ -121,24 +122,42 @@ GetPayloadQualifierForStage(FieldDecl *Field,
121122
return DXIL::PayloadAccessQualifier::NoAccess;
122123
}
123124

124-
// Returns the declaration of the payload used in a TraceRay call
125-
const VarDecl *GetPayloadParameterForTraceCall(const CallExpr *Trace) {
126-
const Decl *callee = Trace->getCalleeDecl();
127-
if (!callee)
125+
static int GetPayloadParamIdxForIntrinsic(const FunctionDecl *FD) {
126+
HLSLIntrinsicAttr *IntrinAttr = FD->getAttr<HLSLIntrinsicAttr>();
127+
if (!IntrinAttr)
128+
return -1;
129+
switch ((IntrinsicOp)IntrinAttr->getOpcode()) {
130+
default:
131+
return -1;
132+
case IntrinsicOp::IOP_TraceRay:
133+
case IntrinsicOp::MOP_DxHitObject_TraceRay:
134+
case IntrinsicOp::MOP_DxHitObject_Invoke:
135+
return FD->getNumParams() - 1;
136+
}
137+
}
138+
139+
static bool IsBuiltinWithPayload(const FunctionDecl *FD) {
140+
return GetPayloadParamIdxForIntrinsic(FD) >= 0;
141+
}
142+
143+
// Returns the declaration of the payload used in a call to TraceRay,
144+
// HitObject::TraceRay or HitObject::Invoke.
145+
const VarDecl *GetPayloadParameterForBuiltinCall(const CallExpr *Call) {
146+
const Decl *Callee = Call->getCalleeDecl();
147+
if (!Callee)
128148
return nullptr;
129149

130-
if (!isa<FunctionDecl>(callee))
150+
if (!isa<FunctionDecl>(Callee))
131151
return nullptr;
132152

133-
const FunctionDecl *FD = cast<FunctionDecl>(callee);
153+
int PldParamIdx = GetPayloadParamIdxForIntrinsic(cast<FunctionDecl>(Callee));
154+
if (PldParamIdx < 0)
155+
return nullptr;
134156

135-
if (FD->isImplicit() && FD->getName() == "TraceRay") {
136-
const Stmt *Param = IgnoreParensAndDecay(Trace->getArg(7));
137-
if (const DeclRefExpr *ParamRef = dyn_cast<DeclRefExpr>(Param)) {
138-
if (const VarDecl *Decl = dyn_cast<VarDecl>(ParamRef->getDecl()))
139-
return Decl;
140-
}
141-
}
157+
const Stmt *Param = IgnoreParensAndDecay(Call->getArg(PldParamIdx));
158+
if (const DeclRefExpr *ParamRef = dyn_cast<DeclRefExpr>(Param))
159+
if (const VarDecl *Decl = dyn_cast<VarDecl>(ParamRef->getDecl()))
160+
return Decl;
142161
return nullptr;
143162
}
144163

@@ -190,12 +209,9 @@ void CollectReadsWritesAndCallsForPayload(const Stmt *S,
190209
}
191210
}
192211

193-
// Collects all TraceRay calls.
194-
void CollectTraceRayCalls(const Stmt *S, DxrShaderDiagnoseInfo &Info,
195-
const CFGBlock *Block) {
196-
// TraceRay has void as return type so it should never be something else
197-
// than a plain CallExpr.
198-
212+
// Collects all calls to TraceRay, HitObject::TraceRay and HitObject::Invoke.
213+
void CollectBuiltinCallsWithPayload(const Stmt *S, DxrShaderDiagnoseInfo &Info,
214+
const CFGBlock *Block) {
199215
if (const CallExpr *Call = dyn_cast<CallExpr>(S)) {
200216

201217
const Decl *Callee = Call->getCalleeDecl();
@@ -204,11 +220,8 @@ void CollectTraceRayCalls(const Stmt *S, DxrShaderDiagnoseInfo &Info,
204220

205221
const FunctionDecl *CalledFunction = cast<FunctionDecl>(Callee);
206222

207-
// Ignore trace calls here.
208-
if (CalledFunction->isImplicit() &&
209-
CalledFunction->getName() == "TraceRay") {
210-
Info.TraceCalls.push_back({Call, Block});
211-
}
223+
if (IsBuiltinWithPayload(CalledFunction))
224+
Info.PayloadBuiltinCalls.push_back({Call, Block});
212225
}
213226
}
214227

@@ -528,13 +541,14 @@ void TraverseCFG(const CFGBlock &Block, Action PerElementAction,
528541
}
529542
}
530543

531-
// Forward traverse the CFG and collect calls to TraceRay.
532-
void ForwardTraverseCFGAndCollectTraceCalls(
544+
// Forward traverse the CFG and collect calls to TraceRay, HitObject::TraceRay
545+
// and HitObject::Invoke.
546+
void ForwardTraverseCFGAndCollectBuiltinCallsWithPayload(
533547
const CFGBlock &Block, DxrShaderDiagnoseInfo &Info,
534548
std::set<const CFGBlock *> &Visited) {
535549
auto Action = [&Info](const CFGBlock &Block, const CFGElement &Element) {
536550
if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
537-
CollectTraceRayCalls(S->getStmt(), Info, &Block);
551+
CollectBuiltinCallsWithPayload(S->getStmt(), Info, &Block);
538552
}
539553
};
540554

@@ -664,9 +678,9 @@ DiagnosePayloadAsFunctionArg(
664678
const FunctionDecl *CalledFunction = cast<FunctionDecl>(Callee);
665679

666680
// Ignore trace calls here.
667-
if (CalledFunction->isImplicit() &&
668-
CalledFunction->getName() == "TraceRay") {
669-
Info.TraceCalls.push_back(TraceRayCall{Call, Use.Parent});
681+
if (IsBuiltinWithPayload(CalledFunction)) {
682+
Info.PayloadBuiltinCalls.push_back(
683+
PayloadBuiltinCall{Call, Use.Parent});
670684
continue;
671685
}
672686

@@ -789,10 +803,12 @@ void HandlePayloadInitializer(DxrShaderDiagnoseInfo &Info) {
789803
}
790804
}
791805

792-
// Emit diagnostics for a TraceRay call.
793-
void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
794-
const TraceRayCall &Trace, DominatorTree &DT) {
795-
// For each TraceRay call check if write(caller) fields are written.
806+
// Emit diagnostics for this call to either TraceRay, HitObject::TraceRay or
807+
// HitObject::Invoke.
808+
void DiagnoseBuiltinCallWithPayload(Sema &S, const VarDecl *Payload,
809+
const PayloadBuiltinCall &PldCall,
810+
DominatorTree &DT) {
811+
// For each call check if write(caller) fields are written.
796812
const DXIL::PayloadAccessShaderStage CallerStage =
797813
DXIL::PayloadAccessShaderStage::Caller;
798814

@@ -810,6 +826,13 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
810826
return;
811827
}
812828

829+
// Verify that the payload type is legal
830+
if (!hlsl::IsHLSLCopyableAnnotatableRecord(Payload->getType())) {
831+
S.Diag(Payload->getLocation(), diag::err_payload_attrs_must_be_udt)
832+
<< /*payload|attributes|callable*/ 0 << Payload;
833+
return;
834+
}
835+
813836
if (ContainsLongVector(Payload->getType())) {
814837
const unsigned PayloadParametersIdx = 10;
815838
S.Diag(Payload->getLocation(), diag::err_hlsl_unsupported_long_vector)
@@ -832,12 +855,12 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
832855

833856
std::set<const CFGBlock *> Visited;
834857

835-
const CFGBlock *Parent = Trace.Parent;
858+
const CFGBlock *Parent = PldCall.Parent;
836859
Visited.insert(Parent);
837-
// Collect payload accesses in the same block until we reach the TraceRay call
860+
// Collect payload accesses in the same block until we reach the call
838861
for (auto Element : *Parent) {
839862
if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
840-
if (S->getStmt() == Trace.Call)
863+
if (S->getStmt() == PldCall.Call)
841864
break;
842865
CollectReadsWritesAndCallsForPayload(S->getStmt(), TraceInfo, Parent);
843866
}
@@ -850,10 +873,12 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
850873
BackwardTraverseCFGAndCollectReadsWrites(*Pred, TraceInfo, Visited);
851874
}
852875

876+
int PldArgIdx = PldCall.Call->getNumArgs() - 1;
877+
853878
// Warn if a writeable field has not been written.
854879
for (const FieldDecl *Field : WriteableFields) {
855880
if (!TraceInfo.WritesPerField.count(Field)) {
856-
S.Diag(Trace.Call->getArg(7)->getExprLoc(),
881+
S.Diag(PldCall.Call->getArg(PldArgIdx)->getExprLoc(),
857882
diag::warn_hlsl_payload_access_no_write_for_trace_payload)
858883
<< Field->getName();
859884
}
@@ -862,7 +887,7 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
862887
for (const FieldDecl *Field : NonWriteableFields) {
863888
if (TraceInfo.WritesPerField.count(Field)) {
864889
S.Diag(
865-
Trace.Call->getArg(7)->getExprLoc(),
890+
PldCall.Call->getArg(PldArgIdx)->getExprLoc(),
866891
diag::warn_hlsl_payload_access_write_but_no_write_for_trace_payload)
867892
<< Field->getName();
868893
}
@@ -878,7 +903,7 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
878903
bool CallFound = false;
879904
for (auto Element : *Parent) { // TODO: reverse iterate?
880905
if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
881-
if (S->getStmt() == Trace.Call) {
906+
if (S->getStmt() == PldCall.Call) {
882907
CallFound = true;
883908
continue;
884909
}
@@ -895,7 +920,7 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
895920

896921
for (const FieldDecl *Field : ReadableFields) {
897922
if (!TraceInfo.ReadsPerField.count(Field)) {
898-
S.Diag(Trace.Call->getArg(7)->getExprLoc(),
923+
S.Diag(PldCall.Call->getArg(PldArgIdx)->getExprLoc(),
899924
diag::warn_hlsl_payload_access_read_but_no_read_after_trace)
900925
<< Field->getName();
901926
}
@@ -928,27 +953,29 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
928953
}
929954
}
930955

931-
// Emit diagnostics for all TraceRay calls.
932-
void DiagnoseTraceCalls(Sema &S, CFG &ShaderCFG, DominatorTree &DT,
933-
DxrShaderDiagnoseInfo &Info) {
934-
// Collect TraceRay calls in the shader.
956+
// Emit diagnostics for all calls to TraceRay, HitObject::TraceRay or
957+
// HitObject::Invoke.
958+
void DiagnoseBuiltinCallsWithPayload(Sema &S, CFG &ShaderCFG, DominatorTree &DT,
959+
DxrShaderDiagnoseInfo &Info) {
960+
// Collect calls with payload in the shader.
935961
std::set<const CFGBlock *> Visited;
936-
ForwardTraverseCFGAndCollectTraceCalls(ShaderCFG.getEntry(), Info, Visited);
962+
ForwardTraverseCFGAndCollectBuiltinCallsWithPayload(ShaderCFG.getEntry(),
963+
Info, Visited);
937964

938965
std::set<const CallExpr *> Diagnosed;
939966

940-
for (const TraceRayCall &TraceCall : Info.TraceCalls) {
941-
if (Diagnosed.count(TraceCall.Call))
967+
for (const PayloadBuiltinCall &PldCall : Info.PayloadBuiltinCalls) {
968+
if (Diagnosed.count(PldCall.Call))
942969
continue;
943-
Diagnosed.insert(TraceCall.Call);
970+
Diagnosed.insert(PldCall.Call);
944971

945-
const VarDecl *Payload = GetPayloadParameterForTraceCall(TraceCall.Call);
946-
DiagnoseTraceCall(S, Payload, TraceCall, DT);
972+
const VarDecl *Payload = GetPayloadParameterForBuiltinCall(PldCall.Call);
973+
DiagnoseBuiltinCallWithPayload(S, Payload, PldCall, DT);
947974
}
948975
}
949976

950977
// Emit diagnostics for all access to the payload of a shader,
951-
// and the input to TraceRay calls.
978+
// and the input to TraceRay, HitObject::TraceRay or HitObject::Invoke calls.
952979
std::vector<const FieldDecl *>
953980
DiagnosePayloadAccess(Sema &S, DxrShaderDiagnoseInfo &Info,
954981
const std::set<const FieldDecl *> &FieldsToIgnoreRead,
@@ -1012,7 +1039,7 @@ DiagnosePayloadAccess(Sema &S, DxrShaderDiagnoseInfo &Info,
10121039
DiagnosePayloadReads(S, TheCFG, DT, Info, NonReadableFields);
10131040
}
10141041

1015-
DiagnoseTraceCalls(S, TheCFG, DT, Info);
1042+
DiagnoseBuiltinCallsWithPayload(S, TheCFG, DT, Info);
10161043

10171044
return WrittenFields;
10181045
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: %dxc -T lib_6_9 %s -D TEST_NUM=0 %s -verify
2+
// RUN: %dxc -T lib_6_9 %s -D TEST_NUM=1 %s -verify
3+
4+
RaytracingAccelerationStructure scene : register(t0);
5+
6+
struct Payload
7+
{
8+
int a : read (caller, closesthit, miss) : write(caller, closesthit, miss);
9+
};
10+
11+
struct Attribs
12+
{
13+
float2 barys;
14+
};
15+
16+
[shader("raygeneration")]
17+
void RayGen()
18+
{
19+
// expected-error@+1{{type 'Payload' used as payload requires that it is annotated with the [raypayload] attribute}}
20+
Payload payload_in_rg;
21+
RayDesc ray;
22+
#if TEST_NUM == 0
23+
dx::HitObject::TraceRay( scene, RAY_FLAG_NONE, 0xff, 0, 1, 0, ray, payload_in_rg );
24+
#else
25+
dx::HitObject::Invoke( dx::HitObject(), payload_in_rg );
26+
#endif
27+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %dxc -T lib_6_9 %s -verify
2+
3+
struct
4+
[raypayload]
5+
Payload
6+
{
7+
int a : read(caller, closesthit, miss) : write(caller, closesthit, miss);
8+
dx::HitObject hit;
9+
};
10+
11+
struct Attribs
12+
{
13+
float2 barys;
14+
};
15+
16+
[shader("raygeneration")]
17+
void RayGen()
18+
{
19+
// expected-error@+1{{payload parameter 'payload_in_rg' must be a user-defined type composed of only numeric types}}
20+
Payload payload_in_rg;
21+
dx::HitObject::Invoke( dx::HitObject(), payload_in_rg );
22+
}

0 commit comments

Comments
 (0)