diff --git a/tools/clang/include/clang/AST/Expr.h b/tools/clang/include/clang/AST/Expr.h index 26eff309f7..368d24d03d 100644 --- a/tools/clang/include/clang/AST/Expr.h +++ b/tools/clang/include/clang/AST/Expr.h @@ -531,6 +531,9 @@ class Expr : public Stmt { bool isConstantInitializer(ASTContext &Ctx, bool ForRef, const Expr **Culprit = nullptr) const; + bool isVulkanSpecConstantExpr(const ASTContext &Ctx, + APValue *Result = nullptr) const; + /// EvalStatus is a struct with detailed info about an evaluation in progress. struct EvalStatus { /// HasSideEffects - Whether the evaluated expression has side effects. diff --git a/tools/clang/include/clang/Basic/Attr.td b/tools/clang/include/clang/Basic/Attr.td index 2518423565..eaabc23591 100644 --- a/tools/clang/include/clang/Basic/Attr.td +++ b/tools/clang/include/clang/Basic/Attr.td @@ -671,7 +671,7 @@ def HLSLMaxTessFactor: InheritableAttr { } def HLSLNumThreads: InheritableAttr { let Spellings = [CXX11<"", "numthreads", 2015>]; - let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">]; + let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">]; let Documentation = [Undocumented]; } def HLSLRootSignature: InheritableAttr { @@ -1007,7 +1007,7 @@ def HLSLNodeIsProgramEntry : InheritableAttr { def HLSLNodeId : InheritableAttr { let Spellings = [CXX11<"", "nodeid", 2017>]; - let Args = [StringArgument<"Name">,DefaultIntArgument<"ArrayIndex", 0>]; + let Args = [StringArgument<"Name">, ExprArgument<"ArrayIndex", 1>]; let Documentation = [Undocumented]; } @@ -1019,25 +1019,25 @@ def HLSLNodeLocalRootArgumentsTableIndex : InheritableAttr { def HLSLNodeShareInputOf : InheritableAttr { let Spellings = [CXX11<"", "nodeshareinputof", 2017>]; - let Args = [StringArgument<"Name">,UnsignedArgument<"ArrayIndex", 1>]; + let Args = [StringArgument<"Name">,ExprArgument<"ArrayIndex", 1>]; let Documentation = [Undocumented]; } def HLSLNodeDispatchGrid: InheritableAttr { let Spellings = [CXX11<"", "nodedispatchgrid", 2015>]; - let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">]; + let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">]; let Documentation = [Undocumented]; } def HLSLNodeMaxDispatchGrid: InheritableAttr { let Spellings = [CXX11<"", "nodemaxdispatchgrid", 2015>]; - let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">]; + let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">]; let Documentation = [Undocumented]; } def HLSLNodeMaxRecursionDepth : InheritableAttr { let Spellings = [CXX11<"", "nodemaxrecursiondepth", 2017>]; - let Args = [UnsignedArgument<"Count">]; + let Args = [ExprArgument<"Count">]; let Documentation = [Undocumented]; } @@ -1185,7 +1185,7 @@ def HLSLHitObject : InheritableAttr { def HLSLMaxRecords : InheritableAttr { let Spellings = [CXX11<"", "MaxRecords", 2015>]; - let Args = [IntArgument<"maxCount">]; + let Args = [ExprArgument<"maxCount">]; let Documentation = [Undocumented]; } diff --git a/tools/clang/lib/AST/ExprConstant.cpp b/tools/clang/lib/AST/ExprConstant.cpp index 69e0760bce..283a1bae4a 100644 --- a/tools/clang/lib/AST/ExprConstant.cpp +++ b/tools/clang/lib/AST/ExprConstant.cpp @@ -9451,6 +9451,22 @@ bool Expr::isIntegerConstantExpr(llvm::APSInt &Value, const ASTContext &Ctx, return true; } +bool Expr::isVulkanSpecConstantExpr(const ASTContext &Ctx, + APValue *Result) const { + if (auto *D = dyn_cast(this)) { + if (auto *V = dyn_cast(D->getDecl())) { + if (V->hasAttr()) { + if (const Expr *I = V->getAnyInitializer()) { + if (!I->isCXX11ConstantExpr(Ctx, Result)) + return false; + } + return true; + } + } + } + return false; +} + bool Expr::isCXX98IntegralConstantExpr(const ASTContext &Ctx) const { return CheckICE(this, Ctx).Kind == IK_ICE; } diff --git a/tools/clang/lib/CodeGen/CGHLSLMS.cpp b/tools/clang/lib/CodeGen/CGHLSLMS.cpp index 16ddeaec60..8785b9c416 100644 --- a/tools/clang/lib/CodeGen/CGHLSLMS.cpp +++ b/tools/clang/lib/CodeGen/CGHLSLMS.cpp @@ -325,6 +325,19 @@ class CGMSHLSLRuntime : public CGHLSLRuntime { }; } // namespace +static uint32_t GetIntConstAttrArg(ASTContext &astContext, const Expr *expr, + uint32_t defaultVal = 0) { + if (expr) { + llvm::APSInt apsInt; + APValue apValue; + if (expr->isIntegerConstantExpr(apsInt, astContext)) + return (uint32_t)apsInt.getSExtValue(); + if (expr->isVulkanSpecConstantExpr(astContext, &apValue) && apValue.isInt()) + return (uint32_t)apValue.getInt().getSExtValue(); + } + return defaultVal; +} + //------------------------------------------------------------------------------ // // CGMSHLSLRuntime methods. @@ -1419,6 +1432,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { } DiagnosticsEngine &Diags = CGM.getDiags(); + ASTContext &astContext = CGM.getTypes().getContext(); std::unique_ptr funcProps = llvm::make_unique(); @@ -1629,10 +1643,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { // Populate numThreads if (const HLSLNumThreadsAttr *Attr = FD->getAttr()) { - - funcProps->numThreads[0] = Attr->getX(); - funcProps->numThreads[1] = Attr->getY(); - funcProps->numThreads[2] = Attr->getZ(); + funcProps->numThreads[0] = GetIntConstAttrArg(astContext, Attr->getX(), 1); + funcProps->numThreads[1] = GetIntConstAttrArg(astContext, Attr->getY(), 1); + funcProps->numThreads[2] = GetIntConstAttrArg(astContext, Attr->getZ(), 1); if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) { unsigned DiagID = Diags.getCustomDiagID( @@ -1805,7 +1818,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { if (const auto *pAttr = FD->getAttr()) { funcProps->NodeShaderID.Name = pAttr->getName().str(); - funcProps->NodeShaderID.Index = pAttr->getArrayIndex(); + funcProps->NodeShaderID.Index = + GetIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0); } else { funcProps->NodeShaderID.Name = FD->getName().str(); funcProps->NodeShaderID.Index = 0; @@ -1816,20 +1830,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { } if (const auto *pAttr = FD->getAttr()) { funcProps->NodeShaderSharedInput.Name = pAttr->getName().str(); - funcProps->NodeShaderSharedInput.Index = pAttr->getArrayIndex(); + funcProps->NodeShaderSharedInput.Index = + GetIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0); } if (const auto *pAttr = FD->getAttr()) { - funcProps->Node.DispatchGrid[0] = pAttr->getX(); - funcProps->Node.DispatchGrid[1] = pAttr->getY(); - funcProps->Node.DispatchGrid[2] = pAttr->getZ(); + funcProps->Node.DispatchGrid[0] = + GetIntConstAttrArg(astContext, pAttr->getX(), 1); + funcProps->Node.DispatchGrid[1] = + GetIntConstAttrArg(astContext, pAttr->getY(), 1); + funcProps->Node.DispatchGrid[2] = + GetIntConstAttrArg(astContext, pAttr->getZ(), 1); } if (const auto *pAttr = FD->getAttr()) { - funcProps->Node.MaxDispatchGrid[0] = pAttr->getX(); - funcProps->Node.MaxDispatchGrid[1] = pAttr->getY(); - funcProps->Node.MaxDispatchGrid[2] = pAttr->getZ(); + funcProps->Node.MaxDispatchGrid[0] = + GetIntConstAttrArg(astContext, pAttr->getX(), 1); + funcProps->Node.MaxDispatchGrid[1] = + GetIntConstAttrArg(astContext, pAttr->getY(), 1); + funcProps->Node.MaxDispatchGrid[2] = + GetIntConstAttrArg(astContext, pAttr->getZ(), 1); } if (const auto *pAttr = FD->getAttr()) { - funcProps->Node.MaxRecursionDepth = pAttr->getCount(); + funcProps->Node.MaxRecursionDepth = + GetIntConstAttrArg(astContext, pAttr->getCount(), 0); } if (!FD->getAttr()) { // NumThreads wasn't specified. @@ -2343,8 +2365,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++; if (parmDecl->hasAttr()) { - node.MaxRecords = - parmDecl->getAttr()->getMaxCount(); + node.MaxRecords = GetIntConstAttrArg( + astContext, + parmDecl->getAttr()->getMaxCount(), 1); } if (parmDecl->hasAttr()) node.Flags.SetGloballyCoherent(); @@ -2375,7 +2398,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { // OutputID from attribute if (const auto *Attr = parmDecl->getAttr()) { node.OutputID.Name = Attr->getName().str(); - node.OutputID.Index = Attr->getArrayIndex(); + node.OutputID.Index = + GetIntConstAttrArg(astContext, Attr->getArrayIndex(), 0); } else { node.OutputID.Name = parmDecl->getName().str(); node.OutputID.Index = 0; @@ -2434,7 +2458,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { node.MaxRecordsSharedWith = ix; } if (const auto *Attr = parmDecl->getAttr()) - node.MaxRecords = Attr->getMaxCount(); + node.MaxRecords = GetIntConstAttrArg(astContext, Attr->getMaxCount(), 0); } if (inputPatchCount > 1) { diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 7337a33b01..3a526182ad 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -13246,6 +13246,31 @@ void SpirvEmitter::processInlineSpirvAttributes(const FunctionDecl *decl) { } } +bool SpirvEmitter::processNumThreadsAttr(const FunctionDecl *decl) { + auto *numThreadsAttr = decl->getAttr(); + if (!numThreadsAttr) + return false; + + auto f = [](ASTContext &Ctx, Expr *E) { + if (E) { + llvm::APSInt apsInt; + APValue apValue; + if (E->isIntegerConstantExpr(apsInt, Ctx)) + return (uint32_t)apsInt.getSExtValue(); + if (E->isVulkanSpecConstantExpr(Ctx, &apValue) && apValue.isInt()) + return (uint32_t)apValue.getInt().getSExtValue(); + } + return 1U; + }; + + spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, + {f(astContext, numThreadsAttr->getX()), + f(astContext, numThreadsAttr->getY()), + f(astContext, numThreadsAttr->getZ())}, + decl->getLocation()); + return true; +} + bool SpirvEmitter::processGeometryShaderAttributes(const FunctionDecl *decl, uint32_t *arraySize) { bool success = true; @@ -13422,15 +13447,9 @@ void SpirvEmitter::processPixelShaderAttributes(const FunctionDecl *decl) { } void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) { - auto *numThreadsAttr = decl->getAttr(); - assert(numThreadsAttr && "thread group size missing from entry-point"); - - uint32_t x = static_cast(numThreadsAttr->getX()); - uint32_t y = static_cast(numThreadsAttr->getY()); - uint32_t z = static_cast(numThreadsAttr->getZ()); - - spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, - {x, y, z}, decl->getLocation()); + if (!processNumThreadsAttr(decl)) { + assert(false && "thread group size missing from entry-point"); + } auto *waveSizeAttr = decl->getAttr(); if (waveSizeAttr) { @@ -13651,14 +13670,7 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing( bool SpirvEmitter::processMeshOrAmplificationShaderAttributes( const FunctionDecl *decl, uint32_t *outVerticesArraySize) { - if (auto *numThreadsAttr = decl->getAttr()) { - uint32_t x, y, z; - x = static_cast(numThreadsAttr->getX()); - y = static_cast(numThreadsAttr->getY()); - z = static_cast(numThreadsAttr->getZ()); - spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, - {x, y, z}, decl->getLocation()); - } + processNumThreadsAttr(decl); // Early return for amplification shaders as they only take the 'numthreads' // attribute. diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 6c1e12989c..d4a7e4de2a 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -838,6 +838,8 @@ class SpirvEmitter : public ASTConsumer { /// \brief Handle inline SPIR-V attributes for the entry function. void processInlineSpirvAttributes(const FunctionDecl *entryFunction); + bool processNumThreadsAttr(const FunctionDecl *decl); + /// \brief Adds necessary execution modes for the hull/domain shaders based on /// the HLSL attributes of the entry point function. /// In the case of hull shaders, also writes the number of output control diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index b15068638d..bbcfa53153 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -11723,21 +11723,37 @@ void Sema::CheckHLSLFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall, } } +static uint32_t GetIntConstAttrArg(ASTContext &astContext, const Expr *expr, + uint32_t defaultVal = 0) { + if (expr) { + llvm::APSInt apsInt; + APValue apValue; + if (expr->isIntegerConstantExpr(apsInt, astContext)) + return (uint32_t)apsInt.getSExtValue(); + if (expr->isVulkanSpecConstantExpr(astContext, &apValue) && apValue.isInt()) + return (uint32_t)apValue.getInt().getSExtValue(); + } + return defaultVal; +} + ///////////////////////////////////////////////////////////////////////////// // Check HLSL intrinsic calls reachable from entry/export functions. -static void DiagnoseNumThreadsForDerivativeOp(const HLSLNumThreadsAttr *Attr, - SourceLocation LocDeriv, - FunctionDecl *FD, - const FunctionDecl *EntryDecl, - DiagnosticsEngine &Diags) { +static void DiagnoseNumThreadsForDerivativeOp( + Sema &S, const HLSLNumThreadsAttr *Attr, SourceLocation LocDeriv, + FunctionDecl *FD, const FunctionDecl *EntryDecl, DiagnosticsEngine &Diags) { bool invalidNumThreads = false; - if (Attr->getY() != 1) { + ASTContext &astContext = S.getASTContext(); + uint32_t x = GetIntConstAttrArg(astContext, Attr->getX(), 1); + uint32_t y = GetIntConstAttrArg(astContext, Attr->getY(), 1); + uint32_t z = GetIntConstAttrArg(astContext, Attr->getZ(), 1); + + if (y != 1) { // 2D mode requires x and y to be multiple of 2. - invalidNumThreads = !((Attr->getX() % 2) == 0 && (Attr->getY() % 2) == 0); + invalidNumThreads = !((x % 2) == 0 && (y % 2) == 0); } else { // 1D mode requires x to be multiple of 4 and y and z to be 1. - invalidNumThreads = (Attr->getX() % 4) != 0 || (Attr->getZ() != 1); + invalidNumThreads = (x % 4) != 0 || (z != 1); } if (invalidNumThreads) { Diags.Report(LocDeriv, diag::warn_hlsl_derivatives_wrong_numthreads) @@ -11783,7 +11799,7 @@ static void DiagnoseDerivativeOp(Sema &S, FunctionDecl *FD, SourceLocation Loc, if (const HLSLNumThreadsAttr *Attr = EntryDecl->getAttr()) { - DiagnoseNumThreadsForDerivativeOp(Attr, Loc, FD, EntryDecl, Diags); + DiagnoseNumThreadsForDerivativeOp(S, Attr, Loc, FD, EntryDecl, Diags); } } @@ -13180,12 +13196,12 @@ FlattenedTypeIterator::CompareTypesForInit(HLSLExternalSource &source, //////////////////////////////////////////////////////////////////////////////// // Attribute processing support. // -static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, - unsigned index = 0) { - int64_t value = 0; +static Expr *ValidateAttributeIntArgExpr(Sema &S, const AttributeList &Attr, + unsigned index, int64_t *value, + bool allowDefinedConstant = false) { + Expr *E = nullptr; if (Attr.getNumArgs() > index) { - Expr *E = nullptr; if (!Attr.isArgExpr(index)) { // For case arg is constant variable. IdentifierLoc *loc = Attr.getArgAsIdent(index); @@ -13196,13 +13212,13 @@ static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, if (!decl) { S.Diag(Attr.getLoc(), diag::warn_hlsl_attribute_expects_uint_literal) << Attr.getName(); - return value; + return nullptr; } Expr *init = decl->getInit(); if (!init) { S.Diag(Attr.getLoc(), diag::warn_hlsl_attribute_expects_uint_literal) << Attr.getName(); - return value; + return nullptr; } E = init; } else @@ -13212,11 +13228,13 @@ static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, bool displayError = false; if (E->isTypeDependent() || E->isValueDependent() || !E->isCXX11ConstantExpr(S.Context, &ArgNum)) { - displayError = true; + displayError = + !allowDefinedConstant || + !(E->isVulkanSpecConstantExpr(S.Context, &ArgNum) && ArgNum.isInt()); } else { if (ArgNum.isInt()) { - value = ArgNum.getInt().getSExtValue(); - if (!(E->getType()->isIntegralOrEnumerationType()) || value < 0) { + *value = ArgNum.getInt().getSExtValue(); + if (!(E->getType()->isIntegralOrEnumerationType()) || *value < 0) { S.Diag(Attr.getLoc(), diag::warn_hlsl_attribute_expects_uint_literal) << Attr.getName(); } @@ -13226,8 +13244,8 @@ static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, if (ArgNum.getFloat().convertToInteger( floatInt, llvm::APFloat::rmTowardZero, &isPrecise) == llvm::APFloat::opStatus::opOK) { - value = floatInt.getSExtValue(); - if (value < 0) { + *value = floatInt.getSExtValue(); + if (*value < 0) { S.Diag(Attr.getLoc(), diag::warn_hlsl_attribute_expects_uint_literal) << Attr.getName(); @@ -13245,9 +13263,23 @@ static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, S.Diag(Attr.getLoc(), diag::err_attribute_argument_type) << Attr.getName() << AANT_ArgumentIntegerConstant << E->getSourceRange(); + return nullptr; } } + return E; +} + +static Expr *ValidateAttributeIntArgExpr(Sema &S, const AttributeList &Attr, + unsigned index = 0) { + int64_t value = 0; + return ValidateAttributeIntArgExpr(S, Attr, index, &value, true); +} + +static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, + unsigned index = 0) { + int64_t value = 0; + ValidateAttributeIntArgExpr(S, Attr, index, &value); return (int)value; } @@ -13593,19 +13625,27 @@ HLSLMaxRecordsAttr *ValidateMaxRecordsAttributes(Sema &S, Decl *D, Expr *ArgExpr = A.getArgAsExpr(0); IntegerLiteral *LiteralInt = dyn_cast(ArgExpr->IgnoreParenCasts()); + clang::SourceLocation Loc = {}; - if (ExistingMRSWA || ExistingMRA->getMaxCount() != LiteralInt->getValue()) { - clang::SourceLocation Loc = ExistingMRA ? ExistingMRA->getLocation() - : ExistingMRSWA->getLocation(); + if (ExistingMRSWA) { + Loc = ExistingMRSWA->getLocation(); + } else if (ExistingMRA) { + uint32_t maxCount = + GetIntConstAttrArg(S.getASTContext(), ExistingMRA->getMaxCount(), 0); + if (LiteralInt->getValue().getLimitedValue() != maxCount) + Loc = ExistingMRA->getLocation(); + } + + if (Loc.isValid()) { S.Diag(A.getLoc(), diag::err_hlsl_maxrecord_attrs_on_same_arg); S.Diag(Loc, diag::note_conflicting_attribute); return nullptr; } } - return ::new (S.Context) - HLSLMaxRecordsAttr(A.getRange(), S.Context, ValidateAttributeIntArg(S, A), - A.getAttributeSpellingListIndex()); + return ::new (S.Context) HLSLMaxRecordsAttr( + A.getRange(), S.Context, ValidateAttributeIntArgExpr(S, A), + A.getAttributeSpellingListIndex()); } // This function validates the wave size attribute in a stand-alone way, @@ -13814,19 +13854,20 @@ void Sema::DiagnoseCoherenceMismatch(const Expr *SrcExpr, QualType TargetType, } } -void ValidateDispatchGridValues(DiagnosticsEngine &Diags, - const AttributeList &A, Attr *declAttr) { +void ValidateDispatchGridValues(Sema &S, const AttributeList &A, + Attr *declAttr) { unsigned x = 1, y = 1, z = 1; + ASTContext &astContext = S.getASTContext(); if (HLSLNodeDispatchGridAttr *pA = dyn_cast(declAttr)) { - x = pA->getX(); - y = pA->getY(); - z = pA->getZ(); + x = GetIntConstAttrArg(astContext, pA->getX(), 1); + y = GetIntConstAttrArg(astContext, pA->getY(), 1); + z = GetIntConstAttrArg(astContext, pA->getZ(), 1); } else if (HLSLNodeMaxDispatchGridAttr *pA = dyn_cast(declAttr)) { - x = pA->getX(); - y = pA->getY(); - z = pA->getZ(); + x = GetIntConstAttrArg(astContext, pA->getX(), 1); + y = GetIntConstAttrArg(astContext, pA->getY(), 1); + z = GetIntConstAttrArg(astContext, pA->getZ(), 1); } else { llvm_unreachable("ValidateDispatchGridValues() called for wrong attribute"); } @@ -13835,26 +13876,26 @@ void ValidateDispatchGridValues(DiagnosticsEngine &Diags, // If a component is out of range, we reset it to 0 to avoid also generating // a secondary error if the product would be out of range if (x < 1 || x > MaxComponentValue) { - Diags.Report(A.getArgAsExpr(0)->getExprLoc(), - diag::err_hlsl_dispatchgrid_component) + S.Diags.Report(A.getArgAsExpr(0)->getExprLoc(), + diag::err_hlsl_dispatchgrid_component) << A.getName() << "X" << A.getRange(); x = 0; } if (y < 1 || y > MaxComponentValue) { - Diags.Report(A.getArgAsExpr(1)->getExprLoc(), - diag::err_hlsl_dispatchgrid_component) + S.Diags.Report(A.getArgAsExpr(1)->getExprLoc(), + diag::err_hlsl_dispatchgrid_component) << A.getName() << "Y" << A.getRange(); y = 0; } if (z < 1 || z > MaxComponentValue) { - Diags.Report(A.getArgAsExpr(2)->getExprLoc(), - diag::err_hlsl_dispatchgrid_component) + S.Diags.Report(A.getArgAsExpr(2)->getExprLoc(), + diag::err_hlsl_dispatchgrid_component) << A.getName() << "Z" << A.getRange(); z = 0; } uint64_t product = (uint64_t)x * (uint64_t)y * (uint64_t)z; if (product > MaxProductValue) - Diags.Report(A.getLoc(), diag::err_hlsl_dispatchgrid_product) + S.Diags.Report(A.getLoc(), diag::err_hlsl_dispatchgrid_product) << A.getName() << A.getRange(); } @@ -14014,7 +14055,8 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A, case AttributeList::AT_HLSLNodeId: declAttr = ::new (S.Context) HLSLNodeIdAttr( A.getRange(), S.Context, ValidateAttributeStringArg(S, A, nullptr, 0), - ValidateAttributeIntArg(S, A, 1), A.getAttributeSpellingListIndex()); + ValidateAttributeIntArgExpr(S, A, 1), + A.getAttributeSpellingListIndex()); break; case AttributeList::AT_HLSLNodeTrackRWInputSharing: declAttr = ::new (S.Context) HLSLNodeTrackRWInputSharingAttr( @@ -14117,18 +14159,20 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A, A.getAttributeSpellingListIndex()); break; case AttributeList::AT_HLSLNumThreads: { - int X = ValidateAttributeIntArg(S, A, 0); - int Y = ValidateAttributeIntArg(S, A, 1); - int Z = ValidateAttributeIntArg(S, A, 2); - int N = X * Y * Z; + int64_t X = 1, Y = 1, Z = 1; + auto *arg0 = ValidateAttributeIntArgExpr(S, A, 0, &X, true); + auto *arg1 = ValidateAttributeIntArgExpr(S, A, 1, &Y, true); + auto *arg2 = ValidateAttributeIntArgExpr(S, A, 2, &Z, true); + int64_t N = X * Y * Z; if (N > 0 && N <= 1024) { - auto numThreads = ::new (S.Context) HLSLNumThreadsAttr( - A.getRange(), S.Context, X, Y, Z, A.getAttributeSpellingListIndex()); + auto numThreads = ::new (S.Context) + HLSLNumThreadsAttr(A.getRange(), S.Context, arg0, arg1, arg2, + A.getAttributeSpellingListIndex()); declAttr = numThreads; } else { // If the number of threads is invalid, diagnose and drop the attribute. S.Diags.Report(A.getLoc(), diag::warn_hlsl_numthreads_group_size) - << N << X << Y << Z << A.getRange(); + << (int)N << (int)X << (int)Y << (int)Z << A.getRange(); return; } break; @@ -14230,31 +14274,37 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A, case AttributeList::AT_HLSLNodeShareInputOf: declAttr = ::new (S.Context) HLSLNodeShareInputOfAttr( A.getRange(), S.Context, ValidateAttributeStringArg(S, A, nullptr, 0), - ValidateAttributeIntArg(S, A, 1), A.getAttributeSpellingListIndex()); + ValidateAttributeIntArgExpr(S, A, 1), + A.getAttributeSpellingListIndex()); break; case AttributeList::AT_HLSLNodeDispatchGrid: declAttr = ::new (S.Context) HLSLNodeDispatchGridAttr( - A.getRange(), S.Context, ValidateAttributeIntArg(S, A), - ValidateAttributeIntArg(S, A, 1), ValidateAttributeIntArg(S, A, 2), + A.getRange(), S.Context, ValidateAttributeIntArgExpr(S, A), + ValidateAttributeIntArgExpr(S, A, 1), + ValidateAttributeIntArgExpr(S, A, 2), A.getAttributeSpellingListIndex()); - ValidateDispatchGridValues(S.Diags, A, declAttr); + ValidateDispatchGridValues(S, A, declAttr); break; case AttributeList::AT_HLSLNodeMaxDispatchGrid: declAttr = ::new (S.Context) HLSLNodeMaxDispatchGridAttr( - A.getRange(), S.Context, ValidateAttributeIntArg(S, A), - ValidateAttributeIntArg(S, A, 1), ValidateAttributeIntArg(S, A, 2), + A.getRange(), S.Context, ValidateAttributeIntArgExpr(S, A), + ValidateAttributeIntArgExpr(S, A, 1), + ValidateAttributeIntArgExpr(S, A, 2), A.getAttributeSpellingListIndex()); - ValidateDispatchGridValues(S.Diags, A, declAttr); + ValidateDispatchGridValues(S, A, declAttr); break; - case AttributeList::AT_HLSLNodeMaxRecursionDepth: + case AttributeList::AT_HLSLNodeMaxRecursionDepth: { + int64_t maxRecursionDepth = 0; declAttr = ::new (S.Context) HLSLNodeMaxRecursionDepthAttr( - A.getRange(), S.Context, ValidateAttributeIntArg(S, A), + A.getRange(), S.Context, + ValidateAttributeIntArgExpr(S, A, 0, &maxRecursionDepth, true), A.getAttributeSpellingListIndex()); - if (cast(declAttr)->getCount() > 32) + if (maxRecursionDepth > 32) S.Diags.Report(declAttr->getLocation(), diag::err_hlsl_maxrecursiondepth_exceeded) << declAttr->getRange(); break; + } default: Handled = false; break; // SPIRV Change: was return; @@ -15646,8 +15696,13 @@ void hlsl::CustomPrintHLSLAttr(const clang::Attr *A, llvm::raw_ostream &Out, Attr *noconst = const_cast(A); HLSLNumThreadsAttr *ACast = static_cast(noconst); Indent(Indentation, Out); - Out << "[numthreads(" << ACast->getX() << ", " << ACast->getY() << ", " - << ACast->getZ() << ")]\n"; + Out << "[numthreads("; + ACast->getX()->printPretty(Out, nullptr, Policy); + Out << ", "; + ACast->getY()->printPretty(Out, nullptr, Policy); + Out << ", "; + ACast->getZ()->printPretty(Out, nullptr, Policy); + Out << ")]\n"; break; } @@ -15879,11 +15934,16 @@ void hlsl::CustomPrintHLSLAttr(const clang::Attr *A, llvm::raw_ostream &Out, Attr *noconst = const_cast(A); HLSLNodeIdAttr *ACast = static_cast(noconst); Indent(Indentation, Out); - if (ACast->getArrayIndex() > 0) - Out << "[NodeId(\"" << ACast->getName() << "\"," << ACast->getArrayIndex() - << ")]\n"; - else - Out << "[NodeId(\"" << ACast->getName() << "\")]\n"; + Out << "[NodeId(\"" << ACast->getName(); + if (auto *lit = dyn_cast(ACast->getArrayIndex())) { + if (!lit->getValue().isStrictlyPositive()) { + Out << "\")]\n"; + break; + } + } + Out << "\","; + ACast->getArrayIndex()->printPretty(Out, nullptr, Policy); + Out << ")]\n"; break; } @@ -15901,11 +15961,16 @@ void hlsl::CustomPrintHLSLAttr(const clang::Attr *A, llvm::raw_ostream &Out, HLSLNodeShareInputOfAttr *ACast = static_cast(noconst); Indent(Indentation, Out); - if (ACast->getArrayIndex() > 0) - Out << "[NodeShareInputOf(\"" << ACast->getName() << "\"," - << ACast->getArrayIndex() << ")]\n"; - else - Out << "[NodeShareInputOf(\"" << ACast->getName() << "\")]\n"; + Out << "[NodeShareInputOf(\"" << ACast->getName(); + if (auto *lit = dyn_cast(ACast->getArrayIndex())) { + if (!lit->getValue().isStrictlyPositive()) { + Out << "\")]\n"; + break; + } + } + Out << "\","; + ACast->getArrayIndex()->printPretty(Out, nullptr, Policy); + Out << ")]\n"; break; } @@ -16549,8 +16614,11 @@ void DiagnoseNodeEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName, // thread group size is (1,1,1) if (NodeLaunchTy == DXIL::NodeLaunchType::Thread) { if (auto NumThreads = FD->getAttr()) { - if (NumThreads->getX() != 1 || NumThreads->getY() != 1 || - NumThreads->getZ() != 1) { + ASTContext &astContext = S.getASTContext(); + uint32_t x = GetIntConstAttrArg(astContext, NumThreads->getX(), 1); + uint32_t y = GetIntConstAttrArg(astContext, NumThreads->getY(), 1); + uint32_t z = GetIntConstAttrArg(astContext, NumThreads->getZ(), 1); + if (x != 1 || y != 1 || z != 1) { S.Diags.Report(NumThreads->getLocation(), diag::err_hlsl_wg_thread_launch_group_size) << NumThreads->getRange(); diff --git a/tools/clang/test/HLSLFileCheck/hlsl/workgraph/sparsenodes.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/workgraph/sparsenodes.hlsl index 10335ee864..d5645e2849 100644 --- a/tools/clang/test/HLSLFileCheck/hlsl/workgraph/sparsenodes.hlsl +++ b/tools/clang/test/HLSLFileCheck/hlsl/workgraph/sparsenodes.hlsl @@ -10,34 +10,40 @@ // AST: FunctionDecl {{.*}} node_1_0 'void (NodeOutputArray)' // AST-NEXT: ParmVarDecl {{.*}} used OutputArray_1_0 'NodeOutputArray':'NodeOutputArray' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 31 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 31 // AST-NEXT: HLSLNodeArraySizeAttr {{.*}} 129 // AST-NEXT: HLSLAllowSparseNodesAttr // AST: FunctionDecl {{.*}} node_1_1 'void (NodeOutputArray)' // AST-NEXT: ParmVarDecl {{.*}} used OutputArray_1_1 'NodeOutputArray':'NodeOutputArray' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 37 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 37 // AST-NEXT: HLSLUnboundedSparseNodesAttr // AST: FunctionDecl {{.*}} node_1_2 'void (NodeOutput)' // AST-NEXT: ParmVarDecl {{.*}} used Output_1_2 'NodeOutput':'NodeOutput' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 47 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 47 // AST-NEXT: HLSLAllowSparseNodesAttr // AST: FunctionDecl {{.*}} node_2_0 'void (EmptyNodeOutputArray)' // AST-NEXT: ParmVarDecl {{.*}} used OutputArray_2_0 'EmptyNodeOutputArray' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 41 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 41 // AST-NEXT: HLSLNodeArraySizeAttr {{.*}} 131 // AST-NEXT: HLSLAllowSparseNodesAttr // AST: FunctionDecl {{.*}} node_2_1 'void (EmptyNodeOutputArray)' // AST-NEXT: ParmVarDecl {{.*}} used OutputArray_2_1 'EmptyNodeOutputArray' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 43 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 43 // AST-NEXT: HLSLUnboundedSparseNodesAttr // AST: FunctionDecl {{.*}} node_2_2 'void (EmptyNodeOutput)' // AST-NEXT: ParmVarDecl {{.*}} used Output_2_2 'EmptyNodeOutput' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 53 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 53 // AST-NEXT: HLSLAllowSparseNodesAttr // ==== -fcgl Metadata Checks ==== diff --git a/tools/clang/test/HLSLFileCheck/shader_targets/nodes/NodeOutput.hlsl b/tools/clang/test/HLSLFileCheck/shader_targets/nodes/NodeOutput.hlsl index 9343ad9831..6c69441468 100644 --- a/tools/clang/test/HLSLFileCheck/shader_targets/nodes/NodeOutput.hlsl +++ b/tools/clang/test/HLSLFileCheck/shader_targets/nodes/NodeOutput.hlsl @@ -23,20 +23,28 @@ struct MY_INPUT_RECORD // CHECK:FunctionDecl 0x{{.*}} myFancyNode 'void (GroupNodeInputRecords, NodeOutput, NodeOutput, NodeOutputArray, EmptyNodeOutput)' // CHECK-NEXT:ParmVarDecl 0x{{.*}} myInput 'GroupNodeInputRecords':'GroupNodeInputRecords' -// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} 4 +// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 4 // CHECK-NEXT: ParmVarDecl 0x{{.*}} myFascinatingNode 'NodeOutput':'NodeOutput' -// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} 4 +// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 4 // CHECK-NEXT: ParmVarDecl 0x{{.*}} myRecords 'NodeOutput':'NodeOutput' -// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} 4 -// CHECK-NEXT: HLSLNodeIdAttr 0x{{.*}} "myNiftyNode" 3 +// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 4 +// CHECK-NEXT: HLSLNodeIdAttr 0x{{.*}} "myNiftyNode" +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 3 // CHECK-NEXT: ParmVarDecl 0x{{.*}} col:65 myMaterials 'NodeOutputArray' // CHECK-NEXT:HLSLNodeArraySizeAttr 0x{{.*}} 63 // CHECK-NEXT:HLSLAllowSparseNodesAttr 0x{{.*}} // CHECK-NEXT:HLSLMaxRecordsSharedWithAttr 0x{{.*}} myRecords // CHECK-NEXT:ParmVarDecl 0x{{.*}} myProgressCounter 'EmptyNodeOutput' -// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} 20 +// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 20 // CHECK-NEXT: CompoundStmt 0x -// CHECK-NEXT: HLSLNumThreadsAttr 0x{{.*}} 4 5 6 +// CHECK-NEXT: HLSLNumThreadsAttr 0x{{.*}} +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 4 +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 5 +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 6 // CHECK-NEXT: HLSLNodeLaunchAttr 0x{{.*}} "coalescing" // CHECK-NEXT: HLSLShaderAttr 0x{{.*}} "node" [Shader("node")] diff --git a/tools/clang/test/SemaHLSL/hlsl/objects/NodeObjects/node-objects.hlsl b/tools/clang/test/SemaHLSL/hlsl/objects/NodeObjects/node-objects.hlsl index 2fd316bba6..5d2a684178 100644 --- a/tools/clang/test/SemaHLSL/hlsl/objects/NodeObjects/node-objects.hlsl +++ b/tools/clang/test/SemaHLSL/hlsl/objects/NodeObjects/node-objects.hlsl @@ -40,8 +40,14 @@ struct RECORD // AST: `-DeclRefExpr 0x{{.+}} 'DispatchNodeInputRecord':'DispatchNodeInputRecord' lvalue ParmVar 0x[[Param]] 'input' 'DispatchNodeInputRecord':'DispatchNodeInputRecord' // attributes. // AST: |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" -// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} 64 1 1 -// AST: |-HLSLNumThreadsAttr 0x{{.+}} 1024 1 1 +// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 64 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 +// AST: |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1024 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: `-HLSLShaderAttr 0x{{.+}} "node" @@ -67,8 +73,14 @@ void node_DispatchNodeInputRecord(DispatchNodeInputRecord input) // AST: `-DeclRefExpr 0x{{.+}} 'RWDispatchNodeInputRecord':'RWDispatchNodeInputRecord' lvalue ParmVar 0x[[Param]] 'input' 'RWDispatchNodeInputRecord':'RWDispatchNodeInputRecord' // attributes. // AST: |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" -// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} 16 1 1 -// AST: |-HLSLNumThreadsAttr 0x{{.+}} 1024 1 1 +// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 16 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 +// AST: |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1024 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: `-HLSLShaderAttr 0x{{.+}} "node" [Shader("node")] @@ -85,7 +97,8 @@ void node_RWDispatchNodeInputRecord(RWDispatchNodeInputRecord input) // AST: FunctionDecl 0x{{.+}} node_GroupNodeInputRecords 'void (GroupNodeInputRecords)' // AST: |-ParmVarDecl 0x[[Param:[0-9a-f]+]] col:81 used inputs 'GroupNodeInputRecords':'GroupNodeInputRecords' -// AST: | `-HLSLMaxRecordsAttr 0x{{.+}} 256 +// AST: | `-HLSLMaxRecordsAttr 0x{{.+}} +// AST: `-IntegerLiteral 0x{{.+}} 256 // call to wrapper // AST: `-CallExpr 0x{{.+}} 'GroupNodeInputRecords':'GroupNodeInputRecords' // AST: |-ImplicitCastExpr 0x{{.+}} 'GroupNodeInputRecords (*)(GroupNodeInputRecords)' @@ -94,7 +107,10 @@ void node_RWDispatchNodeInputRecord(RWDispatchNodeInputRecord input) // AST: `-DeclRefExpr 0x{{.+}} 'GroupNodeInputRecords':'GroupNodeInputRecords' lvalue ParmVar 0x[[Param]] 'inputs' 'GroupNodeInputRecords':'GroupNodeInputRecords' // attributes. // AST: |-HLSLNodeIsProgramEntryAttr 0x{{.+}} -// AST: |-HLSLNumThreadsAttr 0x{{.+}} 1024 1 1 +// AST: |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1024 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: |-HLSLNodeLaunchAttr 0x{{.+}} "coalescing" // AST: `-HLSLShaderAttr 0x{{.+}} "node" @@ -112,7 +128,8 @@ void node_GroupNodeInputRecords([MaxRecords(256)] GroupNodeInputRecords // AST: FunctionDecl 0x{{.+}} node_RWGroupNodeInputRecords 'void (RWGroupNodeInputRecords)' // AST: |-ParmVarDecl 0x[[Param:[0-9a-f]+]] col:84 used input2 'RWGroupNodeInputRecords':'RWGroupNodeInputRecords' -// AST: | `-HLSLMaxRecordsAttr 0x{{.+}} 4 +// AST: | `-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | `-IntegerLiteral 0x{{.+}} 4 // call to wrapper // AST: CallExpr 0x{{.+}} 'RWGroupNodeInputRecords':'RWGroupNodeInputRecords' // AST: |-ImplicitCastExpr 0x{{.+}} 'RWGroupNodeInputRecords (*)(RWGroupNodeInputRecords)' @@ -121,7 +138,10 @@ void node_GroupNodeInputRecords([MaxRecords(256)] GroupNodeInputRecords // AST: | `-DeclRefExpr 0x{{.+}} 'RWGroupNodeInputRecords':'RWGroupNodeInputRecords' lvalue ParmVar 0x[[Param]] 'input2' 'RWGroupNodeInputRecords':'RWGroupNodeInputRecords' // attributes. // AST: |-HLSLNodeLaunchAttr 0x{{.+}} "coalescing" -// AST: |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 +// AST: |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: `-HLSLShaderAttr 0x{{.+}} "node" struct RECORD2 @@ -192,7 +212,10 @@ void node_RWThreadNodeInputRecord(RWThreadNodeInputRecord input) // AST: `-ImplicitCastExpr 0x{{.+}} 'EmptyNodeInput' // AST: `-DeclRefExpr 0x{{.+}} 'EmptyNodeInput' lvalue ParmVar 0x[[Param]] 'input' 'EmptyNodeInput' // attributes. -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 2 1 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 2 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: | |-HLSLNodeIsProgramEntryAttr // AST: | |-HLSLNodeLaunchAttr 0x{{.+}} "coalescing" // AST: | `-HLSLShaderAttr 0x{{.+}}> "node" @@ -219,8 +242,14 @@ void node_EmptyNodeInput(EmptyNodeInput input) // AST: `-DeclRefExpr 0x{{.+}} 'NodeOutput':'NodeOutput' lvalue ParmVar 0x[[Param]] 'output3' 'NodeOutput':'NodeOutput' // attributes. // AST: | |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" -// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} 32 1 1 -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 1024 1 1 +// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 32 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1024 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 // AST: | `-HLSLShaderAttr 0x{{.+}} "node" [Shader("node")] @@ -239,7 +268,8 @@ void node_NodeOutput(NodeOutput output3) // EmptyNodeOutput // AST: FunctionDecl 0x{{.+}} node_EmptyNodeOutput 'void (EmptyNodeOutput)' // AST: | |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} col:35 used loadStressChild 'EmptyNodeOutput' -// AST: | | `-HLSLMaxRecordsAttr 0x{{.+}} 12 +// AST: | | `-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | | `-IntegerLiteral 0x{{.+}} 12 // call to wrapper // AST: CallExpr 0x{{.+}} 'EmptyNodeOutput':'EmptyNodeOutput' // AST: |-ImplicitCastExpr 0x{{.+}} 'EmptyNodeOutput (*)(EmptyNodeOutput)' @@ -247,8 +277,14 @@ void node_NodeOutput(NodeOutput output3) // AST: `-ImplicitCastExpr 0x{{.+}} 'EmptyNodeOutput' // AST: `-DeclRefExpr 0x{{.+}} 'EmptyNodeOutput' lvalue ParmVar 0x[[Param]] 'loadStressChild' 'EmptyNodeOutput' // attributes. -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 -// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 +// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 // AST: | `-HLSLShaderAttr 0x{{.+}} "node" void loadStressEmptyRecWorker( EmptyNodeOutput outputNode) @@ -269,7 +305,8 @@ void node_EmptyNodeOutput( // NodeOutputArray // AST: FunctionDecl 0x{{.+}} node_NodeOutputArray 'void (NodeOutputArray)' // AST: | |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} col:30 used OutputArray_1_0 'NodeOutputArray':'NodeOutputArray' -// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} 31 +// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | | | `-IntegerLiteral 0x{{.+}} 31 // AST: | | |-HLSLNodeArraySizeAttr 0x{{.+}} 129 // AST: | | `-HLSLAllowSparseNodesAttr 0x{{.+}} // call to wrapper @@ -279,8 +316,14 @@ void node_EmptyNodeOutput( // AST: `-ImplicitCastExpr 0x{{.+}} 'NodeOutputArray':'NodeOutputArray' // AST: `-DeclRefExpr 0x{{.+}} 'NodeOutputArray':'NodeOutputArray' lvalue ParmVar 0x[[Param]] 'OutputArray_1_0' 'NodeOutputArray':'NodeOutputArray' // attributes. -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 -// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 +// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 // AST: | |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // AST: | `-HLSLShaderAttr 0x{{.+}} "node" @@ -304,7 +347,8 @@ void node_NodeOutputArray( // EmptyNodeOutputArray // AST: FunctionDecl 0x{{.+}} node_EmptyNodeOutputArray 'void (EmptyNodeOutputArray)' // AST: | |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} col:64 used EmptyOutputArray 'EmptyNodeOutputArray' -// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} 64 +// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | | | `-IntegerLiteral 0x{{.+}} 64 // AST: | | `-HLSLNodeArraySizeAttr 0x{{.+}} 128 // call to wrapper // AST: CallExpr 0x{{.+}} 'EmptyNodeOutputArray':'EmptyNodeOutputArray' @@ -313,8 +357,14 @@ void node_NodeOutputArray( // AST: `-ImplicitCastExpr 0x{{.+}} 'EmptyNodeOutputArray' // AST: `-DeclRefExpr 0x{{.+}} 'EmptyNodeOutputArray' lvalue ParmVar 0x[[Param]] 'EmptyOutputArray' 'EmptyNodeOutputArray' // attributes. -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 128 1 1 -// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 128 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 +// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 // AST: | |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // AST: | `-HLSLShaderAttr 0x{{.+}} "node" @@ -335,7 +385,8 @@ void node_EmptyNodeOutputArray( // GroupNodeOutputRecords // AST: FunctionDecl 0x{{.+}} node_GroupNodeOutputRecords 'void (NodeOutputArray)' // AST: | |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} col:68 used OutputArray 'NodeOutputArray':'NodeOutputArray' -// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} 64 +// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | | | `-IntegerLiteral 0x{{.+}} 64 // AST: | | `-HLSLNodeArraySizeAttr 0x{{.+}} 128 // call to wrapper // AST: CallExpr 0x{{.+}} 'GroupNodeOutputRecords':'GroupNodeOutputRecords' @@ -344,8 +395,14 @@ void node_EmptyNodeOutputArray( // AST: `-ImplicitCastExpr 0x{{.+}} 'GroupNodeOutputRecords':'GroupNodeOutputRecords' // AST: `-DeclRefExpr 0x{{.+}} 'GroupNodeOutputRecords':'GroupNodeOutputRecords' lvalue Var 0x{{.+}} 'outRec' 'GroupNodeOutputRecords':'GroupNodeOutputRecords' // attributes. -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 128 1 1 -// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 128 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 +// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 // AST: | |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // AST: | `-HLSLShaderAttr 0x{{.+}} "node" [Shader("node")] @@ -366,7 +423,8 @@ void node_GroupNodeOutputRecords( // ThreadNodeOutputRecords // AST: FunctionDecl 0x{{.+}} node_ThreadNodeOutputRecords 'void (NodeOutputArray)' // AST: |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} used OutputArray_1_0 'NodeOutputArray':'NodeOutputArray' -// AST: | |-HLSLMaxRecordsAttr 0x{{.+}} 31 +// AST: | |-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | | `-IntegerLiteral 0x{{.+}} 31 // AST: | |-HLSLNodeArraySizeAttr 0x{{.+}} 129 // AST: | `-HLSLAllowSparseNodesAttr 0x{{.+}} // call to wrapper @@ -376,8 +434,14 @@ void node_GroupNodeOutputRecords( // AST: `-ImplicitCastExpr 0x{{.+}} 'ThreadNodeOutputRecords':'ThreadNodeOutputRecords' // AST: `-DeclRefExpr 0x{{.+}} 'ThreadNodeOutputRecords':'ThreadNodeOutputRecords' lvalue Var 0x{{.+}} 'outRec' 'ThreadNodeOutputRecords':'ThreadNodeOutputRecords' // attributes. -// AST: |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 -// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// AST: |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 +// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // AST: `-HLSLShaderAttr 0x{{.+}} "node" diff --git a/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-EmptyNodeOutputArrayTypes.hlsl b/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-EmptyNodeOutputArrayTypes.hlsl index aa5a7123f6..278161d65d 100644 --- a/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-EmptyNodeOutputArrayTypes.hlsl +++ b/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-EmptyNodeOutputArrayTypes.hlsl @@ -54,7 +54,8 @@ void node_2_0( // CHECK: `-FunctionDecl 0x{{.+}} node_2_0 'void (EmptyNodeOutputArray)' // CHECK-NEXT: |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} used OutputArray_2_0 'EmptyNodeOutputArray' -// CHECK-NEXT: | |-HLSLMaxRecordsAttr 0x{{.+}} 41 +// CHECK-NEXT: | |-HLSLMaxRecordsAttr 0x{{.+}} +// CHECK-NEXT: | | `-IntegerLiteral 0x{{.+}} 'literal int' 41 // CHECK-NEXT: | |-HLSLNodeArraySizeAttr 0x{{.+}} 131 // CHECK-NEXT: | `-HLSLAllowSparseNodesAttr 0x{{.+}} // CHECK-NEXT: |-CompoundStmt 0x{{.+}} @@ -68,7 +69,13 @@ void node_2_0( // CHECK-NEXT: | | `-IntegerLiteral 0x{{.+}}{{.+}} 'literal int' 1 // CHECK-NEXT: | `-ImplicitCastExpr 0x{{.+}} 'unsigned int' // CHECK-NEXT: | `-IntegerLiteral 0x{{.+}} 'literal int' 10 -// CHECK-NEXT: |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 -// CHECK-NEXT: |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// CHECK-NEXT: |-HLSLNumThreadsAttr 0x{{.+}} +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | `-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: |-HLSLNodeDispatchGridAttr 0x{{.+}} +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | `-IntegerLiteral 0x{{.+}} 'literal int' 1 // CHECK-NEXT: |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // CHECK-NEXT: `-HLSLShaderAttr 0x{{.+}} "node" diff --git a/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-NodeOutputArrayTypes.hlsl b/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-NodeOutputArrayTypes.hlsl index 6d381c5953..5bf163edbf 100644 --- a/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-NodeOutputArrayTypes.hlsl +++ b/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-NodeOutputArrayTypes.hlsl @@ -152,7 +152,8 @@ void node_1_1( // CHECK:`-FunctionDecl 0x{{.+}} line:16:6 node_1_1 'void (NodeOutputArray)' // CHECK-NEXT: |-ParmVarDecl 0x[[ParmVar:[0-9a-f]+]] col:30 used OutputArray_1_1 'NodeOutputArray':'NodeOutputArray' -// CHECK-NEXT: | |-HLSLMaxRecordsAttr 0x{{.+}} 37 +// CHECK-NEXT: | |-HLSLMaxRecordsAttr 0x{{.+}} +// CHECK-NEXT: | | `-IntegerLiteral 0x{{.+}} 'literal int' 37 // CHECK-NEXT: | `-HLSLUnboundedSparseNodesAttr 0x{{.+}} // CHECK-NEXT: |-CompoundStmt 0x{{.+}} // CHECK-NEXT: | |-DeclStmt 0x{{.+}} @@ -170,7 +171,13 @@ void node_1_1( // CHECK-NEXT: | `-CXXMemberCallExpr 0x{{.+}} 'void' // CHECK-NEXT: | `-MemberExpr 0x{{.+}} '' .OutputComplete 0x[[OutComplete]] // CHECK-NEXT: | `-DeclRefExpr 0x{{.+}} 'ThreadNodeOutputRecords':'ThreadNodeOutputRecords' lvalue Var 0x[[OutRec]] 'outRec' 'ThreadNodeOutputRecords':'ThreadNodeOutputRecords' -// CHECK-NEXT: |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 -// CHECK-NEXT: |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// CHECK-NEXT: |-HLSLNumThreadsAttr 0x{{.+}} +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | `-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: |-HLSLNodeDispatchGridAttr 0x{{.+}} +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | `-IntegerLiteral 0x{{.+}} 'literal int' 1 // CHECK-NEXT: |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // CHECK-NEXT: `-HLSLShaderAttr 0x{{.+}} "node" diff --git a/tools/clang/unittests/HLSL/RewriterTest.cpp b/tools/clang/unittests/HLSL/RewriterTest.cpp index 613c8561a3..8e1bb30002 100644 --- a/tools/clang/unittests/HLSL/RewriterTest.cpp +++ b/tools/clang/unittests/HLSL/RewriterTest.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #ifdef _WIN32