Skip to content

Allows defined constants as attribute arguments. #7439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tools/clang/include/clang/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,9 @@ class Expr : public Stmt {
bool isConstantInitializer(ASTContext &Ctx, bool ForRef,
const Expr **Culprit = nullptr) const;

bool isVulkanSpecConstantExpr(const ASTContext &Ctx,
APValue *Result = nullptr) const;

/// EvalStatus is a struct with detailed info about an evaluation in progress.
struct EvalStatus {
/// HasSideEffects - Whether the evaluated expression has side effects.
Expand Down
14 changes: 7 additions & 7 deletions tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def HLSLMaxTessFactor: InheritableAttr {
}
def HLSLNumThreads: InheritableAttr {
let Spellings = [CXX11<"", "numthreads", 2015>];
let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
let Documentation = [Undocumented];
}
def HLSLRootSignature: InheritableAttr {
Expand Down Expand Up @@ -1007,7 +1007,7 @@ def HLSLNodeIsProgramEntry : InheritableAttr {

def HLSLNodeId : InheritableAttr {
let Spellings = [CXX11<"", "nodeid", 2017>];
let Args = [StringArgument<"Name">,DefaultIntArgument<"ArrayIndex", 0>];
let Args = [StringArgument<"Name">, ExprArgument<"ArrayIndex", 1>];
let Documentation = [Undocumented];
}

Expand All @@ -1019,25 +1019,25 @@ def HLSLNodeLocalRootArgumentsTableIndex : InheritableAttr {

def HLSLNodeShareInputOf : InheritableAttr {
let Spellings = [CXX11<"", "nodeshareinputof", 2017>];
let Args = [StringArgument<"Name">,UnsignedArgument<"ArrayIndex", 1>];
let Args = [StringArgument<"Name">,ExprArgument<"ArrayIndex", 1>];
let Documentation = [Undocumented];
}

def HLSLNodeDispatchGrid: InheritableAttr {
let Spellings = [CXX11<"", "nodedispatchgrid", 2015>];
let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">];
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
let Documentation = [Undocumented];
}

def HLSLNodeMaxDispatchGrid: InheritableAttr {
let Spellings = [CXX11<"", "nodemaxdispatchgrid", 2015>];
let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">];
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
let Documentation = [Undocumented];
}

def HLSLNodeMaxRecursionDepth : InheritableAttr {
let Spellings = [CXX11<"", "nodemaxrecursiondepth", 2017>];
let Args = [UnsignedArgument<"Count">];
let Args = [ExprArgument<"Count">];
let Documentation = [Undocumented];
}

Expand Down Expand Up @@ -1185,7 +1185,7 @@ def HLSLHitObject : InheritableAttr {

def HLSLMaxRecords : InheritableAttr {
let Spellings = [CXX11<"", "MaxRecords", 2015>];
let Args = [IntArgument<"maxCount">];
let Args = [ExprArgument<"maxCount">];
let Documentation = [Undocumented];
}

Expand Down
16 changes: 16 additions & 0 deletions tools/clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9451,6 +9451,22 @@ bool Expr::isIntegerConstantExpr(llvm::APSInt &Value, const ASTContext &Ctx,
return true;
}

bool Expr::isVulkanSpecConstantExpr(const ASTContext &Ctx,
APValue *Result) const {
if (auto *D = dyn_cast<DeclRefExpr>(this)) {
if (auto *V = dyn_cast<VarDecl>(D->getDecl())) {
if (V->hasAttr<VKConstantIdAttr>()) {
if (const Expr *I = V->getAnyInitializer()) {
if (!I->isCXX11ConstantExpr(Ctx, Result))
return false;
}
return true;
}
}
}
return false;
}

Comment on lines +9454 to +9469
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be clearer to require an individual if-statements to reduce the nesting.

auto *D = dyn_cast<DeclRefExpr>(this)
if (!D) return false;
auto *V = dyn_cast<VarDecl>(D->getDecl())
if (!V || !V->hasAttr<VkConstantIdAttr>()) return false;
...

bool Expr::isCXX98IntegralConstantExpr(const ASTContext &Ctx) const {
return CheckICE(this, Ctx).Kind == IK_ICE;
}
Expand Down
58 changes: 41 additions & 17 deletions tools/clang/lib/CodeGen/CGHLSLMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,19 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
};
} // namespace

static uint32_t GetIntConstAttrArg(ASTContext &astContext, const Expr *expr,
uint32_t defaultVal = 0) {
if (expr) {
llvm::APSInt apsInt;
APValue apValue;
if (expr->isIntegerConstantExpr(apsInt, astContext))
return (uint32_t)apsInt.getSExtValue();
if (expr->isVulkanSpecConstantExpr(astContext, &apValue) && apValue.isInt())
return (uint32_t)apValue.getInt().getSExtValue();
}
return defaultVal;
}

//------------------------------------------------------------------------------
//
// CGMSHLSLRuntime methods.
Expand Down Expand Up @@ -1419,6 +1432,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
}

DiagnosticsEngine &Diags = CGM.getDiags();
ASTContext &astContext = CGM.getTypes().getContext();

std::unique_ptr<DxilFunctionProps> funcProps =
llvm::make_unique<DxilFunctionProps>();
Expand Down Expand Up @@ -1629,10 +1643,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {

// Populate numThreads
if (const HLSLNumThreadsAttr *Attr = FD->getAttr<HLSLNumThreadsAttr>()) {

funcProps->numThreads[0] = Attr->getX();
funcProps->numThreads[1] = Attr->getY();
funcProps->numThreads[2] = Attr->getZ();
funcProps->numThreads[0] = GetIntConstAttrArg(astContext, Attr->getX(), 1);
funcProps->numThreads[1] = GetIntConstAttrArg(astContext, Attr->getY(), 1);
funcProps->numThreads[2] = GetIntConstAttrArg(astContext, Attr->getZ(), 1);

if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) {
unsigned DiagID = Diags.getCustomDiagID(
Expand Down Expand Up @@ -1805,7 +1818,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {

if (const auto *pAttr = FD->getAttr<HLSLNodeIdAttr>()) {
funcProps->NodeShaderID.Name = pAttr->getName().str();
funcProps->NodeShaderID.Index = pAttr->getArrayIndex();
funcProps->NodeShaderID.Index =
GetIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
} else {
funcProps->NodeShaderID.Name = FD->getName().str();
funcProps->NodeShaderID.Index = 0;
Expand All @@ -1816,20 +1830,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
}
if (const auto *pAttr = FD->getAttr<HLSLNodeShareInputOfAttr>()) {
funcProps->NodeShaderSharedInput.Name = pAttr->getName().str();
funcProps->NodeShaderSharedInput.Index = pAttr->getArrayIndex();
funcProps->NodeShaderSharedInput.Index =
GetIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
}
if (const auto *pAttr = FD->getAttr<HLSLNodeDispatchGridAttr>()) {
funcProps->Node.DispatchGrid[0] = pAttr->getX();
funcProps->Node.DispatchGrid[1] = pAttr->getY();
funcProps->Node.DispatchGrid[2] = pAttr->getZ();
funcProps->Node.DispatchGrid[0] =
GetIntConstAttrArg(astContext, pAttr->getX(), 1);
funcProps->Node.DispatchGrid[1] =
GetIntConstAttrArg(astContext, pAttr->getY(), 1);
funcProps->Node.DispatchGrid[2] =
GetIntConstAttrArg(astContext, pAttr->getZ(), 1);
}
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxDispatchGridAttr>()) {
funcProps->Node.MaxDispatchGrid[0] = pAttr->getX();
funcProps->Node.MaxDispatchGrid[1] = pAttr->getY();
funcProps->Node.MaxDispatchGrid[2] = pAttr->getZ();
funcProps->Node.MaxDispatchGrid[0] =
GetIntConstAttrArg(astContext, pAttr->getX(), 1);
funcProps->Node.MaxDispatchGrid[1] =
GetIntConstAttrArg(astContext, pAttr->getY(), 1);
funcProps->Node.MaxDispatchGrid[2] =
GetIntConstAttrArg(astContext, pAttr->getZ(), 1);
}
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxRecursionDepthAttr>()) {
funcProps->Node.MaxRecursionDepth = pAttr->getCount();
funcProps->Node.MaxRecursionDepth =
GetIntConstAttrArg(astContext, pAttr->getCount(), 0);
}
if (!FD->getAttr<HLSLNumThreadsAttr>()) {
// NumThreads wasn't specified.
Expand Down Expand Up @@ -2343,8 +2365,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++;

if (parmDecl->hasAttr<HLSLMaxRecordsAttr>()) {
node.MaxRecords =
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount();
node.MaxRecords = GetIntConstAttrArg(
astContext,
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount(), 1);
}
if (parmDecl->hasAttr<HLSLGloballyCoherentAttr>())
node.Flags.SetGloballyCoherent();
Expand Down Expand Up @@ -2375,7 +2398,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
// OutputID from attribute
if (const auto *Attr = parmDecl->getAttr<HLSLNodeIdAttr>()) {
node.OutputID.Name = Attr->getName().str();
node.OutputID.Index = Attr->getArrayIndex();
node.OutputID.Index =
GetIntConstAttrArg(astContext, Attr->getArrayIndex(), 0);
} else {
node.OutputID.Name = parmDecl->getName().str();
node.OutputID.Index = 0;
Expand Down Expand Up @@ -2434,7 +2458,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
node.MaxRecordsSharedWith = ix;
}
if (const auto *Attr = parmDecl->getAttr<HLSLMaxRecordsAttr>())
node.MaxRecords = Attr->getMaxCount();
node.MaxRecords = GetIntConstAttrArg(astContext, Attr->getMaxCount(), 0);
}

if (inputPatchCount > 1) {
Expand Down
46 changes: 29 additions & 17 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13246,6 +13246,31 @@ void SpirvEmitter::processInlineSpirvAttributes(const FunctionDecl *decl) {
}
}

bool SpirvEmitter::processNumThreadsAttr(const FunctionDecl *decl) {
auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>();
if (!numThreadsAttr)
return false;

auto f = [](ASTContext &Ctx, Expr *E) {
if (E) {
llvm::APSInt apsInt;
APValue apValue;
if (E->isIntegerConstantExpr(apsInt, Ctx))
return (uint32_t)apsInt.getSExtValue();
if (E->isVulkanSpecConstantExpr(Ctx, &apValue) && apValue.isInt())
return (uint32_t)apValue.getInt().getSExtValue();
}
return 1U;
};

spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
{f(astContext, numThreadsAttr->getX()),
f(astContext, numThreadsAttr->getY()),
f(astContext, numThreadsAttr->getZ())},
decl->getLocation());
return true;
}
Comment on lines +13249 to +13272
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a spec constant, you need to use the LocalSizeId execution mode, with the id of the spec constant right?

We could implement this by always using LocalSizeId, and make the parameters the id of the expression in the attribute.


bool SpirvEmitter::processGeometryShaderAttributes(const FunctionDecl *decl,
uint32_t *arraySize) {
bool success = true;
Expand Down Expand Up @@ -13422,15 +13447,9 @@ void SpirvEmitter::processPixelShaderAttributes(const FunctionDecl *decl) {
}

void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) {
auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>();
assert(numThreadsAttr && "thread group size missing from entry-point");

uint32_t x = static_cast<uint32_t>(numThreadsAttr->getX());
uint32_t y = static_cast<uint32_t>(numThreadsAttr->getY());
uint32_t z = static_cast<uint32_t>(numThreadsAttr->getZ());

spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
{x, y, z}, decl->getLocation());
if (!processNumThreadsAttr(decl)) {
assert(false && "thread group size missing from entry-point");
}

auto *waveSizeAttr = decl->getAttr<HLSLWaveSizeAttr>();
if (waveSizeAttr) {
Expand Down Expand Up @@ -13651,14 +13670,7 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(

bool SpirvEmitter::processMeshOrAmplificationShaderAttributes(
const FunctionDecl *decl, uint32_t *outVerticesArraySize) {
if (auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>()) {
uint32_t x, y, z;
x = static_cast<uint32_t>(numThreadsAttr->getX());
y = static_cast<uint32_t>(numThreadsAttr->getY());
z = static_cast<uint32_t>(numThreadsAttr->getZ());
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
{x, y, z}, decl->getLocation());
}
processNumThreadsAttr(decl);

// Early return for amplification shaders as they only take the 'numthreads'
// attribute.
Expand Down
2 changes: 2 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,8 @@ class SpirvEmitter : public ASTConsumer {
/// \brief Handle inline SPIR-V attributes for the entry function.
void processInlineSpirvAttributes(const FunctionDecl *entryFunction);

bool processNumThreadsAttr(const FunctionDecl *decl);

/// \brief Adds necessary execution modes for the hull/domain shaders based on
/// the HLSL attributes of the entry point function.
/// In the case of hull shaders, also writes the number of output control
Expand Down
Loading