@@ -325,6 +325,19 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
325
325
};
326
326
} // namespace
327
327
328
+ static uint32_t GetIntConstAttrArg (ASTContext &astContext, const Expr *expr,
329
+ uint32_t defaultVal = 0 ) {
330
+ if (expr) {
331
+ llvm::APSInt apsInt;
332
+ APValue apValue;
333
+ if (expr->isIntegerConstantExpr (apsInt, astContext))
334
+ return (uint32_t )apsInt.getSExtValue ();
335
+ if (expr->isVulkanSpecConstantExpr (astContext, &apValue) && apValue.isInt ())
336
+ return (uint32_t )apValue.getInt ().getSExtValue ();
337
+ }
338
+ return defaultVal;
339
+ }
340
+
328
341
// ------------------------------------------------------------------------------
329
342
//
330
343
// CGMSHLSLRuntime methods.
@@ -1419,6 +1432,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
1419
1432
}
1420
1433
1421
1434
DiagnosticsEngine &Diags = CGM.getDiags ();
1435
+ ASTContext &astContext = CGM.getTypes ().getContext ();
1422
1436
1423
1437
std::unique_ptr<DxilFunctionProps> funcProps =
1424
1438
llvm::make_unique<DxilFunctionProps>();
@@ -1629,10 +1643,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
1629
1643
1630
1644
// Populate numThreads
1631
1645
if (const HLSLNumThreadsAttr *Attr = FD->getAttr <HLSLNumThreadsAttr>()) {
1632
-
1633
- funcProps->numThreads [0 ] = Attr->getX ();
1634
- funcProps->numThreads [1 ] = Attr->getY ();
1635
- funcProps->numThreads [2 ] = Attr->getZ ();
1646
+ funcProps->numThreads [0 ] = GetIntConstAttrArg (astContext, Attr->getX (), 1 );
1647
+ funcProps->numThreads [1 ] = GetIntConstAttrArg (astContext, Attr->getY (), 1 );
1648
+ funcProps->numThreads [2 ] = GetIntConstAttrArg (astContext, Attr->getZ (), 1 );
1636
1649
1637
1650
if (isEntry && !SM->IsCS () && !SM->IsMS () && !SM->IsAS ()) {
1638
1651
unsigned DiagID = Diags.getCustomDiagID (
@@ -1805,7 +1818,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
1805
1818
1806
1819
if (const auto *pAttr = FD->getAttr <HLSLNodeIdAttr>()) {
1807
1820
funcProps->NodeShaderID .Name = pAttr->getName ().str ();
1808
- funcProps->NodeShaderID .Index = pAttr->getArrayIndex ();
1821
+ funcProps->NodeShaderID .Index =
1822
+ GetIntConstAttrArg (astContext, pAttr->getArrayIndex (), 0 );
1809
1823
} else {
1810
1824
funcProps->NodeShaderID .Name = FD->getName ().str ();
1811
1825
funcProps->NodeShaderID .Index = 0 ;
@@ -1816,20 +1830,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
1816
1830
}
1817
1831
if (const auto *pAttr = FD->getAttr <HLSLNodeShareInputOfAttr>()) {
1818
1832
funcProps->NodeShaderSharedInput .Name = pAttr->getName ().str ();
1819
- funcProps->NodeShaderSharedInput .Index = pAttr->getArrayIndex ();
1833
+ funcProps->NodeShaderSharedInput .Index =
1834
+ GetIntConstAttrArg (astContext, pAttr->getArrayIndex (), 0 );
1820
1835
}
1821
1836
if (const auto *pAttr = FD->getAttr <HLSLNodeDispatchGridAttr>()) {
1822
- funcProps->Node .DispatchGrid [0 ] = pAttr->getX ();
1823
- funcProps->Node .DispatchGrid [1 ] = pAttr->getY ();
1824
- funcProps->Node .DispatchGrid [2 ] = pAttr->getZ ();
1837
+ funcProps->Node .DispatchGrid [0 ] =
1838
+ GetIntConstAttrArg (astContext, pAttr->getX (), 1 );
1839
+ funcProps->Node .DispatchGrid [1 ] =
1840
+ GetIntConstAttrArg (astContext, pAttr->getY (), 1 );
1841
+ funcProps->Node .DispatchGrid [2 ] =
1842
+ GetIntConstAttrArg (astContext, pAttr->getZ (), 1 );
1825
1843
}
1826
1844
if (const auto *pAttr = FD->getAttr <HLSLNodeMaxDispatchGridAttr>()) {
1827
- funcProps->Node .MaxDispatchGrid [0 ] = pAttr->getX ();
1828
- funcProps->Node .MaxDispatchGrid [1 ] = pAttr->getY ();
1829
- funcProps->Node .MaxDispatchGrid [2 ] = pAttr->getZ ();
1845
+ funcProps->Node .MaxDispatchGrid [0 ] =
1846
+ GetIntConstAttrArg (astContext, pAttr->getX (), 1 );
1847
+ funcProps->Node .MaxDispatchGrid [1 ] =
1848
+ GetIntConstAttrArg (astContext, pAttr->getY (), 1 );
1849
+ funcProps->Node .MaxDispatchGrid [2 ] =
1850
+ GetIntConstAttrArg (astContext, pAttr->getZ (), 1 );
1830
1851
}
1831
1852
if (const auto *pAttr = FD->getAttr <HLSLNodeMaxRecursionDepthAttr>()) {
1832
- funcProps->Node .MaxRecursionDepth = pAttr->getCount ();
1853
+ funcProps->Node .MaxRecursionDepth =
1854
+ GetIntConstAttrArg (astContext, pAttr->getCount (), 0 );
1833
1855
}
1834
1856
if (!FD->getAttr <HLSLNumThreadsAttr>()) {
1835
1857
// NumThreads wasn't specified.
@@ -2343,8 +2365,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
2343
2365
NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++;
2344
2366
2345
2367
if (parmDecl->hasAttr <HLSLMaxRecordsAttr>()) {
2346
- node.MaxRecords =
2347
- parmDecl->getAttr <HLSLMaxRecordsAttr>()->getMaxCount ();
2368
+ node.MaxRecords = GetIntConstAttrArg (
2369
+ astContext,
2370
+ parmDecl->getAttr <HLSLMaxRecordsAttr>()->getMaxCount (), 1 );
2348
2371
}
2349
2372
if (parmDecl->hasAttr <HLSLGloballyCoherentAttr>())
2350
2373
node.Flags .SetGloballyCoherent ();
@@ -2375,7 +2398,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
2375
2398
// OutputID from attribute
2376
2399
if (const auto *Attr = parmDecl->getAttr <HLSLNodeIdAttr>()) {
2377
2400
node.OutputID .Name = Attr->getName ().str ();
2378
- node.OutputID .Index = Attr->getArrayIndex ();
2401
+ node.OutputID .Index =
2402
+ GetIntConstAttrArg (astContext, Attr->getArrayIndex (), 0 );
2379
2403
} else {
2380
2404
node.OutputID .Name = parmDecl->getName ().str ();
2381
2405
node.OutputID .Index = 0 ;
@@ -2434,7 +2458,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
2434
2458
node.MaxRecordsSharedWith = ix;
2435
2459
}
2436
2460
if (const auto *Attr = parmDecl->getAttr <HLSLMaxRecordsAttr>())
2437
- node.MaxRecords = Attr->getMaxCount ();
2461
+ node.MaxRecords = GetIntConstAttrArg (astContext, Attr->getMaxCount (), 0 );
2438
2462
}
2439
2463
2440
2464
if (inputPatchCount > 1 ) {
0 commit comments