28
28
29
29
#include " dxc/DXIL/DxilConstants.h"
30
30
#include " dxc/DXIL/DxilShaderModel.h"
31
+ #include " dxc/HlslIntrinsicOp.h"
31
32
32
33
using namespace clang ;
33
34
using namespace sema ;
@@ -49,9 +50,9 @@ struct PayloadUse {
49
50
const MemberExpr *Member = nullptr ;
50
51
};
51
52
52
- struct TraceRayCall {
53
- TraceRayCall () = default ;
54
- TraceRayCall (const CallExpr *Call, const CFGBlock *Parent)
53
+ struct PayloadBuiltinCall {
54
+ PayloadBuiltinCall () = default ;
55
+ PayloadBuiltinCall (const CallExpr *Call, const CFGBlock *Parent)
55
56
: Call(Call), Parent(Parent) {}
56
57
const CallExpr *Call = nullptr ;
57
58
const CFGBlock *Parent = nullptr ;
@@ -71,7 +72,7 @@ struct DxrShaderDiagnoseInfo {
71
72
const FunctionDecl *funcDecl;
72
73
const VarDecl *Payload;
73
74
DXIL::PayloadAccessShaderStage Stage;
74
- std::vector<TraceRayCall> TraceCalls ;
75
+ std::vector<PayloadBuiltinCall> PayloadBuiltinCalls ;
75
76
std::map<const FieldDecl *, std::vector<PayloadUse>> WritesPerField;
76
77
std::map<const FieldDecl *, std::vector<PayloadUse>> ReadsPerField;
77
78
std::vector<PayloadUse> PayloadAsCallArg;
@@ -121,24 +122,42 @@ GetPayloadQualifierForStage(FieldDecl *Field,
121
122
return DXIL::PayloadAccessQualifier::NoAccess;
122
123
}
123
124
124
- // Returns the declaration of the payload used in a TraceRay call
125
- const VarDecl *GetPayloadParameterForTraceCall (const CallExpr *Trace) {
126
- const Decl *callee = Trace->getCalleeDecl ();
127
- if (!callee)
125
+ static int GetPayloadParamIdxForIntrinsic (const FunctionDecl *FD) {
126
+ HLSLIntrinsicAttr *IntrinAttr = FD->getAttr <HLSLIntrinsicAttr>();
127
+ if (!IntrinAttr)
128
+ return -1 ;
129
+ switch ((IntrinsicOp)IntrinAttr->getOpcode ()) {
130
+ default :
131
+ return -1 ;
132
+ case IntrinsicOp::IOP_TraceRay:
133
+ case IntrinsicOp::MOP_DxHitObject_TraceRay:
134
+ case IntrinsicOp::MOP_DxHitObject_Invoke:
135
+ return FD->getNumParams () - 1 ;
136
+ }
137
+ }
138
+
139
+ static bool IsBuiltinWithPayload (const FunctionDecl *FD) {
140
+ return GetPayloadParamIdxForIntrinsic (FD) >= 0 ;
141
+ }
142
+
143
+ // Returns the declaration of the payload used in a call to TraceRay,
144
+ // HitObject::TraceRay or HitObject::Invoke.
145
+ const VarDecl *GetPayloadParameterForBuiltinCall (const CallExpr *Call) {
146
+ const Decl *Callee = Call->getCalleeDecl ();
147
+ if (!Callee)
128
148
return nullptr ;
129
149
130
- if (!isa<FunctionDecl>(callee ))
150
+ if (!isa<FunctionDecl>(Callee ))
131
151
return nullptr ;
132
152
133
- const FunctionDecl *FD = cast<FunctionDecl>(callee);
153
+ int PldParamIdx = GetPayloadParamIdxForIntrinsic (cast<FunctionDecl>(Callee));
154
+ if (PldParamIdx < 0 )
155
+ return nullptr ;
134
156
135
- if (FD->isImplicit () && FD->getName () == " TraceRay" ) {
136
- const Stmt *Param = IgnoreParensAndDecay (Trace->getArg (7 ));
137
- if (const DeclRefExpr *ParamRef = dyn_cast<DeclRefExpr>(Param)) {
138
- if (const VarDecl *Decl = dyn_cast<VarDecl>(ParamRef->getDecl ()))
139
- return Decl;
140
- }
141
- }
157
+ const Stmt *Param = IgnoreParensAndDecay (Call->getArg (PldParamIdx));
158
+ if (const DeclRefExpr *ParamRef = dyn_cast<DeclRefExpr>(Param))
159
+ if (const VarDecl *Decl = dyn_cast<VarDecl>(ParamRef->getDecl ()))
160
+ return Decl;
142
161
return nullptr ;
143
162
}
144
163
@@ -190,12 +209,9 @@ void CollectReadsWritesAndCallsForPayload(const Stmt *S,
190
209
}
191
210
}
192
211
193
- // Collects all TraceRay calls.
194
- void CollectTraceRayCalls (const Stmt *S, DxrShaderDiagnoseInfo &Info,
195
- const CFGBlock *Block) {
196
- // TraceRay has void as return type so it should never be something else
197
- // than a plain CallExpr.
198
-
212
+ // Collects all calls to TraceRay, HitObject::TraceRay and HitObject::Invoke.
213
+ void CollectBuiltinCallsWithPayload (const Stmt *S, DxrShaderDiagnoseInfo &Info,
214
+ const CFGBlock *Block) {
199
215
if (const CallExpr *Call = dyn_cast<CallExpr>(S)) {
200
216
201
217
const Decl *Callee = Call->getCalleeDecl ();
@@ -204,11 +220,8 @@ void CollectTraceRayCalls(const Stmt *S, DxrShaderDiagnoseInfo &Info,
204
220
205
221
const FunctionDecl *CalledFunction = cast<FunctionDecl>(Callee);
206
222
207
- // Ignore trace calls here.
208
- if (CalledFunction->isImplicit () &&
209
- CalledFunction->getName () == " TraceRay" ) {
210
- Info.TraceCalls .push_back ({Call, Block});
211
- }
223
+ if (IsBuiltinWithPayload (CalledFunction))
224
+ Info.PayloadBuiltinCalls .push_back ({Call, Block});
212
225
}
213
226
}
214
227
@@ -528,13 +541,14 @@ void TraverseCFG(const CFGBlock &Block, Action PerElementAction,
528
541
}
529
542
}
530
543
531
- // Forward traverse the CFG and collect calls to TraceRay.
532
- void ForwardTraverseCFGAndCollectTraceCalls (
544
+ // Forward traverse the CFG and collect calls to TraceRay, HitObject::TraceRay
545
+ // and HitObject::Invoke.
546
+ void ForwardTraverseCFGAndCollectBuiltinCallsWithPayload (
533
547
const CFGBlock &Block, DxrShaderDiagnoseInfo &Info,
534
548
std::set<const CFGBlock *> &Visited) {
535
549
auto Action = [&Info](const CFGBlock &Block, const CFGElement &Element) {
536
550
if (Optional<CFGStmt> S = Element.getAs <CFGStmt>()) {
537
- CollectTraceRayCalls (S->getStmt (), Info, &Block);
551
+ CollectBuiltinCallsWithPayload (S->getStmt (), Info, &Block);
538
552
}
539
553
};
540
554
@@ -664,9 +678,9 @@ DiagnosePayloadAsFunctionArg(
664
678
const FunctionDecl *CalledFunction = cast<FunctionDecl>(Callee);
665
679
666
680
// Ignore trace calls here.
667
- if (CalledFunction-> isImplicit () &&
668
- CalledFunction-> getName () == " TraceRay " ) {
669
- Info. TraceCalls . push_back (TraceRayCall {Call, Use.Parent });
681
+ if (IsBuiltinWithPayload (CalledFunction)) {
682
+ Info. PayloadBuiltinCalls . push_back (
683
+ PayloadBuiltinCall {Call, Use.Parent });
670
684
continue ;
671
685
}
672
686
@@ -789,10 +803,12 @@ void HandlePayloadInitializer(DxrShaderDiagnoseInfo &Info) {
789
803
}
790
804
}
791
805
792
- // Emit diagnostics for a TraceRay call.
793
- void DiagnoseTraceCall (Sema &S, const VarDecl *Payload,
794
- const TraceRayCall &Trace, DominatorTree &DT) {
795
- // For each TraceRay call check if write(caller) fields are written.
806
+ // Emit diagnostics for this call to either TraceRay, HitObject::TraceRay or
807
+ // HitObject::Invoke.
808
+ void DiagnoseBuiltinCallWithPayload (Sema &S, const VarDecl *Payload,
809
+ const PayloadBuiltinCall &PldCall,
810
+ DominatorTree &DT) {
811
+ // For each call check if write(caller) fields are written.
796
812
const DXIL::PayloadAccessShaderStage CallerStage =
797
813
DXIL::PayloadAccessShaderStage::Caller;
798
814
@@ -810,6 +826,13 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
810
826
return ;
811
827
}
812
828
829
+ // Verify that the payload type is legal
830
+ if (!hlsl::IsHLSLCopyableAnnotatableRecord (Payload->getType ())) {
831
+ S.Diag (Payload->getLocation (), diag::err_payload_attrs_must_be_udt)
832
+ << /* payload|attributes|callable*/ 0 << Payload;
833
+ return ;
834
+ }
835
+
813
836
if (ContainsLongVector (Payload->getType ())) {
814
837
const unsigned PayloadParametersIdx = 10 ;
815
838
S.Diag (Payload->getLocation (), diag::err_hlsl_unsupported_long_vector)
@@ -832,12 +855,12 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
832
855
833
856
std::set<const CFGBlock *> Visited;
834
857
835
- const CFGBlock *Parent = Trace .Parent ;
858
+ const CFGBlock *Parent = PldCall .Parent ;
836
859
Visited.insert (Parent);
837
- // Collect payload accesses in the same block until we reach the TraceRay call
860
+ // Collect payload accesses in the same block until we reach the call
838
861
for (auto Element : *Parent) {
839
862
if (Optional<CFGStmt> S = Element.getAs <CFGStmt>()) {
840
- if (S->getStmt () == Trace .Call )
863
+ if (S->getStmt () == PldCall .Call )
841
864
break ;
842
865
CollectReadsWritesAndCallsForPayload (S->getStmt (), TraceInfo, Parent);
843
866
}
@@ -850,10 +873,12 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
850
873
BackwardTraverseCFGAndCollectReadsWrites (*Pred, TraceInfo, Visited);
851
874
}
852
875
876
+ int PldArgIdx = PldCall.Call ->getNumArgs () - 1 ;
877
+
853
878
// Warn if a writeable field has not been written.
854
879
for (const FieldDecl *Field : WriteableFields) {
855
880
if (!TraceInfo.WritesPerField .count (Field)) {
856
- S.Diag (Trace .Call ->getArg (7 )->getExprLoc (),
881
+ S.Diag (PldCall .Call ->getArg (PldArgIdx )->getExprLoc (),
857
882
diag::warn_hlsl_payload_access_no_write_for_trace_payload)
858
883
<< Field->getName ();
859
884
}
@@ -862,7 +887,7 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
862
887
for (const FieldDecl *Field : NonWriteableFields) {
863
888
if (TraceInfo.WritesPerField .count (Field)) {
864
889
S.Diag (
865
- Trace .Call ->getArg (7 )->getExprLoc (),
890
+ PldCall .Call ->getArg (PldArgIdx )->getExprLoc (),
866
891
diag::warn_hlsl_payload_access_write_but_no_write_for_trace_payload)
867
892
<< Field->getName ();
868
893
}
@@ -878,7 +903,7 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
878
903
bool CallFound = false ;
879
904
for (auto Element : *Parent) { // TODO: reverse iterate?
880
905
if (Optional<CFGStmt> S = Element.getAs <CFGStmt>()) {
881
- if (S->getStmt () == Trace .Call ) {
906
+ if (S->getStmt () == PldCall .Call ) {
882
907
CallFound = true ;
883
908
continue ;
884
909
}
@@ -895,7 +920,7 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
895
920
896
921
for (const FieldDecl *Field : ReadableFields) {
897
922
if (!TraceInfo.ReadsPerField .count (Field)) {
898
- S.Diag (Trace .Call ->getArg (7 )->getExprLoc (),
923
+ S.Diag (PldCall .Call ->getArg (PldArgIdx )->getExprLoc (),
899
924
diag::warn_hlsl_payload_access_read_but_no_read_after_trace)
900
925
<< Field->getName ();
901
926
}
@@ -928,27 +953,29 @@ void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
928
953
}
929
954
}
930
955
931
- // Emit diagnostics for all TraceRay calls.
932
- void DiagnoseTraceCalls (Sema &S, CFG &ShaderCFG, DominatorTree &DT,
933
- DxrShaderDiagnoseInfo &Info) {
934
- // Collect TraceRay calls in the shader.
956
+ // Emit diagnostics for all calls to TraceRay, HitObject::TraceRay or
957
+ // HitObject::Invoke.
958
+ void DiagnoseBuiltinCallsWithPayload (Sema &S, CFG &ShaderCFG, DominatorTree &DT,
959
+ DxrShaderDiagnoseInfo &Info) {
960
+ // Collect calls with payload in the shader.
935
961
std::set<const CFGBlock *> Visited;
936
- ForwardTraverseCFGAndCollectTraceCalls (ShaderCFG.getEntry (), Info, Visited);
962
+ ForwardTraverseCFGAndCollectBuiltinCallsWithPayload (ShaderCFG.getEntry (),
963
+ Info, Visited);
937
964
938
965
std::set<const CallExpr *> Diagnosed;
939
966
940
- for (const TraceRayCall &TraceCall : Info.TraceCalls ) {
941
- if (Diagnosed.count (TraceCall .Call ))
967
+ for (const PayloadBuiltinCall &PldCall : Info.PayloadBuiltinCalls ) {
968
+ if (Diagnosed.count (PldCall .Call ))
942
969
continue ;
943
- Diagnosed.insert (TraceCall .Call );
970
+ Diagnosed.insert (PldCall .Call );
944
971
945
- const VarDecl *Payload = GetPayloadParameterForTraceCall (TraceCall .Call );
946
- DiagnoseTraceCall (S, Payload, TraceCall , DT);
972
+ const VarDecl *Payload = GetPayloadParameterForBuiltinCall (PldCall .Call );
973
+ DiagnoseBuiltinCallWithPayload (S, Payload, PldCall , DT);
947
974
}
948
975
}
949
976
950
977
// Emit diagnostics for all access to the payload of a shader,
951
- // and the input to TraceRay calls.
978
+ // and the input to TraceRay, HitObject::TraceRay or HitObject::Invoke calls.
952
979
std::vector<const FieldDecl *>
953
980
DiagnosePayloadAccess (Sema &S, DxrShaderDiagnoseInfo &Info,
954
981
const std::set<const FieldDecl *> &FieldsToIgnoreRead,
@@ -1012,7 +1039,7 @@ DiagnosePayloadAccess(Sema &S, DxrShaderDiagnoseInfo &Info,
1012
1039
DiagnosePayloadReads (S, TheCFG, DT, Info, NonReadableFields);
1013
1040
}
1014
1041
1015
- DiagnoseTraceCalls (S, TheCFG, DT, Info);
1042
+ DiagnoseBuiltinCallsWithPayload (S, TheCFG, DT, Info);
1016
1043
1017
1044
return WrittenFields;
1018
1045
}
0 commit comments