Skip to content

Commit

Permalink
Make declarations function global.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Jan 26, 2024
1 parent ae4b94f commit e4e71e4
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 38 deletions.
11 changes: 10 additions & 1 deletion include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,8 @@ namespace clad {
clang::Expr* Init = nullptr, bool DirectInit = false,
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
clang::VarDecl::InitializationStyle::CInit,
clang::Scope* scope=nullptr);
/// Builds variable declaration to be used inside the derivative
/// body.
/// \param[in] Type The type of variable declaration to build.
Expand All @@ -311,6 +312,14 @@ namespace clad {
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
/// Builds variable declaration to be used inside the derivative
/// body in the derivative function global scope.
clang::VarDecl*
BuildGlobalVarDecl(clang::QualType Type, llvm::StringRef prefix = "_t",
clang::Expr* Init = nullptr, bool DirectInit = false,
clang::TypeSourceInfo* TSI = nullptr,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
/// Creates a namespace declaration and enters its context. All subsequent
/// Stmts are built inside that namespace, until
/// m_Sema.PopDeclContextIsUsed.
Expand Down
66 changes: 31 additions & 35 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto* gradientParam = gradientParams[i];

auto* gradientVD =
BuildVarDecl(gradientParam->getType(), gradientParam->getName(),
BuildGlobalVarDecl(gradientParam->getType(), gradientParam->getName(),
BuildDeclRef(overloadParam));
callArgs.push_back(BuildDeclRef(gradientVD));
addToCurrentBlock(BuildDeclStmt(gradientVD));
Expand Down Expand Up @@ -581,7 +581,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (utils::isArrayOrPointerType(VDDerivedType))
continue;
auto* VDDerived =
BuildVarDecl(VDDerivedType, "_d_" + param->getNameAsString(),
BuildGlobalVarDecl(VDDerivedType, "_d_" + param->getNameAsString(),
getZeroInit(VDDerivedType));
m_Variables[param] = BuildDeclRef(VDDerived);
addToBlock(BuildDeclStmt(VDDerived), m_Globals);
Expand Down Expand Up @@ -2540,7 +2540,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (const auto* AT = dyn_cast<ArrayType>(VD->getType())) {
Expr* init = getArraySizeExpr(AT, m_Context, *this);
VDDerivedInit = init;
VDDerived = BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
VDDerived = BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
VDDerivedInit, false, nullptr,
clang::VarDecl::InitializationStyle::CallInit);
} else {
Expand Down Expand Up @@ -2614,12 +2614,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// necessary to preserve the old tests.
if (VDDerivedType->isRecordType())
VDDerived =
BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
VDDerivedInit, VD->isDirectInit(),
m_Context.getTrivialTypeSourceInfo(VDDerivedType),
VD->getInitStyle());
else
VDDerived = BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
VDDerived = BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
VDDerivedInit);
}

Expand Down Expand Up @@ -2661,11 +2661,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Here separate behaviour for record and non-record types is only
// necessary to preserve the old tests.
if (VD->getType()->isRecordType()) {
VDClone = BuildVarDecl(VD->getType(), VD->getNameAsString(),
VDClone = BuildGlobalVarDecl(VD->getType(), VD->getNameAsString(),
initDiff.getExpr(), VD->isDirectInit(),
VD->getTypeSourceInfo(), VD->getInitStyle());
} else {
VDClone = BuildVarDecl(CloneType(VD->getType()), VD->getNameAsString(),
VDClone = BuildGlobalVarDecl(CloneType(VD->getType()), VD->getNameAsString(),
initDiff.getExpr(), VD->isDirectInit());
}
Expr* derivedVDE = BuildDeclRef(VDDerived);
Expand Down Expand Up @@ -2786,27 +2786,25 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// with function globals and replace initializations with
// assignments. This is a temporary measure to avoid the bug that arises
// from overwriting local variables on different loop passes.
if (isInsideLoop) {
if (VD->getType()->isBuiltinType()) {
auto* decl = VDDiff.getDecl();
/// The same variable will be assigned with new values every
/// loop iteration so the const qualifier must be dropped.
if (decl->getType().isConstQualified()) {
QualType nonConstType =
getNonConstType(decl->getType(), m_Context, m_Sema);
decl->setType(nonConstType);
}
if (decl->getInit()) {
auto* declRef = BuildDeclRef(decl);
auto pushPop =
StoreAndRestore(declRef, /*prefix=*/"_t", /*force=*/true);
if (pushPop.getExpr() != declRef)
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
auto* assignment = BuildOp(BO_Assign, declRef, decl->getInit());
inits.push_back(BuildOp(BO_Comma, pushPop.getExpr(), assignment));
}
decl->setInit(getZeroInit(VD->getType()));
if (VD->getType()->isBuiltinType() && m_Reverse.size()>2) {
auto* decl = VDDiff.getDecl();
/// The same variable will be assigned with new values every
/// loop iteration so the const qualifier must be dropped.
if (decl->getType().isConstQualified()) {
QualType nonConstType =
getNonConstType(decl->getType(), m_Context, m_Sema);
decl->setType(nonConstType);
}
if (decl->getInit()) {
auto* declRef = BuildDeclRef(decl);
auto pushPop =
StoreAndRestore(declRef, /*prefix=*/"_t", /*force=*/true);
if (pushPop.getExpr() != declRef)
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
auto* assignment = BuildOp(BO_Assign, declRef, decl->getInit());
inits.push_back(BuildOp(BO_Comma, pushPop.getExpr(), assignment));
}
decl->setInit(getZeroInit(VD->getType()));
}

decls.push_back(VDDiff.getDecl());
Expand Down Expand Up @@ -2843,14 +2841,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
/// with function globals and replace initializations with assignments.
/// This is a temporary measure to avoid the bug that arises from
/// overwriting local variables on different loop passes.
if (isInsideLoop) {
if (auto* VD = dyn_cast<VarDecl>(decls[0])) {
if (VD->getType()->isBuiltinType()) {
addToBlock(DSClone, m_Globals);
Stmt* initAssignments = MakeCompoundStmt(inits);
initAssignments = unwrapIfSingleStmt(initAssignments);
return StmtDiff(initAssignments);
}
if (auto* VD = dyn_cast<VarDecl>(decls[0])) {
if (VD->getType()->isBuiltinType() && m_Reverse.size()>2) {
addToBlock(DSClone, m_Globals);
Stmt* initAssignments = MakeCompoundStmt(inits);
initAssignments = unwrapIfSingleStmt(initAssignments);
return StmtDiff(initAssignments);
}
}

Expand Down
15 changes: 13 additions & 2 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ namespace clad {
VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier,
Expr* Init, bool DirectInit,
TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
VarDecl::InitializationStyle IS,
Scope* scope) {

// add namespace specifier in variable declaration if needed.
Type = utils::AddNamespaceSpecifier(m_Sema, m_Context, Type);
Expand All @@ -143,8 +144,10 @@ namespace clad {
m_Sema.ActOnUninitializedDecl(VD);
}
m_Sema.FinalizeDeclaration(VD);
if (!scope)
scope = getCurrentScope();
// Add the identifier to the scope and IdResolver
m_Sema.PushOnScopeChains(VD, getCurrentScope(), /*AddToContext*/ false);
m_Sema.PushOnScopeChains(VD, scope, /*AddToContext*/ false);
return VD;
}

Expand All @@ -162,6 +165,14 @@ namespace clad {
TSI, IS);
}

VarDecl* VisitorBase::BuildGlobalVarDecl(QualType Type, llvm::StringRef prefix,
Expr* Init, bool DirectInit,
TypeSourceInfo* TSI,
VarDecl::InitializationStyle IS) {
return BuildVarDecl(Type, CreateUniqueIdentifier(prefix), Init, DirectInit,
TSI, IS, m_DerivativeFnScope);
}

NamespaceDecl* VisitorBase::BuildNamespaceDecl(IdentifierInfo* II,
bool isInline) {
// Check if the namespace is being redeclared.
Expand Down

0 comments on commit e4e71e4

Please sign in to comment.