diff --git a/tools/clang/include/clang/AST/Expr.h b/tools/clang/include/clang/AST/Expr.h index 26eff309f7..368d24d03d 100644 --- a/tools/clang/include/clang/AST/Expr.h +++ b/tools/clang/include/clang/AST/Expr.h @@ -531,6 +531,9 @@ class Expr : public Stmt { bool isConstantInitializer(ASTContext &Ctx, bool ForRef, const Expr **Culprit = nullptr) const; + bool isVulkanSpecConstantExpr(const ASTContext &Ctx, + APValue *Result = nullptr) const; + /// EvalStatus is a struct with detailed info about an evaluation in progress. struct EvalStatus { /// HasSideEffects - Whether the evaluated expression has side effects. diff --git a/tools/clang/include/clang/Basic/Attr.td b/tools/clang/include/clang/Basic/Attr.td index 1797597d17..89c5e22788 100644 --- a/tools/clang/include/clang/Basic/Attr.td +++ b/tools/clang/include/clang/Basic/Attr.td @@ -668,7 +668,7 @@ def HLSLMaxTessFactor: InheritableAttr { } def HLSLNumThreads: InheritableAttr { let Spellings = [CXX11<"", "numthreads", 2015>]; - let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">]; + let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">]; let Documentation = [Undocumented]; } def HLSLRootSignature: InheritableAttr { @@ -1004,7 +1004,7 @@ def HLSLNodeIsProgramEntry : InheritableAttr { def HLSLNodeId : InheritableAttr { let Spellings = [CXX11<"", "nodeid", 2017>]; - let Args = [StringArgument<"Name">,DefaultIntArgument<"ArrayIndex", 0>]; + let Args = [StringArgument<"Name">, ExprArgument<"ArrayIndex", 1>]; let Documentation = [Undocumented]; } @@ -1016,25 +1016,25 @@ def HLSLNodeLocalRootArgumentsTableIndex : InheritableAttr { def HLSLNodeShareInputOf : InheritableAttr { let Spellings = [CXX11<"", "nodeshareinputof", 2017>]; - let Args = [StringArgument<"Name">,UnsignedArgument<"ArrayIndex", 1>]; + let Args = [StringArgument<"Name">,ExprArgument<"ArrayIndex", 1>]; let Documentation = [Undocumented]; } def HLSLNodeDispatchGrid: InheritableAttr { let Spellings = [CXX11<"", "nodedispatchgrid", 2015>]; - let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">]; + let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">]; let Documentation = [Undocumented]; } def HLSLNodeMaxDispatchGrid: InheritableAttr { let Spellings = [CXX11<"", "nodemaxdispatchgrid", 2015>]; - let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">]; + let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">]; let Documentation = [Undocumented]; } def HLSLNodeMaxRecursionDepth : InheritableAttr { let Spellings = [CXX11<"", "nodemaxrecursiondepth", 2017>]; - let Args = [UnsignedArgument<"Count">]; + let Args = [ExprArgument<"Count">]; let Documentation = [Undocumented]; } @@ -1182,7 +1182,7 @@ def HLSLHitObject : InheritableAttr { def HLSLMaxRecords : InheritableAttr { let Spellings = [CXX11<"", "MaxRecords", 2015>]; - let Args = [IntArgument<"maxCount">]; + let Args = [ExprArgument<"maxCount">]; let Documentation = [Undocumented]; } diff --git a/tools/clang/include/clang/SPIRV/SpirvContext.h b/tools/clang/include/clang/SPIRV/SpirvContext.h index 8e0458e731..8321da9e63 100644 --- a/tools/clang/include/clang/SPIRV/SpirvContext.h +++ b/tools/clang/include/clang/SPIRV/SpirvContext.h @@ -456,6 +456,15 @@ class SpirvContext { instructionsWithLoweredType.end(); } + SpirvInstruction *getSpecConstant(const VarDecl *decl) { + return specConstants[decl]; + } + + void registerSpecConstant(const VarDecl *decl, + SpirvInstruction *specConstant) { + specConstants[decl] = specConstant; + } + void registerDispatchGridIndex(const RecordDecl *decl, unsigned index) { auto iter = dispatchGridIndices.find(decl); if (iter == dispatchGridIndices.end()) { @@ -536,6 +545,7 @@ class SpirvContext { llvm::DenseSet functionTypes; llvm::DenseMap spirvIntrinsicTypesById; llvm::SmallVector spirvIntrinsicTypes; + llvm::MapVector specConstants; const AccelerationStructureTypeNV *accelerationStructureTypeNV; const RayQueryTypeKHR *rayQueryTypeKHR; diff --git a/tools/clang/include/clang/Sema/SemaHLSL.h b/tools/clang/include/clang/Sema/SemaHLSL.h index 80ce8ddd7d..2dafbf7734 100644 --- a/tools/clang/include/clang/Sema/SemaHLSL.h +++ b/tools/clang/include/clang/Sema/SemaHLSL.h @@ -160,8 +160,6 @@ unsigned CaculateInitListArraySizeForHLSL(clang::Sema *sema, const clang::InitListExpr *InitList, const clang::QualType EltTy); -bool ContainsLongVector(clang::QualType); - bool IsConversionToLessOrEqualElements(clang::Sema *self, const clang::ExprResult &sourceExpr, const clang::QualType &targetType, diff --git a/tools/clang/lib/AST/ExprConstant.cpp b/tools/clang/lib/AST/ExprConstant.cpp index baa0349cfe..c5313c4ec8 100644 --- a/tools/clang/lib/AST/ExprConstant.cpp +++ b/tools/clang/lib/AST/ExprConstant.cpp @@ -9448,6 +9448,19 @@ bool Expr::isIntegerConstantExpr(llvm::APSInt &Value, const ASTContext &Ctx, return true; } +bool Expr::isVulkanSpecConstantExpr(const ASTContext &Ctx, + APValue *Result) const { + auto *D = dyn_cast(this); + if (!D) + return false; + auto *V = dyn_cast(D->getDecl()); + if (!V || !V->hasAttr()) + return false; + if (const Expr *I = V->getAnyInitializer()) + return I->IgnoreParenCasts()->isCXX11ConstantExpr(Ctx, Result); + return true; +} + bool Expr::isCXX98IntegralConstantExpr(const ASTContext &Ctx) const { return CheckICE(this, Ctx).Kind == IK_ICE; } diff --git a/tools/clang/lib/CodeGen/CGHLSLMS.cpp b/tools/clang/lib/CodeGen/CGHLSLMS.cpp index b5add521a6..c3f728e11e 100644 --- a/tools/clang/lib/CodeGen/CGHLSLMS.cpp +++ b/tools/clang/lib/CodeGen/CGHLSLMS.cpp @@ -328,6 +328,19 @@ class CGMSHLSLRuntime : public CGHLSLRuntime { }; } // namespace +static uint32_t GetIntConstAttrArg(ASTContext &astContext, const Expr *expr, + uint32_t defaultVal = 0) { + if (expr) { + llvm::APSInt apsInt; + APValue apValue; + if (expr->isIntegerConstantExpr(apsInt, astContext)) + return (uint32_t)apsInt.getSExtValue(); + if (expr->isVulkanSpecConstantExpr(astContext, &apValue) && apValue.isInt()) + return (uint32_t)apValue.getInt().getSExtValue(); + } + return defaultVal; +} + //------------------------------------------------------------------------------ // // CGMSHLSLRuntime methods. @@ -1422,6 +1435,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { } DiagnosticsEngine &Diags = CGM.getDiags(); + ASTContext &astContext = CGM.getTypes().getContext(); std::unique_ptr funcProps = llvm::make_unique(); @@ -1632,10 +1646,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { // Populate numThreads if (const HLSLNumThreadsAttr *Attr = FD->getAttr()) { - - funcProps->numThreads[0] = Attr->getX(); - funcProps->numThreads[1] = Attr->getY(); - funcProps->numThreads[2] = Attr->getZ(); + funcProps->numThreads[0] = GetIntConstAttrArg(astContext, Attr->getX(), 1); + funcProps->numThreads[1] = GetIntConstAttrArg(astContext, Attr->getY(), 1); + funcProps->numThreads[2] = GetIntConstAttrArg(astContext, Attr->getZ(), 1); if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) { unsigned DiagID = Diags.getCustomDiagID( @@ -1808,7 +1821,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { if (const auto *pAttr = FD->getAttr()) { funcProps->NodeShaderID.Name = pAttr->getName().str(); - funcProps->NodeShaderID.Index = pAttr->getArrayIndex(); + funcProps->NodeShaderID.Index = + GetIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0); } else { funcProps->NodeShaderID.Name = FD->getName().str(); funcProps->NodeShaderID.Index = 0; @@ -1819,20 +1833,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { } if (const auto *pAttr = FD->getAttr()) { funcProps->NodeShaderSharedInput.Name = pAttr->getName().str(); - funcProps->NodeShaderSharedInput.Index = pAttr->getArrayIndex(); + funcProps->NodeShaderSharedInput.Index = + GetIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0); } if (const auto *pAttr = FD->getAttr()) { - funcProps->Node.DispatchGrid[0] = pAttr->getX(); - funcProps->Node.DispatchGrid[1] = pAttr->getY(); - funcProps->Node.DispatchGrid[2] = pAttr->getZ(); + funcProps->Node.DispatchGrid[0] = + GetIntConstAttrArg(astContext, pAttr->getX(), 1); + funcProps->Node.DispatchGrid[1] = + GetIntConstAttrArg(astContext, pAttr->getY(), 1); + funcProps->Node.DispatchGrid[2] = + GetIntConstAttrArg(astContext, pAttr->getZ(), 1); } if (const auto *pAttr = FD->getAttr()) { - funcProps->Node.MaxDispatchGrid[0] = pAttr->getX(); - funcProps->Node.MaxDispatchGrid[1] = pAttr->getY(); - funcProps->Node.MaxDispatchGrid[2] = pAttr->getZ(); + funcProps->Node.MaxDispatchGrid[0] = + GetIntConstAttrArg(astContext, pAttr->getX(), 1); + funcProps->Node.MaxDispatchGrid[1] = + GetIntConstAttrArg(astContext, pAttr->getY(), 1); + funcProps->Node.MaxDispatchGrid[2] = + GetIntConstAttrArg(astContext, pAttr->getZ(), 1); } if (const auto *pAttr = FD->getAttr()) { - funcProps->Node.MaxRecursionDepth = pAttr->getCount(); + funcProps->Node.MaxRecursionDepth = + GetIntConstAttrArg(astContext, pAttr->getCount(), 0); } if (!FD->getAttr()) { // NumThreads wasn't specified. @@ -2346,8 +2368,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++; if (parmDecl->hasAttr()) { - node.MaxRecords = - parmDecl->getAttr()->getMaxCount(); + node.MaxRecords = GetIntConstAttrArg( + astContext, + parmDecl->getAttr()->getMaxCount(), 1); } if (parmDecl->hasAttr()) node.Flags.SetGloballyCoherent(); @@ -2378,7 +2401,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { // OutputID from attribute if (const auto *Attr = parmDecl->getAttr()) { node.OutputID.Name = Attr->getName().str(); - node.OutputID.Index = Attr->getArrayIndex(); + node.OutputID.Index = + GetIntConstAttrArg(astContext, Attr->getArrayIndex(), 0); } else { node.OutputID.Name = parmDecl->getName().str(); node.OutputID.Index = 0; @@ -2437,7 +2461,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { node.MaxRecordsSharedWith = ix; } if (const auto *Attr = parmDecl->getAttr()) - node.MaxRecords = Attr->getMaxCount(); + node.MaxRecords = GetIntConstAttrArg(astContext, Attr->getMaxCount(), 0); } if (inputPatchCount > 1) { diff --git a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp index 9d0d8f51a3..b01992208f 100644 --- a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp +++ b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp @@ -1815,6 +1815,7 @@ DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) { void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl, SpirvInstruction *specConstant) { specConstant->setRValue(); + spvContext.registerSpecConstant(decl, specConstant); registerVariableForDecl(decl, createDeclSpirvInfo(specConstant)); } diff --git a/tools/clang/lib/SPIRV/EmitVisitor.cpp b/tools/clang/lib/SPIRV/EmitVisitor.cpp index 8de0262ae6..6118b505e9 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.cpp +++ b/tools/clang/lib/SPIRV/EmitVisitor.cpp @@ -2523,6 +2523,24 @@ isFieldMergeWithPrevious(const StructType::FieldInfo &previous, return previous.fieldIndex == field.fieldIndex; } +uint32_t EmitTypeHandler::getAttrArgInstr(ASTContext &astContext, + const Expr *expr, + uint32_t defaultVal) { + if (expr) { + llvm::APSInt apsInt; + APValue apValue; + if (expr->isIntegerConstantExpr(apsInt, astContext)) + return getOrCreateConstantInt(apsInt, context.getUIntType(32), false); + if (expr->isVulkanSpecConstantExpr(astContext, &apValue) && + apValue.isInt()) { + auto *declRefExpr = dyn_cast(expr); + auto *decl = dyn_cast(declRefExpr->getDecl()); + return getOrAssignResultId(context.getSpecConstant(decl)); + } + } + return defaultVal; +} + uint32_t EmitTypeHandler::emitType(const SpirvType *type) { // First get the decorations that would apply to this type. bool alreadyExists = false; @@ -2655,9 +2673,9 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) { if (hlsl::IsHLSLNodeOutputType(nodeDecl->getType())) { StringRef name = nodeDecl->getName(); unsigned index = 0; - if (auto nodeID = nodeDecl->getAttr()) { + if (auto *nodeID = nodeDecl->getAttr()) { name = nodeID->getName(); - index = nodeID->getArrayIndex(); + index = getAttrArgInstr(astContext, nodeID->getArrayIndex()); } auto *str = new (context) SpirvConstantString(name); @@ -2665,17 +2683,14 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) { emitDecoration(id, spv::Decoration::PayloadNodeNameAMDX, {nodeName}, llvm::None, true); if (index) { - uint32_t baseIndex = getOrCreateConstantInt( - llvm::APInt(32, index), context.getUIntType(32), false); - emitDecoration(id, spv::Decoration::PayloadNodeBaseIndexAMDX, - {baseIndex}, llvm::None, true); + emitDecoration(id, spv::Decoration::PayloadNodeBaseIndexAMDX, {index}, + llvm::None, true); } } uint32_t maxRecords; if (const auto *attr = nodeDecl->getAttr()) { - maxRecords = getOrCreateConstantInt(llvm::APInt(32, attr->getMaxCount()), - context.getUIntType(32), false); + maxRecords = getAttrArgInstr(astContext, attr->getMaxCount(), 1); } else { maxRecords = getOrCreateConstantInt(llvm::APInt(32, 1), context.getUIntType(32), false); diff --git a/tools/clang/lib/SPIRV/EmitVisitor.h b/tools/clang/lib/SPIRV/EmitVisitor.h index fb4b22e52b..4e46a53179 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.h +++ b/tools/clang/lib/SPIRV/EmitVisitor.h @@ -67,6 +67,9 @@ class EmitTypeHandler { EmitTypeHandler(const EmitTypeHandler &) = delete; EmitTypeHandler &operator=(const EmitTypeHandler &) = delete; + uint32_t getAttrArgInstr(ASTContext &astContext, const Expr *expr, + uint32_t defaultVal = 0); + // Emits the instruction for the given type into the typeConstantBinary and // returns the result-id for the type. If the type has already been emitted, // it only returns its result-id. diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index cc7016b594..50a4072beb 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -11262,15 +11262,11 @@ SpirvEmitter::processIntrinsicIsValid(const CXXMemberCallExpr *callExpr) { const auto *declRefExpr = dyn_cast(baseExpr->IgnoreImpCasts()); const auto *paramDecl = dyn_cast(declRefExpr->getDecl()); - int nodeIndex = 0; - if (HLSLNodeIdAttr *nodeId = paramDecl->getAttr()) { - nodeIndex = nodeId->getArrayIndex(); - } SpirvInstruction *payload = doExpr(baseExpr); if (!shaderIndex) { - shaderIndex = spvBuilder.getConstantInt(astContext.UnsignedIntTy, - llvm::APInt(32, nodeIndex)); + shaderIndex = + spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)); } return spvBuilder.createIsNodePayloadValid(payload, shaderIndex, loc); @@ -13682,6 +13678,88 @@ void SpirvEmitter::processInlineSpirvAttributes(const FunctionDecl *decl) { } } +SpirvInstruction * +SpirvEmitter::evalIntConstAttrArg(const Expr *expr, + llvm::Optional defaultVal) { + if (expr) { + QualType type = expr->getType(); + assert(type->isIntegerType()); + SpirvInstruction *ret = doExpr(expr); + assert(ret->getopcode() == spv::Op::OpConstant || + ret->getopcode() == spv::Op::OpSpecConstant); + if (type->isSignedIntegerType()) + ret->setAstResultType(astContext.UnsignedIntTy); + return ret; + } + if (defaultVal.hasValue()) + return spvBuilder.getConstantInt(astContext.UnsignedIntTy, + llvm::APInt(32, defaultVal.getValue())); + return nullptr; +} + +bool SpirvEmitter::processNumThreadsAttr(const FunctionDecl *decl) { + auto *numThreadsAttr = decl->getAttr(); + if (!numThreadsAttr) + return false; + + bool localSizeId = false; + Expr *x = numThreadsAttr->getX(), *y = numThreadsAttr->getY(), + *z = numThreadsAttr->getZ(); + + // SPIR-V spec says LocalSizeId missing "before version 1.2" but SPIRV-Tools + // validation excludes 1.2 as well. + switch (featureManager.getTargetEnv()) { + case SPV_ENV_VULKAN_1_0: + case SPV_ENV_VULKAN_1_1: + case SPV_ENV_VULKAN_1_1_SPIRV_1_4: + case SPV_ENV_VULKAN_1_2: + break; + default: + if (x->isVulkanSpecConstantExpr(astContext) || + y->isVulkanSpecConstantExpr(astContext) || + z->isVulkanSpecConstantExpr(astContext)) { + auto f = [this](Expr *E) -> SpirvInstruction * { + if (E) { + llvm::APSInt apsInt; + APValue apValue; + if (E->isIntegerConstantExpr(apsInt, astContext)) + return spvBuilder.getConstantInt(astContext.UnsignedIntTy, apsInt); + if (E->isVulkanSpecConstantExpr(astContext, &apValue) && + apValue.isInt()) { + auto *declRefExpr = dyn_cast(E); + auto *varDecl = dyn_cast(declRefExpr->getDecl()); + return declIdMapper.getDeclEvalInfo(varDecl, + declRefExpr->getExprLoc()); + } + } + return spvBuilder.getConstantInt(astContext.UnsignedIntTy, + llvm::APInt(32, 1)); + }; + spvBuilder.addExecutionModeId(entryFunction, + spv::ExecutionMode::LocalSizeId, + {f(x), f(y), f(z)}, decl->getLocation()); + return true; + } + } + + auto f = [this](Expr *E) -> unsigned { + if (E) { + llvm::APSInt apsInt; + APValue apValue; + if (E->isIntegerConstantExpr(apsInt, astContext)) + return (unsigned)apsInt.getZExtValue(); + if (E->isVulkanSpecConstantExpr(astContext, &apValue) && + apValue.isInt()) { + return apValue.getInt().getZExtValue(); + } + } + return 1U; + }; + spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, + {f(x), f(y), f(z)}, decl->getLocation()); + return true; +} + bool SpirvEmitter::processGeometryShaderAttributes(const FunctionDecl *decl, uint32_t *arraySize) { bool success = true; @@ -13873,28 +13951,18 @@ void SpirvEmitter::checkForWaveSizeAttr(const FunctionDecl *decl) { } void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) { - auto *numThreadsAttr = decl->getAttr(); - assert(numThreadsAttr && "thread group size missing from entry-point"); - - uint32_t x = static_cast(numThreadsAttr->getX()); - uint32_t y = static_cast(numThreadsAttr->getY()); - uint32_t z = static_cast(numThreadsAttr->getZ()); - - spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, - {x, y, z}, decl->getLocation()); + if (!processNumThreadsAttr(decl)) { + assert(false && "thread group size missing from entry-point"); + } checkForWaveSizeAttr(decl); } void SpirvEmitter::processNodeShaderAttributes(const FunctionDecl *decl) { - uint32_t x = 1, y = 1, z = 1; - if (auto *numThreadsAttr = decl->getAttr()) { - x = static_cast(numThreadsAttr->getX()); - y = static_cast(numThreadsAttr->getY()); - z = static_cast(numThreadsAttr->getZ()); + if (!processNumThreadsAttr(decl)) { + spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, + {1, 1, 1}, decl->getLocation()); } - spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, - {x, y, z}, decl->getLocation()); auto *nodeLaunchAttr = decl->getAttr(); StringRef launchType = nodeLaunchAttr ? nodeLaunchAttr->getLaunchType() : ""; @@ -13904,20 +13972,20 @@ void SpirvEmitter::processNodeShaderAttributes(const FunctionDecl *decl) { decl->getLocation()); } - uint64_t nodeId = 0; - if (const auto nodeIdAttr = decl->getAttr()) - nodeId = static_cast(nodeIdAttr->getArrayIndex()); - spvBuilder.addExecutionModeId( - entryFunction, spv::ExecutionMode::ShaderIndexAMDX, - {spvBuilder.getConstantInt(astContext.UnsignedIntTy, - llvm::APInt(32, nodeId))}, - decl->getLocation()); + SpirvInstruction *nodeId = nullptr; + if (const auto *nodeIdAttr = decl->getAttr()) + nodeId = evalIntConstAttrArg(nodeIdAttr->getArrayIndex(), 0); + else + nodeId = + spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)); + spvBuilder.addExecutionModeId(entryFunction, + spv::ExecutionMode::ShaderIndexAMDX, {nodeId}, + decl->getLocation()); if (const auto *nodeMaxRecursionDepthAttr = decl->getAttr()) { - SpirvInstruction *count = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, - llvm::APInt(32, nodeMaxRecursionDepthAttr->getCount())); + SpirvInstruction *count = + evalIntConstAttrArg(nodeMaxRecursionDepthAttr->getCount()); spvBuilder.addExecutionModeId(entryFunction, spv::ExecutionMode::MaxNodeRecursionAMDX, {count}, decl->getLocation()); @@ -13927,32 +13995,25 @@ void SpirvEmitter::processNodeShaderAttributes(const FunctionDecl *decl) { decl->getAttr()) { SpirvInstruction *name = spvBuilder.getConstantString(nodeShareInputOfAttr->getName()); - SpirvInstruction *index = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, - llvm::APInt(32, nodeShareInputOfAttr->getArrayIndex())); + SpirvInstruction *index = + evalIntConstAttrArg(nodeShareInputOfAttr->getArrayIndex(), 0); spvBuilder.addExecutionModeId(entryFunction, spv::ExecutionMode::SharesInputWithAMDX, {name, index}, decl->getLocation()); } if (const auto *dispatchGrid = decl->getAttr()) { - SpirvInstruction *gridX = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, llvm::APInt(32, dispatchGrid->getX())); - SpirvInstruction *gridY = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, llvm::APInt(32, dispatchGrid->getY())); - SpirvInstruction *gridZ = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, llvm::APInt(32, dispatchGrid->getZ())); + SpirvInstruction *gridX = evalIntConstAttrArg(dispatchGrid->getX(), 1); + SpirvInstruction *gridY = evalIntConstAttrArg(dispatchGrid->getY(), 1); + SpirvInstruction *gridZ = evalIntConstAttrArg(dispatchGrid->getZ(), 1); spvBuilder.addExecutionModeId(entryFunction, spv::ExecutionMode::StaticNumWorkgroupsAMDX, {gridX, gridY, gridZ}, decl->getLocation()); } else if (const auto *maxDispatchGrid = decl->getAttr()) { - SpirvInstruction *gridX = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, llvm::APInt(32, maxDispatchGrid->getX())); - SpirvInstruction *gridY = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, llvm::APInt(32, maxDispatchGrid->getY())); - SpirvInstruction *gridZ = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, llvm::APInt(32, maxDispatchGrid->getZ())); + SpirvInstruction *gridX = evalIntConstAttrArg(maxDispatchGrid->getX(), 1); + SpirvInstruction *gridY = evalIntConstAttrArg(maxDispatchGrid->getY(), 1); + SpirvInstruction *gridZ = evalIntConstAttrArg(maxDispatchGrid->getZ(), 1); spvBuilder.addExecutionModeId(entryFunction, spv::ExecutionMode::MaxNumWorkgroupsAMDX, {gridX, gridY, gridZ}, decl->getLocation()); @@ -14165,14 +14226,7 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing( bool SpirvEmitter::processMeshOrAmplificationShaderAttributes( const FunctionDecl *decl, uint32_t *outVerticesArraySize) { - if (auto *numThreadsAttr = decl->getAttr()) { - uint32_t x, y, z; - x = static_cast(numThreadsAttr->getX()); - y = static_cast(numThreadsAttr->getY()); - z = static_cast(numThreadsAttr->getZ()); - spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, - {x, y, z}, decl->getLocation()); - } + processNumThreadsAttr(decl); // Early return for amplification shaders as they only take the 'numthreads' // attribute. @@ -15781,13 +15835,16 @@ bool SpirvEmitter::spirvToolsValidate(std::vector *mod, void SpirvEmitter::addDerivativeGroupExecutionMode() { assert(spvContext.isCS()); - SpirvExecutionMode *numThreadsEm = - cast(spvBuilder.getModule()->findExecutionMode( - entryFunction, spv::ExecutionMode::LocalSize)); + SpirvExecutionMode *numThreadsEm = dyn_cast_or_null( + spvBuilder.getModule()->findExecutionMode(entryFunction, + spv::ExecutionMode::LocalSize)); + // If there is no LocalSize, there must be LocalSizeId. + if (!numThreadsEm) + return addDerivativeGroupExecutionModeId(); auto numThreads = numThreadsEm->getParams(); // The layout of the quad is determined by the numer of threads in each - // dimention. From the HLSL spec + // dimension. From the HLSL spec // (https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_Derivatives.html): // // Where numthreads has an X value divisible by 4 and Y and Z are both 1, the @@ -15810,6 +15867,47 @@ void SpirvEmitter::addDerivativeGroupExecutionMode() { spvBuilder.addExecutionMode(entryFunction, em, {}, SourceLocation()); } +void SpirvEmitter::addDerivativeGroupExecutionModeId() { + assert(spvContext.isCS()); + + SpirvExecutionModeId *numThreadsEm = + dyn_cast(spvBuilder.getModule()->findExecutionMode( + entryFunction, spv::ExecutionMode::LocalSizeId)); + auto numThreads = numThreadsEm->getParams(); + auto f = [this](SpirvInstruction *arg) -> llvm::Optional { + if (auto con = dyn_cast(arg)) { + return (unsigned)con->getValue().getZExtValue(); + } + return llvm::None; + }; + + // The layout of the quad is determined by the numer of threads in each + // dimension. From the HLSL spec + // (https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_Derivatives.html): + // + // Where numthreads has an X value divisible by 4 and Y and Z are both 1, the + // quad layouts are determined according to 1D quad rules. Where numthreads X + // and Y values are divisible by 2, the quad layouts are determined according + // to 2D quad rules. Using derivative operations in any numthreads + // configuration not matching either of these is invalid and will produce an + // error. + static_assert(spv::ExecutionMode::DerivativeGroupQuadsNV == + spv::ExecutionMode::DerivativeGroupQuadsKHR); + static_assert(spv::ExecutionMode::DerivativeGroupLinearNV == + spv::ExecutionMode::DerivativeGroupLinearKHR); + spv::ExecutionMode em = spv::ExecutionMode::DerivativeGroupQuadsNV; + auto x = f(numThreads[0]), y = f(numThreads[1]), z = f(numThreads[2]); + if (x.hasValue() && x.getValue() % 4 == 0 && y.hasValue() && + y.getValue() == 1 && z.hasValue() && z.getValue() == 1) { + em = spv::ExecutionMode::DerivativeGroupLinearNV; + } else { + assert((!x.hasValue() || x.getValue() % 2 == 0) && + (!y.hasValue() || y.getValue() % 2 == 0)); + } + + spvBuilder.addExecutionMode(entryFunction, em, {}, SourceLocation()); +} + SpirvVariable *SpirvEmitter::createPCFParmVarAndInitFromStageInputVar( const ParmVarDecl *param) { const QualType type = param->getType(); diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 14401c6418..744e114e3a 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -896,6 +896,12 @@ class SpirvEmitter : public ASTConsumer { /// \brief Handle inline SPIR-V attributes for the entry function. void processInlineSpirvAttributes(const FunctionDecl *entryFunction); + SpirvInstruction * + evalIntConstAttrArg(const Expr *expr, + llvm::Optional defaultVal = llvm::None); + + bool processNumThreadsAttr(const FunctionDecl *decl); + /// \brief Adds necessary execution modes for the hull/domain shaders based on /// the HLSL attributes of the entry point function. /// In the case of hull shaders, also writes the number of output control @@ -1365,6 +1371,7 @@ class SpirvEmitter : public ASTConsumer { /// This decision is made according to the rules in /// https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_Derivatives.html. void addDerivativeGroupExecutionMode(); + void addDerivativeGroupExecutionModeId(); /// Creates an input variable for `param` that will be used by the patch /// constant function. The parameter is also added to the patch constant diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index 3d9de1804d..76b9a79ba6 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -12283,21 +12283,37 @@ void Sema::CheckHLSLFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall, } } +static uint32_t GetIntConstAttrArg(ASTContext &astContext, const Expr *expr, + uint32_t defaultVal = 0) { + if (expr) { + llvm::APSInt apsInt; + APValue apValue; + if (expr->isIntegerConstantExpr(apsInt, astContext)) + return (uint32_t)apsInt.getSExtValue(); + if (expr->isVulkanSpecConstantExpr(astContext, &apValue) && apValue.isInt()) + return (uint32_t)apValue.getInt().getSExtValue(); + } + return defaultVal; +} + ///////////////////////////////////////////////////////////////////////////// // Check HLSL intrinsic calls reachable from entry/export functions. -static void DiagnoseNumThreadsForDerivativeOp(const HLSLNumThreadsAttr *Attr, - SourceLocation LocDeriv, - FunctionDecl *FD, - const FunctionDecl *EntryDecl, - DiagnosticsEngine &Diags) { +static void DiagnoseNumThreadsForDerivativeOp( + Sema &S, const HLSLNumThreadsAttr *Attr, SourceLocation LocDeriv, + FunctionDecl *FD, const FunctionDecl *EntryDecl, DiagnosticsEngine &Diags) { bool invalidNumThreads = false; - if (Attr->getY() != 1) { + ASTContext &astContext = S.getASTContext(); + uint32_t x = GetIntConstAttrArg(astContext, Attr->getX(), 1); + uint32_t y = GetIntConstAttrArg(astContext, Attr->getY(), 1); + uint32_t z = GetIntConstAttrArg(astContext, Attr->getZ(), 1); + + if (y != 1) { // 2D mode requires x and y to be multiple of 2. - invalidNumThreads = !((Attr->getX() % 2) == 0 && (Attr->getY() % 2) == 0); + invalidNumThreads = !((x % 2) == 0 && (y % 2) == 0); } else { // 1D mode requires x to be multiple of 4 and y and z to be 1. - invalidNumThreads = (Attr->getX() % 4) != 0 || (Attr->getZ() != 1); + invalidNumThreads = (x % 4) != 0 || (z != 1); } if (invalidNumThreads) { Diags.Report(LocDeriv, diag::warn_hlsl_derivatives_wrong_numthreads) @@ -12343,7 +12359,7 @@ static void DiagnoseDerivativeOp(Sema &S, FunctionDecl *FD, SourceLocation Loc, if (const HLSLNumThreadsAttr *Attr = EntryDecl->getAttr()) { - DiagnoseNumThreadsForDerivativeOp(Attr, Loc, FD, EntryDecl, Diags); + DiagnoseNumThreadsForDerivativeOp(S, Attr, Loc, FD, EntryDecl, Diags); } } @@ -13782,12 +13798,12 @@ FlattenedTypeIterator::CompareTypesForInit(HLSLExternalSource &source, //////////////////////////////////////////////////////////////////////////////// // Attribute processing support. // -static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, - unsigned index = 0) { - int64_t value = 0; +static Expr *ValidateAttributeIntArgExpr(Sema &S, const AttributeList &Attr, + unsigned index, int64_t *value, + bool allowDefinedConstant = false) { + Expr *E = nullptr; if (Attr.getNumArgs() > index) { - Expr *E = nullptr; if (!Attr.isArgExpr(index)) { // For case arg is constant variable. IdentifierLoc *loc = Attr.getArgAsIdent(index); @@ -13798,13 +13814,13 @@ static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, if (!decl) { S.Diag(Attr.getLoc(), diag::warn_hlsl_attribute_expects_uint_literal) << Attr.getName(); - return value; + return nullptr; } Expr *init = decl->getInit(); if (!init) { S.Diag(Attr.getLoc(), diag::warn_hlsl_attribute_expects_uint_literal) << Attr.getName(); - return value; + return nullptr; } E = init; } else @@ -13814,11 +13830,13 @@ static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, bool displayError = false; if (E->isTypeDependent() || E->isValueDependent() || !E->isCXX11ConstantExpr(S.Context, &ArgNum)) { - displayError = true; + displayError = + !allowDefinedConstant || + !(E->isVulkanSpecConstantExpr(S.Context, &ArgNum) && ArgNum.isInt()); } else { if (ArgNum.isInt()) { - value = ArgNum.getInt().getSExtValue(); - if (!(E->getType()->isIntegralOrEnumerationType()) || value < 0) { + *value = ArgNum.getInt().getSExtValue(); + if (!(E->getType()->isIntegralOrEnumerationType()) || *value < 0) { S.Diag(Attr.getLoc(), diag::warn_hlsl_attribute_expects_uint_literal) << Attr.getName(); } @@ -13828,8 +13846,8 @@ static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, if (ArgNum.getFloat().convertToInteger( floatInt, llvm::APFloat::rmTowardZero, &isPrecise) == llvm::APFloat::opStatus::opOK) { - value = floatInt.getSExtValue(); - if (value < 0) { + *value = floatInt.getSExtValue(); + if (*value < 0) { S.Diag(Attr.getLoc(), diag::warn_hlsl_attribute_expects_uint_literal) << Attr.getName(); @@ -13847,9 +13865,23 @@ static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, S.Diag(Attr.getLoc(), diag::err_attribute_argument_type) << Attr.getName() << AANT_ArgumentIntegerConstant << E->getSourceRange(); + return nullptr; } } + return E; +} + +static Expr *ValidateAttributeIntArgExpr(Sema &S, const AttributeList &Attr, + unsigned index = 0) { + int64_t value = 0; + return ValidateAttributeIntArgExpr(S, Attr, index, &value, true); +} + +static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, + unsigned index = 0) { + int64_t value = 0; + ValidateAttributeIntArgExpr(S, Attr, index, &value); return (int)value; } @@ -14195,19 +14227,27 @@ HLSLMaxRecordsAttr *ValidateMaxRecordsAttributes(Sema &S, Decl *D, Expr *ArgExpr = A.getArgAsExpr(0); IntegerLiteral *LiteralInt = dyn_cast(ArgExpr->IgnoreParenCasts()); + clang::SourceLocation Loc = {}; - if (ExistingMRSWA || ExistingMRA->getMaxCount() != LiteralInt->getValue()) { - clang::SourceLocation Loc = ExistingMRA ? ExistingMRA->getLocation() - : ExistingMRSWA->getLocation(); + if (ExistingMRSWA) { + Loc = ExistingMRSWA->getLocation(); + } else if (ExistingMRA) { + uint32_t maxCount = + GetIntConstAttrArg(S.getASTContext(), ExistingMRA->getMaxCount(), 0); + if (LiteralInt->getValue().getLimitedValue() != maxCount) + Loc = ExistingMRA->getLocation(); + } + + if (Loc.isValid()) { S.Diag(A.getLoc(), diag::err_hlsl_maxrecord_attrs_on_same_arg); S.Diag(Loc, diag::note_conflicting_attribute); return nullptr; } } - return ::new (S.Context) - HLSLMaxRecordsAttr(A.getRange(), S.Context, ValidateAttributeIntArg(S, A), - A.getAttributeSpellingListIndex()); + return ::new (S.Context) HLSLMaxRecordsAttr( + A.getRange(), S.Context, ValidateAttributeIntArgExpr(S, A), + A.getAttributeSpellingListIndex()); } // This function validates the wave size attribute in a stand-alone way, @@ -14416,19 +14456,20 @@ void Sema::DiagnoseCoherenceMismatch(const Expr *SrcExpr, QualType TargetType, } } -void ValidateDispatchGridValues(DiagnosticsEngine &Diags, - const AttributeList &A, Attr *declAttr) { +void ValidateDispatchGridValues(Sema &S, const AttributeList &A, + Attr *declAttr) { unsigned x = 1, y = 1, z = 1; + ASTContext &astContext = S.getASTContext(); if (HLSLNodeDispatchGridAttr *pA = dyn_cast(declAttr)) { - x = pA->getX(); - y = pA->getY(); - z = pA->getZ(); + x = GetIntConstAttrArg(astContext, pA->getX(), 1); + y = GetIntConstAttrArg(astContext, pA->getY(), 1); + z = GetIntConstAttrArg(astContext, pA->getZ(), 1); } else if (HLSLNodeMaxDispatchGridAttr *pA = dyn_cast(declAttr)) { - x = pA->getX(); - y = pA->getY(); - z = pA->getZ(); + x = GetIntConstAttrArg(astContext, pA->getX(), 1); + y = GetIntConstAttrArg(astContext, pA->getY(), 1); + z = GetIntConstAttrArg(astContext, pA->getZ(), 1); } else { llvm_unreachable("ValidateDispatchGridValues() called for wrong attribute"); } @@ -14437,26 +14478,26 @@ void ValidateDispatchGridValues(DiagnosticsEngine &Diags, // If a component is out of range, we reset it to 0 to avoid also generating // a secondary error if the product would be out of range if (x < 1 || x > MaxComponentValue) { - Diags.Report(A.getArgAsExpr(0)->getExprLoc(), - diag::err_hlsl_dispatchgrid_component) + S.Diags.Report(A.getArgAsExpr(0)->getExprLoc(), + diag::err_hlsl_dispatchgrid_component) << A.getName() << "X" << A.getRange(); x = 0; } if (y < 1 || y > MaxComponentValue) { - Diags.Report(A.getArgAsExpr(1)->getExprLoc(), - diag::err_hlsl_dispatchgrid_component) + S.Diags.Report(A.getArgAsExpr(1)->getExprLoc(), + diag::err_hlsl_dispatchgrid_component) << A.getName() << "Y" << A.getRange(); y = 0; } if (z < 1 || z > MaxComponentValue) { - Diags.Report(A.getArgAsExpr(2)->getExprLoc(), - diag::err_hlsl_dispatchgrid_component) + S.Diags.Report(A.getArgAsExpr(2)->getExprLoc(), + diag::err_hlsl_dispatchgrid_component) << A.getName() << "Z" << A.getRange(); z = 0; } uint64_t product = (uint64_t)x * (uint64_t)y * (uint64_t)z; if (product > MaxProductValue) - Diags.Report(A.getLoc(), diag::err_hlsl_dispatchgrid_product) + S.Diags.Report(A.getLoc(), diag::err_hlsl_dispatchgrid_product) << A.getName() << A.getRange(); } @@ -14616,7 +14657,8 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A, case AttributeList::AT_HLSLNodeId: declAttr = ::new (S.Context) HLSLNodeIdAttr( A.getRange(), S.Context, ValidateAttributeStringArg(S, A, nullptr, 0), - ValidateAttributeIntArg(S, A, 1), A.getAttributeSpellingListIndex()); + ValidateAttributeIntArgExpr(S, A, 1), + A.getAttributeSpellingListIndex()); break; case AttributeList::AT_HLSLNodeTrackRWInputSharing: declAttr = ::new (S.Context) HLSLNodeTrackRWInputSharingAttr( @@ -14719,18 +14761,20 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A, A.getAttributeSpellingListIndex()); break; case AttributeList::AT_HLSLNumThreads: { - int X = ValidateAttributeIntArg(S, A, 0); - int Y = ValidateAttributeIntArg(S, A, 1); - int Z = ValidateAttributeIntArg(S, A, 2); - int N = X * Y * Z; + int64_t X = 1, Y = 1, Z = 1; + auto *arg0 = ValidateAttributeIntArgExpr(S, A, 0, &X, true); + auto *arg1 = ValidateAttributeIntArgExpr(S, A, 1, &Y, true); + auto *arg2 = ValidateAttributeIntArgExpr(S, A, 2, &Z, true); + int64_t N = X * Y * Z; if (N > 0 && N <= 1024) { - auto numThreads = ::new (S.Context) HLSLNumThreadsAttr( - A.getRange(), S.Context, X, Y, Z, A.getAttributeSpellingListIndex()); + auto numThreads = ::new (S.Context) + HLSLNumThreadsAttr(A.getRange(), S.Context, arg0, arg1, arg2, + A.getAttributeSpellingListIndex()); declAttr = numThreads; } else { // If the number of threads is invalid, diagnose and drop the attribute. S.Diags.Report(A.getLoc(), diag::warn_hlsl_numthreads_group_size) - << N << X << Y << Z << A.getRange(); + << (int)N << (int)X << (int)Y << (int)Z << A.getRange(); return; } break; @@ -14832,31 +14876,37 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A, case AttributeList::AT_HLSLNodeShareInputOf: declAttr = ::new (S.Context) HLSLNodeShareInputOfAttr( A.getRange(), S.Context, ValidateAttributeStringArg(S, A, nullptr, 0), - ValidateAttributeIntArg(S, A, 1), A.getAttributeSpellingListIndex()); + ValidateAttributeIntArgExpr(S, A, 1), + A.getAttributeSpellingListIndex()); break; case AttributeList::AT_HLSLNodeDispatchGrid: declAttr = ::new (S.Context) HLSLNodeDispatchGridAttr( - A.getRange(), S.Context, ValidateAttributeIntArg(S, A), - ValidateAttributeIntArg(S, A, 1), ValidateAttributeIntArg(S, A, 2), + A.getRange(), S.Context, ValidateAttributeIntArgExpr(S, A), + ValidateAttributeIntArgExpr(S, A, 1), + ValidateAttributeIntArgExpr(S, A, 2), A.getAttributeSpellingListIndex()); - ValidateDispatchGridValues(S.Diags, A, declAttr); + ValidateDispatchGridValues(S, A, declAttr); break; case AttributeList::AT_HLSLNodeMaxDispatchGrid: declAttr = ::new (S.Context) HLSLNodeMaxDispatchGridAttr( - A.getRange(), S.Context, ValidateAttributeIntArg(S, A), - ValidateAttributeIntArg(S, A, 1), ValidateAttributeIntArg(S, A, 2), + A.getRange(), S.Context, ValidateAttributeIntArgExpr(S, A), + ValidateAttributeIntArgExpr(S, A, 1), + ValidateAttributeIntArgExpr(S, A, 2), A.getAttributeSpellingListIndex()); - ValidateDispatchGridValues(S.Diags, A, declAttr); + ValidateDispatchGridValues(S, A, declAttr); break; - case AttributeList::AT_HLSLNodeMaxRecursionDepth: + case AttributeList::AT_HLSLNodeMaxRecursionDepth: { + int64_t maxRecursionDepth = 0; declAttr = ::new (S.Context) HLSLNodeMaxRecursionDepthAttr( - A.getRange(), S.Context, ValidateAttributeIntArg(S, A), + A.getRange(), S.Context, + ValidateAttributeIntArgExpr(S, A, 0, &maxRecursionDepth, true), A.getAttributeSpellingListIndex()); - if (cast(declAttr)->getCount() > 32) + if (maxRecursionDepth > 32) S.Diags.Report(declAttr->getLocation(), diag::err_hlsl_maxrecursiondepth_exceeded) << declAttr->getRange(); break; + } default: Handled = false; break; // SPIRV Change: was return; @@ -16258,8 +16308,13 @@ void hlsl::CustomPrintHLSLAttr(const clang::Attr *A, llvm::raw_ostream &Out, Attr *noconst = const_cast(A); HLSLNumThreadsAttr *ACast = static_cast(noconst); Indent(Indentation, Out); - Out << "[numthreads(" << ACast->getX() << ", " << ACast->getY() << ", " - << ACast->getZ() << ")]\n"; + Out << "[numthreads("; + ACast->getX()->printPretty(Out, nullptr, Policy); + Out << ", "; + ACast->getY()->printPretty(Out, nullptr, Policy); + Out << ", "; + ACast->getZ()->printPretty(Out, nullptr, Policy); + Out << ")]\n"; break; } @@ -16491,11 +16546,16 @@ void hlsl::CustomPrintHLSLAttr(const clang::Attr *A, llvm::raw_ostream &Out, Attr *noconst = const_cast(A); HLSLNodeIdAttr *ACast = static_cast(noconst); Indent(Indentation, Out); - if (ACast->getArrayIndex() > 0) - Out << "[NodeId(\"" << ACast->getName() << "\"," << ACast->getArrayIndex() - << ")]\n"; - else - Out << "[NodeId(\"" << ACast->getName() << "\")]\n"; + Out << "[NodeId(\"" << ACast->getName(); + if (auto *lit = dyn_cast(ACast->getArrayIndex())) { + if (!lit->getValue().isStrictlyPositive()) { + Out << "\")]\n"; + break; + } + } + Out << "\","; + ACast->getArrayIndex()->printPretty(Out, nullptr, Policy); + Out << ")]\n"; break; } @@ -16513,11 +16573,16 @@ void hlsl::CustomPrintHLSLAttr(const clang::Attr *A, llvm::raw_ostream &Out, HLSLNodeShareInputOfAttr *ACast = static_cast(noconst); Indent(Indentation, Out); - if (ACast->getArrayIndex() > 0) - Out << "[NodeShareInputOf(\"" << ACast->getName() << "\"," - << ACast->getArrayIndex() << ")]\n"; - else - Out << "[NodeShareInputOf(\"" << ACast->getName() << "\")]\n"; + Out << "[NodeShareInputOf(\"" << ACast->getName(); + if (auto *lit = dyn_cast(ACast->getArrayIndex())) { + if (!lit->getValue().isStrictlyPositive()) { + Out << "\")]\n"; + break; + } + } + Out << "\","; + ACast->getArrayIndex()->printPretty(Out, nullptr, Policy); + Out << ")]\n"; break; } @@ -17162,8 +17227,11 @@ void DiagnoseNodeEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName, // thread group size is (1,1,1) if (NodeLaunchTy == DXIL::NodeLaunchType::Thread) { if (auto NumThreads = FD->getAttr()) { - if (NumThreads->getX() != 1 || NumThreads->getY() != 1 || - NumThreads->getZ() != 1) { + ASTContext &astContext = S.getASTContext(); + uint32_t x = GetIntConstAttrArg(astContext, NumThreads->getX(), 1); + uint32_t y = GetIntConstAttrArg(astContext, NumThreads->getY(), 1); + uint32_t z = GetIntConstAttrArg(astContext, NumThreads->getZ(), 1); + if (x != 1 || y != 1 || z != 1) { S.Diags.Report(NumThreads->getLocation(), diag::err_hlsl_wg_thread_launch_group_size) << NumThreads->getRange(); diff --git a/tools/clang/test/CodeGenSPIRV/vk.spec-constant.attributes.derivative-group.hlsl b/tools/clang/test/CodeGenSPIRV/vk.spec-constant.attributes.derivative-group.hlsl new file mode 100644 index 0000000000..cde93df5b6 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/vk.spec-constant.attributes.derivative-group.hlsl @@ -0,0 +1,35 @@ +// RUN: %dxc -T cs_6_6 -E main -fspv-extension=SPV_NV_compute_shader_derivatives -fcgl %s -spirv 2>&1 | FileCheck -check-prefix=CHECK-LINEAR %s +// RUN: %dxc -T cs_6_6 -E main -fspv-extension=SPV_NV_compute_shader_derivatives -fcgl -DQUADS %s -spirv 2>&1 | FileCheck -check-prefix=CHECK-QUADS %s + +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +// CHECK-LINEAR: OpCapability ComputeDerivativeGroupLinearKHR +// CHECK-LINEAR: OpExecutionMode %{{[^ ]*}} DerivativeGroupLinearKHR +// CHECK-QUADS: OpCapability ComputeDerivativeGroupQuadsKHR +// CHECK-QUADS: OpExecutionMode %{{[^ ]*}} DerivativeGroupQuadsKHR + +SamplerState ss : register(s2); +SamplerComparisonState scs; + +RWStructuredBuffer o; +Texture1D t1; + +#ifdef QUADS +[[vk::constant_id(0)]] +const uint NumThreadsX = 2; +[[vk::constant_id(1)]] +const uint NumThreadsY = 2; +#else +[[vk::constant_id(0)]] +const uint NumThreadsX = 24; +[[vk::constant_id(1)]] +const uint NumThreadsY = 1; +#endif + +[numthreads(NumThreadsX,NumThreadsY,1)] +void main(uint3 id : SV_GroupThreadID) +{ + o[0] = t1.Sample(ss, 1); +} + diff --git a/tools/clang/test/CodeGenSPIRV/vk.spec-constant.attributes.hlsl b/tools/clang/test/CodeGenSPIRV/vk.spec-constant.attributes.hlsl new file mode 100644 index 0000000000..d74fa554e5 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/vk.spec-constant.attributes.hlsl @@ -0,0 +1,53 @@ +// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s | FileCheck %s + +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +struct InputPayload { + uint grid : SV_DispatchGrid; +}; + +struct OutputPayload { + uint foo; +}; + +[[vk::constant_id(0)]] +const uint MaxPayloads = 1; +[[vk::constant_id(1)]] +const uint WorkgroupSizeX = 1; +[[vk::constant_id(2)]] +const uint ShaderIndex = 0; +[[vk::constant_id(3)]] +const uint NumThreadsX = 512; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(NumThreadsX, 1, 1)] +[NodeDispatchGrid(WorkgroupSizeX, 1, 1)] + +void main(const uint svGroupIndex : SV_GroupIndex, + DispatchNodeInputRecord inputRecord, + [NodeID("main", ShaderIndex)] + [MaxRecords(MaxPayloads)] + NodeOutput nodeOutput) { + ThreadNodeOutputRecords outRec = nodeOutput.GetThreadNodeOutputRecords(1); + outRec.OutputComplete(); +} + +// CHECK: OpExecutionModeId %{{[_0-9A-Za-z]*}} LocalSizeId [[NUMTHREADSX:%[_0-9A-Za-z]*]] [[U1:%[_0-9A-Za-z]*]] [[U1]] +// CHECK: OpExecutionModeId %{{[_0-9A-Za-z]*}} StaticNumWorkgroupsAMDX [[WGSIZEX:%[_0-9A-Za-z]*]] [[U1]] [[U1]] +// CHECK: OpDecorate [[MAXPAYLOADS:%[_0-9A-Za-z]*]] SpecId 0 +// CHECK: OpDecorate [[WGSIZEX]] SpecId 1 +// CHECK: OpDecorate [[SHADERINDEX:%[_0-9A-Za-z]*]] SpecId 2 +// CHECK: OpDecorate [[NUMTHREADSX]] SpecId 3 +// CHECK: OpDecorateId %{{[_0-9A-Za-z]*}} NodeMaxPayloadsAMDX [[U1:%[_0-9A-Za-z]*]] +// CHECK-DAG: OpDecorateId %{{[_0-9A-Za-z]*}} PayloadNodeBaseIndexAMDX [[SHADERINDEX]] +// CHECK-DAG: OpDecorateId %{{[_0-9A-Za-z]*}} NodeMaxPayloadsAMDX [[MAXPAYLOADS]] +// CHECK: [[UINT:%[_0-9A-Za-z]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[U0:%[_0-9A-Za-z]*]] = OpConstant [[UINT]] 0 +// CHECK-DAG: [[U1]] = OpConstant [[UINT]] 1 +// CHECK-DAG: [[MAXPAYLOADS:%[_0-9A-Za-z]*]] = OpSpecConstant [[UINT]] 1 +// CHECK-DAG: [[WGSIZEX:%[_0-9A-Za-z]*]] = OpSpecConstant [[UINT]] 1 +// CHECK-DAG: [[SHADERINDEX:%[_0-9A-Za-z]*]] = OpSpecConstant [[UINT]] 0 +// CHECK-DAG: [[NUMTHREADSX:%[_0-9A-Za-z]*]] = OpSpecConstant [[UINT]] 512 + diff --git a/tools/clang/test/HLSLFileCheck/hlsl/workgraph/sparsenodes.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/workgraph/sparsenodes.hlsl index 10335ee864..d5645e2849 100644 --- a/tools/clang/test/HLSLFileCheck/hlsl/workgraph/sparsenodes.hlsl +++ b/tools/clang/test/HLSLFileCheck/hlsl/workgraph/sparsenodes.hlsl @@ -10,34 +10,40 @@ // AST: FunctionDecl {{.*}} node_1_0 'void (NodeOutputArray)' // AST-NEXT: ParmVarDecl {{.*}} used OutputArray_1_0 'NodeOutputArray':'NodeOutputArray' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 31 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 31 // AST-NEXT: HLSLNodeArraySizeAttr {{.*}} 129 // AST-NEXT: HLSLAllowSparseNodesAttr // AST: FunctionDecl {{.*}} node_1_1 'void (NodeOutputArray)' // AST-NEXT: ParmVarDecl {{.*}} used OutputArray_1_1 'NodeOutputArray':'NodeOutputArray' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 37 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 37 // AST-NEXT: HLSLUnboundedSparseNodesAttr // AST: FunctionDecl {{.*}} node_1_2 'void (NodeOutput)' // AST-NEXT: ParmVarDecl {{.*}} used Output_1_2 'NodeOutput':'NodeOutput' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 47 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 47 // AST-NEXT: HLSLAllowSparseNodesAttr // AST: FunctionDecl {{.*}} node_2_0 'void (EmptyNodeOutputArray)' // AST-NEXT: ParmVarDecl {{.*}} used OutputArray_2_0 'EmptyNodeOutputArray' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 41 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 41 // AST-NEXT: HLSLNodeArraySizeAttr {{.*}} 131 // AST-NEXT: HLSLAllowSparseNodesAttr // AST: FunctionDecl {{.*}} node_2_1 'void (EmptyNodeOutputArray)' // AST-NEXT: ParmVarDecl {{.*}} used OutputArray_2_1 'EmptyNodeOutputArray' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 43 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 43 // AST-NEXT: HLSLUnboundedSparseNodesAttr // AST: FunctionDecl {{.*}} node_2_2 'void (EmptyNodeOutput)' // AST-NEXT: ParmVarDecl {{.*}} used Output_2_2 'EmptyNodeOutput' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 53 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 53 // AST-NEXT: HLSLAllowSparseNodesAttr // ==== -fcgl Metadata Checks ==== diff --git a/tools/clang/test/HLSLFileCheck/shader_targets/nodes/NodeOutput.hlsl b/tools/clang/test/HLSLFileCheck/shader_targets/nodes/NodeOutput.hlsl index 9343ad9831..6c69441468 100644 --- a/tools/clang/test/HLSLFileCheck/shader_targets/nodes/NodeOutput.hlsl +++ b/tools/clang/test/HLSLFileCheck/shader_targets/nodes/NodeOutput.hlsl @@ -23,20 +23,28 @@ struct MY_INPUT_RECORD // CHECK:FunctionDecl 0x{{.*}} myFancyNode 'void (GroupNodeInputRecords, NodeOutput, NodeOutput, NodeOutputArray, EmptyNodeOutput)' // CHECK-NEXT:ParmVarDecl 0x{{.*}} myInput 'GroupNodeInputRecords':'GroupNodeInputRecords' -// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} 4 +// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 4 // CHECK-NEXT: ParmVarDecl 0x{{.*}} myFascinatingNode 'NodeOutput':'NodeOutput' -// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} 4 +// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 4 // CHECK-NEXT: ParmVarDecl 0x{{.*}} myRecords 'NodeOutput':'NodeOutput' -// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} 4 -// CHECK-NEXT: HLSLNodeIdAttr 0x{{.*}} "myNiftyNode" 3 +// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 4 +// CHECK-NEXT: HLSLNodeIdAttr 0x{{.*}} "myNiftyNode" +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 3 // CHECK-NEXT: ParmVarDecl 0x{{.*}} col:65 myMaterials 'NodeOutputArray' // CHECK-NEXT:HLSLNodeArraySizeAttr 0x{{.*}} 63 // CHECK-NEXT:HLSLAllowSparseNodesAttr 0x{{.*}} // CHECK-NEXT:HLSLMaxRecordsSharedWithAttr 0x{{.*}} myRecords // CHECK-NEXT:ParmVarDecl 0x{{.*}} myProgressCounter 'EmptyNodeOutput' -// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} 20 +// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 20 // CHECK-NEXT: CompoundStmt 0x -// CHECK-NEXT: HLSLNumThreadsAttr 0x{{.*}} 4 5 6 +// CHECK-NEXT: HLSLNumThreadsAttr 0x{{.*}} +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 4 +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 5 +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 6 // CHECK-NEXT: HLSLNodeLaunchAttr 0x{{.*}} "coalescing" // CHECK-NEXT: HLSLShaderAttr 0x{{.*}} "node" [Shader("node")] diff --git a/tools/clang/test/SemaHLSL/hlsl/objects/NodeObjects/node-objects.hlsl b/tools/clang/test/SemaHLSL/hlsl/objects/NodeObjects/node-objects.hlsl index 2fd316bba6..5d2a684178 100644 --- a/tools/clang/test/SemaHLSL/hlsl/objects/NodeObjects/node-objects.hlsl +++ b/tools/clang/test/SemaHLSL/hlsl/objects/NodeObjects/node-objects.hlsl @@ -40,8 +40,14 @@ struct RECORD // AST: `-DeclRefExpr 0x{{.+}} 'DispatchNodeInputRecord':'DispatchNodeInputRecord' lvalue ParmVar 0x[[Param]] 'input' 'DispatchNodeInputRecord':'DispatchNodeInputRecord' // attributes. // AST: |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" -// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} 64 1 1 -// AST: |-HLSLNumThreadsAttr 0x{{.+}} 1024 1 1 +// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 64 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 +// AST: |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1024 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: `-HLSLShaderAttr 0x{{.+}} "node" @@ -67,8 +73,14 @@ void node_DispatchNodeInputRecord(DispatchNodeInputRecord input) // AST: `-DeclRefExpr 0x{{.+}} 'RWDispatchNodeInputRecord':'RWDispatchNodeInputRecord' lvalue ParmVar 0x[[Param]] 'input' 'RWDispatchNodeInputRecord':'RWDispatchNodeInputRecord' // attributes. // AST: |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" -// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} 16 1 1 -// AST: |-HLSLNumThreadsAttr 0x{{.+}} 1024 1 1 +// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 16 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 +// AST: |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1024 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: `-HLSLShaderAttr 0x{{.+}} "node" [Shader("node")] @@ -85,7 +97,8 @@ void node_RWDispatchNodeInputRecord(RWDispatchNodeInputRecord input) // AST: FunctionDecl 0x{{.+}} node_GroupNodeInputRecords 'void (GroupNodeInputRecords)' // AST: |-ParmVarDecl 0x[[Param:[0-9a-f]+]] col:81 used inputs 'GroupNodeInputRecords':'GroupNodeInputRecords' -// AST: | `-HLSLMaxRecordsAttr 0x{{.+}} 256 +// AST: | `-HLSLMaxRecordsAttr 0x{{.+}} +// AST: `-IntegerLiteral 0x{{.+}} 256 // call to wrapper // AST: `-CallExpr 0x{{.+}} 'GroupNodeInputRecords':'GroupNodeInputRecords' // AST: |-ImplicitCastExpr 0x{{.+}} 'GroupNodeInputRecords (*)(GroupNodeInputRecords)' @@ -94,7 +107,10 @@ void node_RWDispatchNodeInputRecord(RWDispatchNodeInputRecord input) // AST: `-DeclRefExpr 0x{{.+}} 'GroupNodeInputRecords':'GroupNodeInputRecords' lvalue ParmVar 0x[[Param]] 'inputs' 'GroupNodeInputRecords':'GroupNodeInputRecords' // attributes. // AST: |-HLSLNodeIsProgramEntryAttr 0x{{.+}} -// AST: |-HLSLNumThreadsAttr 0x{{.+}} 1024 1 1 +// AST: |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1024 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: |-HLSLNodeLaunchAttr 0x{{.+}} "coalescing" // AST: `-HLSLShaderAttr 0x{{.+}} "node" @@ -112,7 +128,8 @@ void node_GroupNodeInputRecords([MaxRecords(256)] GroupNodeInputRecords // AST: FunctionDecl 0x{{.+}} node_RWGroupNodeInputRecords 'void (RWGroupNodeInputRecords)' // AST: |-ParmVarDecl 0x[[Param:[0-9a-f]+]] col:84 used input2 'RWGroupNodeInputRecords':'RWGroupNodeInputRecords' -// AST: | `-HLSLMaxRecordsAttr 0x{{.+}} 4 +// AST: | `-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | `-IntegerLiteral 0x{{.+}} 4 // call to wrapper // AST: CallExpr 0x{{.+}} 'RWGroupNodeInputRecords':'RWGroupNodeInputRecords' // AST: |-ImplicitCastExpr 0x{{.+}} 'RWGroupNodeInputRecords (*)(RWGroupNodeInputRecords)' @@ -121,7 +138,10 @@ void node_GroupNodeInputRecords([MaxRecords(256)] GroupNodeInputRecords // AST: | `-DeclRefExpr 0x{{.+}} 'RWGroupNodeInputRecords':'RWGroupNodeInputRecords' lvalue ParmVar 0x[[Param]] 'input2' 'RWGroupNodeInputRecords':'RWGroupNodeInputRecords' // attributes. // AST: |-HLSLNodeLaunchAttr 0x{{.+}} "coalescing" -// AST: |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 +// AST: |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: `-HLSLShaderAttr 0x{{.+}} "node" struct RECORD2 @@ -192,7 +212,10 @@ void node_RWThreadNodeInputRecord(RWThreadNodeInputRecord input) // AST: `-ImplicitCastExpr 0x{{.+}} 'EmptyNodeInput' // AST: `-DeclRefExpr 0x{{.+}} 'EmptyNodeInput' lvalue ParmVar 0x[[Param]] 'input' 'EmptyNodeInput' // attributes. -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 2 1 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 2 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: | |-HLSLNodeIsProgramEntryAttr // AST: | |-HLSLNodeLaunchAttr 0x{{.+}} "coalescing" // AST: | `-HLSLShaderAttr 0x{{.+}}> "node" @@ -219,8 +242,14 @@ void node_EmptyNodeInput(EmptyNodeInput input) // AST: `-DeclRefExpr 0x{{.+}} 'NodeOutput':'NodeOutput' lvalue ParmVar 0x[[Param]] 'output3' 'NodeOutput':'NodeOutput' // attributes. // AST: | |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" -// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} 32 1 1 -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 1024 1 1 +// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 32 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1024 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 // AST: | `-HLSLShaderAttr 0x{{.+}} "node" [Shader("node")] @@ -239,7 +268,8 @@ void node_NodeOutput(NodeOutput output3) // EmptyNodeOutput // AST: FunctionDecl 0x{{.+}} node_EmptyNodeOutput 'void (EmptyNodeOutput)' // AST: | |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} col:35 used loadStressChild 'EmptyNodeOutput' -// AST: | | `-HLSLMaxRecordsAttr 0x{{.+}} 12 +// AST: | | `-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | | `-IntegerLiteral 0x{{.+}} 12 // call to wrapper // AST: CallExpr 0x{{.+}} 'EmptyNodeOutput':'EmptyNodeOutput' // AST: |-ImplicitCastExpr 0x{{.+}} 'EmptyNodeOutput (*)(EmptyNodeOutput)' @@ -247,8 +277,14 @@ void node_NodeOutput(NodeOutput output3) // AST: `-ImplicitCastExpr 0x{{.+}} 'EmptyNodeOutput' // AST: `-DeclRefExpr 0x{{.+}} 'EmptyNodeOutput' lvalue ParmVar 0x[[Param]] 'loadStressChild' 'EmptyNodeOutput' // attributes. -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 -// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 +// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 // AST: | `-HLSLShaderAttr 0x{{.+}} "node" void loadStressEmptyRecWorker( EmptyNodeOutput outputNode) @@ -269,7 +305,8 @@ void node_EmptyNodeOutput( // NodeOutputArray // AST: FunctionDecl 0x{{.+}} node_NodeOutputArray 'void (NodeOutputArray)' // AST: | |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} col:30 used OutputArray_1_0 'NodeOutputArray':'NodeOutputArray' -// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} 31 +// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | | | `-IntegerLiteral 0x{{.+}} 31 // AST: | | |-HLSLNodeArraySizeAttr 0x{{.+}} 129 // AST: | | `-HLSLAllowSparseNodesAttr 0x{{.+}} // call to wrapper @@ -279,8 +316,14 @@ void node_EmptyNodeOutput( // AST: `-ImplicitCastExpr 0x{{.+}} 'NodeOutputArray':'NodeOutputArray' // AST: `-DeclRefExpr 0x{{.+}} 'NodeOutputArray':'NodeOutputArray' lvalue ParmVar 0x[[Param]] 'OutputArray_1_0' 'NodeOutputArray':'NodeOutputArray' // attributes. -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 -// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 +// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 // AST: | |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // AST: | `-HLSLShaderAttr 0x{{.+}} "node" @@ -304,7 +347,8 @@ void node_NodeOutputArray( // EmptyNodeOutputArray // AST: FunctionDecl 0x{{.+}} node_EmptyNodeOutputArray 'void (EmptyNodeOutputArray)' // AST: | |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} col:64 used EmptyOutputArray 'EmptyNodeOutputArray' -// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} 64 +// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | | | `-IntegerLiteral 0x{{.+}} 64 // AST: | | `-HLSLNodeArraySizeAttr 0x{{.+}} 128 // call to wrapper // AST: CallExpr 0x{{.+}} 'EmptyNodeOutputArray':'EmptyNodeOutputArray' @@ -313,8 +357,14 @@ void node_NodeOutputArray( // AST: `-ImplicitCastExpr 0x{{.+}} 'EmptyNodeOutputArray' // AST: `-DeclRefExpr 0x{{.+}} 'EmptyNodeOutputArray' lvalue ParmVar 0x[[Param]] 'EmptyOutputArray' 'EmptyNodeOutputArray' // attributes. -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 128 1 1 -// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 128 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 +// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 // AST: | |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // AST: | `-HLSLShaderAttr 0x{{.+}} "node" @@ -335,7 +385,8 @@ void node_EmptyNodeOutputArray( // GroupNodeOutputRecords // AST: FunctionDecl 0x{{.+}} node_GroupNodeOutputRecords 'void (NodeOutputArray)' // AST: | |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} col:68 used OutputArray 'NodeOutputArray':'NodeOutputArray' -// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} 64 +// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | | | `-IntegerLiteral 0x{{.+}} 64 // AST: | | `-HLSLNodeArraySizeAttr 0x{{.+}} 128 // call to wrapper // AST: CallExpr 0x{{.+}} 'GroupNodeOutputRecords':'GroupNodeOutputRecords' @@ -344,8 +395,14 @@ void node_EmptyNodeOutputArray( // AST: `-ImplicitCastExpr 0x{{.+}} 'GroupNodeOutputRecords':'GroupNodeOutputRecords' // AST: `-DeclRefExpr 0x{{.+}} 'GroupNodeOutputRecords':'GroupNodeOutputRecords' lvalue Var 0x{{.+}} 'outRec' 'GroupNodeOutputRecords':'GroupNodeOutputRecords' // attributes. -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 128 1 1 -// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 128 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 +// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 // AST: | |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // AST: | `-HLSLShaderAttr 0x{{.+}} "node" [Shader("node")] @@ -366,7 +423,8 @@ void node_GroupNodeOutputRecords( // ThreadNodeOutputRecords // AST: FunctionDecl 0x{{.+}} node_ThreadNodeOutputRecords 'void (NodeOutputArray)' // AST: |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} used OutputArray_1_0 'NodeOutputArray':'NodeOutputArray' -// AST: | |-HLSLMaxRecordsAttr 0x{{.+}} 31 +// AST: | |-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | | `-IntegerLiteral 0x{{.+}} 31 // AST: | |-HLSLNodeArraySizeAttr 0x{{.+}} 129 // AST: | `-HLSLAllowSparseNodesAttr 0x{{.+}} // call to wrapper @@ -376,8 +434,14 @@ void node_GroupNodeOutputRecords( // AST: `-ImplicitCastExpr 0x{{.+}} 'ThreadNodeOutputRecords':'ThreadNodeOutputRecords' // AST: `-DeclRefExpr 0x{{.+}} 'ThreadNodeOutputRecords':'ThreadNodeOutputRecords' lvalue Var 0x{{.+}} 'outRec' 'ThreadNodeOutputRecords':'ThreadNodeOutputRecords' // attributes. -// AST: |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 -// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// AST: |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 +// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // AST: `-HLSLShaderAttr 0x{{.+}} "node" diff --git a/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-EmptyNodeOutputArrayTypes.hlsl b/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-EmptyNodeOutputArrayTypes.hlsl index aa5a7123f6..278161d65d 100644 --- a/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-EmptyNodeOutputArrayTypes.hlsl +++ b/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-EmptyNodeOutputArrayTypes.hlsl @@ -54,7 +54,8 @@ void node_2_0( // CHECK: `-FunctionDecl 0x{{.+}} node_2_0 'void (EmptyNodeOutputArray)' // CHECK-NEXT: |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} used OutputArray_2_0 'EmptyNodeOutputArray' -// CHECK-NEXT: | |-HLSLMaxRecordsAttr 0x{{.+}} 41 +// CHECK-NEXT: | |-HLSLMaxRecordsAttr 0x{{.+}} +// CHECK-NEXT: | | `-IntegerLiteral 0x{{.+}} 'literal int' 41 // CHECK-NEXT: | |-HLSLNodeArraySizeAttr 0x{{.+}} 131 // CHECK-NEXT: | `-HLSLAllowSparseNodesAttr 0x{{.+}} // CHECK-NEXT: |-CompoundStmt 0x{{.+}} @@ -68,7 +69,13 @@ void node_2_0( // CHECK-NEXT: | | `-IntegerLiteral 0x{{.+}}{{.+}} 'literal int' 1 // CHECK-NEXT: | `-ImplicitCastExpr 0x{{.+}} 'unsigned int' // CHECK-NEXT: | `-IntegerLiteral 0x{{.+}} 'literal int' 10 -// CHECK-NEXT: |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 -// CHECK-NEXT: |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// CHECK-NEXT: |-HLSLNumThreadsAttr 0x{{.+}} +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | `-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: |-HLSLNodeDispatchGridAttr 0x{{.+}} +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | `-IntegerLiteral 0x{{.+}} 'literal int' 1 // CHECK-NEXT: |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // CHECK-NEXT: `-HLSLShaderAttr 0x{{.+}} "node" diff --git a/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-NodeOutputArrayTypes.hlsl b/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-NodeOutputArrayTypes.hlsl index 6d381c5953..5bf163edbf 100644 --- a/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-NodeOutputArrayTypes.hlsl +++ b/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-NodeOutputArrayTypes.hlsl @@ -152,7 +152,8 @@ void node_1_1( // CHECK:`-FunctionDecl 0x{{.+}} line:16:6 node_1_1 'void (NodeOutputArray)' // CHECK-NEXT: |-ParmVarDecl 0x[[ParmVar:[0-9a-f]+]] col:30 used OutputArray_1_1 'NodeOutputArray':'NodeOutputArray' -// CHECK-NEXT: | |-HLSLMaxRecordsAttr 0x{{.+}} 37 +// CHECK-NEXT: | |-HLSLMaxRecordsAttr 0x{{.+}} +// CHECK-NEXT: | | `-IntegerLiteral 0x{{.+}} 'literal int' 37 // CHECK-NEXT: | `-HLSLUnboundedSparseNodesAttr 0x{{.+}} // CHECK-NEXT: |-CompoundStmt 0x{{.+}} // CHECK-NEXT: | |-DeclStmt 0x{{.+}} @@ -170,7 +171,13 @@ void node_1_1( // CHECK-NEXT: | `-CXXMemberCallExpr 0x{{.+}} 'void' // CHECK-NEXT: | `-MemberExpr 0x{{.+}} '' .OutputComplete 0x[[OutComplete]] // CHECK-NEXT: | `-DeclRefExpr 0x{{.+}} 'ThreadNodeOutputRecords':'ThreadNodeOutputRecords' lvalue Var 0x[[OutRec]] 'outRec' 'ThreadNodeOutputRecords':'ThreadNodeOutputRecords' -// CHECK-NEXT: |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 -// CHECK-NEXT: |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// CHECK-NEXT: |-HLSLNumThreadsAttr 0x{{.+}} +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | `-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: |-HLSLNodeDispatchGridAttr 0x{{.+}} +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | `-IntegerLiteral 0x{{.+}} 'literal int' 1 // CHECK-NEXT: |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // CHECK-NEXT: `-HLSLShaderAttr 0x{{.+}} "node"