Skip to content

Commit a12ce71

Browse files
committed
Allows defined constants as arguments of (some) attributes.
1 parent ffc174c commit a12ce71

File tree

14 files changed

+385
-152
lines changed

14 files changed

+385
-152
lines changed

tools/clang/include/clang/AST/Expr.h

+3
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,9 @@ class Expr : public Stmt {
531531
bool isConstantInitializer(ASTContext &Ctx, bool ForRef,
532532
const Expr **Culprit = nullptr) const;
533533

534+
bool isVulkanSpecConstantExpr(const ASTContext &Ctx,
535+
APValue *Result = nullptr) const;
536+
534537
/// EvalStatus is a struct with detailed info about an evaluation in progress.
535538
struct EvalStatus {
536539
/// HasSideEffects - Whether the evaluated expression has side effects.

tools/clang/include/clang/Basic/Attr.td

+7-7
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ def HLSLMaxTessFactor: InheritableAttr {
671671
}
672672
def HLSLNumThreads: InheritableAttr {
673673
let Spellings = [CXX11<"", "numthreads", 2015>];
674-
let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
674+
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
675675
let Documentation = [Undocumented];
676676
}
677677
def HLSLRootSignature: InheritableAttr {
@@ -1007,7 +1007,7 @@ def HLSLNodeIsProgramEntry : InheritableAttr {
10071007

10081008
def HLSLNodeId : InheritableAttr {
10091009
let Spellings = [CXX11<"", "nodeid", 2017>];
1010-
let Args = [StringArgument<"Name">,DefaultIntArgument<"ArrayIndex", 0>];
1010+
let Args = [StringArgument<"Name">, ExprArgument<"ArrayIndex", 1>];
10111011
let Documentation = [Undocumented];
10121012
}
10131013

@@ -1019,25 +1019,25 @@ def HLSLNodeLocalRootArgumentsTableIndex : InheritableAttr {
10191019

10201020
def HLSLNodeShareInputOf : InheritableAttr {
10211021
let Spellings = [CXX11<"", "nodeshareinputof", 2017>];
1022-
let Args = [StringArgument<"Name">,UnsignedArgument<"ArrayIndex", 1>];
1022+
let Args = [StringArgument<"Name">,ExprArgument<"ArrayIndex", 1>];
10231023
let Documentation = [Undocumented];
10241024
}
10251025

10261026
def HLSLNodeDispatchGrid: InheritableAttr {
10271027
let Spellings = [CXX11<"", "nodedispatchgrid", 2015>];
1028-
let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">];
1028+
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
10291029
let Documentation = [Undocumented];
10301030
}
10311031

10321032
def HLSLNodeMaxDispatchGrid: InheritableAttr {
10331033
let Spellings = [CXX11<"", "nodemaxdispatchgrid", 2015>];
1034-
let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">];
1034+
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
10351035
let Documentation = [Undocumented];
10361036
}
10371037

10381038
def HLSLNodeMaxRecursionDepth : InheritableAttr {
10391039
let Spellings = [CXX11<"", "nodemaxrecursiondepth", 2017>];
1040-
let Args = [UnsignedArgument<"Count">];
1040+
let Args = [ExprArgument<"Count">];
10411041
let Documentation = [Undocumented];
10421042
}
10431043

@@ -1185,7 +1185,7 @@ def HLSLHitObject : InheritableAttr {
11851185

11861186
def HLSLMaxRecords : InheritableAttr {
11871187
let Spellings = [CXX11<"", "MaxRecords", 2015>];
1188-
let Args = [IntArgument<"maxCount">];
1188+
let Args = [ExprArgument<"maxCount">];
11891189
let Documentation = [Undocumented];
11901190
}
11911191

tools/clang/lib/AST/ExprConstant.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -9451,6 +9451,21 @@ bool Expr::isIntegerConstantExpr(llvm::APSInt &Value, const ASTContext &Ctx,
94519451
return true;
94529452
}
94539453

9454+
bool Expr::isVulkanSpecConstantExpr(const ASTContext &Ctx, APValue *Result) const {
9455+
if (auto *D = dyn_cast<DeclRefExpr>(this)) {
9456+
if (auto *V = dyn_cast<VarDecl>(D->getDecl())) {
9457+
if (V->hasAttr<VKConstantIdAttr>()) {
9458+
if (const Expr *I = V->getAnyInitializer()) {
9459+
if (!I->isCXX11ConstantExpr(Ctx, Result))
9460+
return false;
9461+
}
9462+
return true;
9463+
}
9464+
}
9465+
}
9466+
return false;
9467+
}
9468+
94549469
bool Expr::isCXX98IntegralConstantExpr(const ASTContext &Ctx) const {
94559470
return CheckICE(this, Ctx).Kind == IK_ICE;
94569471
}

tools/clang/lib/CodeGen/CGHLSLMS.cpp

+41-17
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,19 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
325325
};
326326
} // namespace
327327

328+
static uint32_t GetIntConstAttrArg(ASTContext &astContext, const Expr *expr,
329+
uint32_t defaultVal = 0) {
330+
if (expr) {
331+
llvm::APSInt apsInt;
332+
APValue apValue;
333+
if (expr->isIntegerConstantExpr(apsInt, astContext))
334+
return (uint32_t)apsInt.getSExtValue();
335+
if (expr->isVulkanSpecConstantExpr(astContext, &apValue) && apValue.isInt())
336+
return (uint32_t)apValue.getInt().getSExtValue();
337+
}
338+
return defaultVal;
339+
}
340+
328341
//------------------------------------------------------------------------------
329342
//
330343
// CGMSHLSLRuntime methods.
@@ -1419,6 +1432,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
14191432
}
14201433

14211434
DiagnosticsEngine &Diags = CGM.getDiags();
1435+
ASTContext &astContext = CGM.getTypes().getContext();
14221436

14231437
std::unique_ptr<DxilFunctionProps> funcProps =
14241438
llvm::make_unique<DxilFunctionProps>();
@@ -1629,10 +1643,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
16291643

16301644
// Populate numThreads
16311645
if (const HLSLNumThreadsAttr *Attr = FD->getAttr<HLSLNumThreadsAttr>()) {
1632-
1633-
funcProps->numThreads[0] = Attr->getX();
1634-
funcProps->numThreads[1] = Attr->getY();
1635-
funcProps->numThreads[2] = Attr->getZ();
1646+
funcProps->numThreads[0] = GetIntConstAttrArg(astContext, Attr->getX(), 1);
1647+
funcProps->numThreads[1] = GetIntConstAttrArg(astContext, Attr->getY(), 1);
1648+
funcProps->numThreads[2] = GetIntConstAttrArg(astContext, Attr->getZ(), 1);
16361649

16371650
if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) {
16381651
unsigned DiagID = Diags.getCustomDiagID(
@@ -1805,7 +1818,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18051818

18061819
if (const auto *pAttr = FD->getAttr<HLSLNodeIdAttr>()) {
18071820
funcProps->NodeShaderID.Name = pAttr->getName().str();
1808-
funcProps->NodeShaderID.Index = pAttr->getArrayIndex();
1821+
funcProps->NodeShaderID.Index =
1822+
GetIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
18091823
} else {
18101824
funcProps->NodeShaderID.Name = FD->getName().str();
18111825
funcProps->NodeShaderID.Index = 0;
@@ -1816,20 +1830,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18161830
}
18171831
if (const auto *pAttr = FD->getAttr<HLSLNodeShareInputOfAttr>()) {
18181832
funcProps->NodeShaderSharedInput.Name = pAttr->getName().str();
1819-
funcProps->NodeShaderSharedInput.Index = pAttr->getArrayIndex();
1833+
funcProps->NodeShaderSharedInput.Index =
1834+
GetIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
18201835
}
18211836
if (const auto *pAttr = FD->getAttr<HLSLNodeDispatchGridAttr>()) {
1822-
funcProps->Node.DispatchGrid[0] = pAttr->getX();
1823-
funcProps->Node.DispatchGrid[1] = pAttr->getY();
1824-
funcProps->Node.DispatchGrid[2] = pAttr->getZ();
1837+
funcProps->Node.DispatchGrid[0] =
1838+
GetIntConstAttrArg(astContext, pAttr->getX(), 1);
1839+
funcProps->Node.DispatchGrid[1] =
1840+
GetIntConstAttrArg(astContext, pAttr->getY(), 1);
1841+
funcProps->Node.DispatchGrid[2] =
1842+
GetIntConstAttrArg(astContext, pAttr->getZ(), 1);
18251843
}
18261844
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxDispatchGridAttr>()) {
1827-
funcProps->Node.MaxDispatchGrid[0] = pAttr->getX();
1828-
funcProps->Node.MaxDispatchGrid[1] = pAttr->getY();
1829-
funcProps->Node.MaxDispatchGrid[2] = pAttr->getZ();
1845+
funcProps->Node.MaxDispatchGrid[0] =
1846+
GetIntConstAttrArg(astContext, pAttr->getX(), 1);
1847+
funcProps->Node.MaxDispatchGrid[1] =
1848+
GetIntConstAttrArg(astContext, pAttr->getY(), 1);
1849+
funcProps->Node.MaxDispatchGrid[2] =
1850+
GetIntConstAttrArg(astContext, pAttr->getZ(), 1);
18301851
}
18311852
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxRecursionDepthAttr>()) {
1832-
funcProps->Node.MaxRecursionDepth = pAttr->getCount();
1853+
funcProps->Node.MaxRecursionDepth =
1854+
GetIntConstAttrArg(astContext, pAttr->getCount(), 0);
18331855
}
18341856
if (!FD->getAttr<HLSLNumThreadsAttr>()) {
18351857
// NumThreads wasn't specified.
@@ -2343,8 +2365,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23432365
NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++;
23442366

23452367
if (parmDecl->hasAttr<HLSLMaxRecordsAttr>()) {
2346-
node.MaxRecords =
2347-
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount();
2368+
node.MaxRecords = GetIntConstAttrArg(
2369+
astContext,
2370+
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount(), 1);
23482371
}
23492372
if (parmDecl->hasAttr<HLSLGloballyCoherentAttr>())
23502373
node.Flags.SetGloballyCoherent();
@@ -2375,7 +2398,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23752398
// OutputID from attribute
23762399
if (const auto *Attr = parmDecl->getAttr<HLSLNodeIdAttr>()) {
23772400
node.OutputID.Name = Attr->getName().str();
2378-
node.OutputID.Index = Attr->getArrayIndex();
2401+
node.OutputID.Index =
2402+
GetIntConstAttrArg(astContext, Attr->getArrayIndex(), 0);
23792403
} else {
23802404
node.OutputID.Name = parmDecl->getName().str();
23812405
node.OutputID.Index = 0;
@@ -2434,7 +2458,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
24342458
node.MaxRecordsSharedWith = ix;
24352459
}
24362460
if (const auto *Attr = parmDecl->getAttr<HLSLMaxRecordsAttr>())
2437-
node.MaxRecords = Attr->getMaxCount();
2461+
node.MaxRecords = GetIntConstAttrArg(astContext, Attr->getMaxCount(), 0);
24382462
}
24392463

24402464
if (inputPatchCount > 1) {

tools/clang/lib/SPIRV/SpirvEmitter.cpp

+29-17
Original file line numberDiff line numberDiff line change
@@ -13246,6 +13246,31 @@ void SpirvEmitter::processInlineSpirvAttributes(const FunctionDecl *decl) {
1324613246
}
1324713247
}
1324813248

13249+
bool SpirvEmitter::processNumThreadsAttr(const FunctionDecl *decl) {
13250+
auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>();
13251+
if (!numThreadsAttr)
13252+
return false;
13253+
13254+
auto f = [](ASTContext &Ctx, Expr *E) {
13255+
if (E) {
13256+
llvm::APSInt apsInt;
13257+
APValue apValue;
13258+
if (E->isIntegerConstantExpr(apsInt, Ctx))
13259+
return (uint32_t)apsInt.getSExtValue();
13260+
if (E->isVulkanSpecConstantExpr(Ctx, &apValue) && apValue.isInt())
13261+
return (uint32_t)apValue.getInt().getSExtValue();
13262+
}
13263+
return 1U;
13264+
};
13265+
13266+
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
13267+
{f(astContext, numThreadsAttr->getX()),
13268+
f(astContext, numThreadsAttr->getY()),
13269+
f(astContext, numThreadsAttr->getZ())},
13270+
decl->getLocation());
13271+
return true;
13272+
}
13273+
1324913274
bool SpirvEmitter::processGeometryShaderAttributes(const FunctionDecl *decl,
1325013275
uint32_t *arraySize) {
1325113276
bool success = true;
@@ -13422,15 +13447,9 @@ void SpirvEmitter::processPixelShaderAttributes(const FunctionDecl *decl) {
1342213447
}
1342313448

1342413449
void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) {
13425-
auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>();
13426-
assert(numThreadsAttr && "thread group size missing from entry-point");
13427-
13428-
uint32_t x = static_cast<uint32_t>(numThreadsAttr->getX());
13429-
uint32_t y = static_cast<uint32_t>(numThreadsAttr->getY());
13430-
uint32_t z = static_cast<uint32_t>(numThreadsAttr->getZ());
13431-
13432-
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
13433-
{x, y, z}, decl->getLocation());
13450+
if (!processNumThreadsAttr(decl)) {
13451+
assert(false && "thread group size missing from entry-point");
13452+
}
1343413453

1343513454
auto *waveSizeAttr = decl->getAttr<HLSLWaveSizeAttr>();
1343613455
if (waveSizeAttr) {
@@ -13651,14 +13670,7 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
1365113670

1365213671
bool SpirvEmitter::processMeshOrAmplificationShaderAttributes(
1365313672
const FunctionDecl *decl, uint32_t *outVerticesArraySize) {
13654-
if (auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>()) {
13655-
uint32_t x, y, z;
13656-
x = static_cast<uint32_t>(numThreadsAttr->getX());
13657-
y = static_cast<uint32_t>(numThreadsAttr->getY());
13658-
z = static_cast<uint32_t>(numThreadsAttr->getZ());
13659-
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
13660-
{x, y, z}, decl->getLocation());
13661-
}
13673+
processNumThreadsAttr(decl);
1366213674

1366313675
// Early return for amplification shaders as they only take the 'numthreads'
1366413676
// attribute.

tools/clang/lib/SPIRV/SpirvEmitter.h

+2
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,8 @@ class SpirvEmitter : public ASTConsumer {
838838
/// \brief Handle inline SPIR-V attributes for the entry function.
839839
void processInlineSpirvAttributes(const FunctionDecl *entryFunction);
840840

841+
bool processNumThreadsAttr(const FunctionDecl *decl);
842+
841843
/// \brief Adds necessary execution modes for the hull/domain shaders based on
842844
/// the HLSL attributes of the entry point function.
843845
/// In the case of hull shaders, also writes the number of output control

0 commit comments

Comments
 (0)