From e4e71e4ed29e38ac163be2f29b10e8ea545fb94b Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Wed, 24 Jan 2024 00:18:17 +0200 Subject: [PATCH] Make declarations function global. --- include/clad/Differentiator/VisitorBase.h | 11 +++- lib/Differentiator/ReverseModeVisitor.cpp | 66 +++++++++++------------ lib/Differentiator/VisitorBase.cpp | 15 +++++- 3 files changed, 54 insertions(+), 38 deletions(-) diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index f85085a86..9457232c1 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -294,7 +294,8 @@ namespace clad { clang::Expr* Init = nullptr, bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr, clang::VarDecl::InitializationStyle IS = - clang::VarDecl::InitializationStyle::CInit); + clang::VarDecl::InitializationStyle::CInit, + clang::Scope* scope=nullptr); /// Builds variable declaration to be used inside the derivative /// body. /// \param[in] Type The type of variable declaration to build. @@ -311,6 +312,14 @@ namespace clad { clang::TypeSourceInfo* TSI = nullptr, clang::VarDecl::InitializationStyle IS = clang::VarDecl::InitializationStyle::CInit); + /// Builds variable declaration to be used inside the derivative + /// body in the derivative function global scope. + clang::VarDecl* + BuildGlobalVarDecl(clang::QualType Type, llvm::StringRef prefix = "_t", + clang::Expr* Init = nullptr, bool DirectInit = false, + clang::TypeSourceInfo* TSI = nullptr, + clang::VarDecl::InitializationStyle IS = + clang::VarDecl::InitializationStyle::CInit); /// Creates a namespace declaration and enters its context. All subsequent /// Stmts are built inside that namespace, until /// m_Sema.PopDeclContextIsUsed. diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index d0ec52a6c..40c2365e8 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -215,7 +215,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto* gradientParam = gradientParams[i]; auto* gradientVD = - BuildVarDecl(gradientParam->getType(), gradientParam->getName(), + BuildGlobalVarDecl(gradientParam->getType(), gradientParam->getName(), BuildDeclRef(overloadParam)); callArgs.push_back(BuildDeclRef(gradientVD)); addToCurrentBlock(BuildDeclStmt(gradientVD)); @@ -581,7 +581,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (utils::isArrayOrPointerType(VDDerivedType)) continue; auto* VDDerived = - BuildVarDecl(VDDerivedType, "_d_" + param->getNameAsString(), + BuildGlobalVarDecl(VDDerivedType, "_d_" + param->getNameAsString(), getZeroInit(VDDerivedType)); m_Variables[param] = BuildDeclRef(VDDerived); addToBlock(BuildDeclStmt(VDDerived), m_Globals); @@ -2540,7 +2540,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (const auto* AT = dyn_cast(VD->getType())) { Expr* init = getArraySizeExpr(AT, m_Context, *this); VDDerivedInit = init; - VDDerived = BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), + VDDerived = BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false, nullptr, clang::VarDecl::InitializationStyle::CallInit); } else { @@ -2614,12 +2614,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // necessary to preserve the old tests. if (VDDerivedType->isRecordType()) VDDerived = - BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), + BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, VD->isDirectInit(), m_Context.getTrivialTypeSourceInfo(VDDerivedType), VD->getInitStyle()); else - VDDerived = BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), + VDDerived = BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit); } @@ -2661,11 +2661,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Here separate behaviour for record and non-record types is only // necessary to preserve the old tests. if (VD->getType()->isRecordType()) { - VDClone = BuildVarDecl(VD->getType(), VD->getNameAsString(), + VDClone = BuildGlobalVarDecl(VD->getType(), VD->getNameAsString(), initDiff.getExpr(), VD->isDirectInit(), VD->getTypeSourceInfo(), VD->getInitStyle()); } else { - VDClone = BuildVarDecl(CloneType(VD->getType()), VD->getNameAsString(), + VDClone = BuildGlobalVarDecl(CloneType(VD->getType()), VD->getNameAsString(), initDiff.getExpr(), VD->isDirectInit()); } Expr* derivedVDE = BuildDeclRef(VDDerived); @@ -2786,27 +2786,25 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // with function globals and replace initializations with // assignments. This is a temporary measure to avoid the bug that arises // from overwriting local variables on different loop passes. - if (isInsideLoop) { - if (VD->getType()->isBuiltinType()) { - auto* decl = VDDiff.getDecl(); - /// The same variable will be assigned with new values every - /// loop iteration so the const qualifier must be dropped. - if (decl->getType().isConstQualified()) { - QualType nonConstType = - getNonConstType(decl->getType(), m_Context, m_Sema); - decl->setType(nonConstType); - } - if (decl->getInit()) { - auto* declRef = BuildDeclRef(decl); - auto pushPop = - StoreAndRestore(declRef, /*prefix=*/"_t", /*force=*/true); - if (pushPop.getExpr() != declRef) - addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse); - auto* assignment = BuildOp(BO_Assign, declRef, decl->getInit()); - inits.push_back(BuildOp(BO_Comma, pushPop.getExpr(), assignment)); - } - decl->setInit(getZeroInit(VD->getType())); + if (VD->getType()->isBuiltinType() && m_Reverse.size()>2) { + auto* decl = VDDiff.getDecl(); + /// The same variable will be assigned with new values every + /// loop iteration so the const qualifier must be dropped. + if (decl->getType().isConstQualified()) { + QualType nonConstType = + getNonConstType(decl->getType(), m_Context, m_Sema); + decl->setType(nonConstType); } + if (decl->getInit()) { + auto* declRef = BuildDeclRef(decl); + auto pushPop = + StoreAndRestore(declRef, /*prefix=*/"_t", /*force=*/true); + if (pushPop.getExpr() != declRef) + addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse); + auto* assignment = BuildOp(BO_Assign, declRef, decl->getInit()); + inits.push_back(BuildOp(BO_Comma, pushPop.getExpr(), assignment)); + } + decl->setInit(getZeroInit(VD->getType())); } decls.push_back(VDDiff.getDecl()); @@ -2843,14 +2841,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, /// with function globals and replace initializations with assignments. /// This is a temporary measure to avoid the bug that arises from /// overwriting local variables on different loop passes. - if (isInsideLoop) { - if (auto* VD = dyn_cast(decls[0])) { - if (VD->getType()->isBuiltinType()) { - addToBlock(DSClone, m_Globals); - Stmt* initAssignments = MakeCompoundStmt(inits); - initAssignments = unwrapIfSingleStmt(initAssignments); - return StmtDiff(initAssignments); - } + if (auto* VD = dyn_cast(decls[0])) { + if (VD->getType()->isBuiltinType() && m_Reverse.size()>2) { + addToBlock(DSClone, m_Globals); + Stmt* initAssignments = MakeCompoundStmt(inits); + initAssignments = unwrapIfSingleStmt(initAssignments); + return StmtDiff(initAssignments); } } diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index fc0857fb8..79a6a9a64 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -127,7 +127,8 @@ namespace clad { VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier, Expr* Init, bool DirectInit, TypeSourceInfo* TSI, - VarDecl::InitializationStyle IS) { + VarDecl::InitializationStyle IS, + Scope* scope) { // add namespace specifier in variable declaration if needed. Type = utils::AddNamespaceSpecifier(m_Sema, m_Context, Type); @@ -143,8 +144,10 @@ namespace clad { m_Sema.ActOnUninitializedDecl(VD); } m_Sema.FinalizeDeclaration(VD); + if (!scope) + scope = getCurrentScope(); // Add the identifier to the scope and IdResolver - m_Sema.PushOnScopeChains(VD, getCurrentScope(), /*AddToContext*/ false); + m_Sema.PushOnScopeChains(VD, scope, /*AddToContext*/ false); return VD; } @@ -162,6 +165,14 @@ namespace clad { TSI, IS); } + VarDecl* VisitorBase::BuildGlobalVarDecl(QualType Type, llvm::StringRef prefix, + Expr* Init, bool DirectInit, + TypeSourceInfo* TSI, + VarDecl::InitializationStyle IS) { + return BuildVarDecl(Type, CreateUniqueIdentifier(prefix), Init, DirectInit, + TSI, IS, m_DerivativeFnScope); + } + NamespaceDecl* VisitorBase::BuildNamespaceDecl(IdentifierInfo* II, bool isInline) { // Check if the namespace is being redeclared.