Skip to content

Commit

Permalink
Use CloneParmVarDecl and the parameters stored in the DiffRequest. NFC
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Dec 30, 2024
1 parent a2e1c50 commit 443fcd8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 20 deletions.
5 changes: 0 additions & 5 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@ namespace clad {
// several private/protected members of the visitor classes.
friend class ErrorEstimationHandler;
llvm::SmallVector<const clang::ValueDecl*, 16> m_IndependentVars;
/// Set used to keep track of parameter variables w.r.t which the
/// the derivative (gradient) is being computed. This is separate from the
/// m_Variables map because all other intermediate variables will
/// not be stored here.
std::unordered_set<const clang::ValueDecl*> m_ParamVarsWithDiff;
/// In addition to a sequence of forward-accumulated Stmts (m_Blocks), in
/// the reverse mode we also accumulate Stmts for the reverse pass which
/// will be executed on return.
Expand Down
27 changes: 12 additions & 15 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1528,8 +1528,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If the arg is used for differentiation of the function, then we
// cannot free it in the end as it's the result to be returned to the
// user.
if (m_ParamVarsWithDiff.find(DRE->getDecl()) ==
m_ParamVarsWithDiff.end())
if (std::find(m_DiffReq.DVI.begin(), m_DiffReq.DVI.end(),
DRE->getDecl()) == m_DiffReq.DVI.end())
DerivedCallArgs.push_back(ArgDiff.getExpr_dx());
}
}
Expand Down Expand Up @@ -4436,19 +4436,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

for (auto* PVD : m_DiffReq->parameters()) {
auto* newPVD = utils::BuildParmVarDecl(
m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(),
PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo());
params.push_back(newPVD);

if (newPVD->getIdentifier())
m_Sema.PushOnScopeChains(newPVD, getCurrentScope(),
/*AddToContext=*/false);
else {
IdentifierInfo* newName = CreateUniqueIdentifier("arg");
newPVD->setDeclName(newName);
IdentifierInfo* PVDII = PVD->getIdentifier();
// Implicitly created special member functions have no parameter names.
if (!PVD->getDeclName())
PVDII = CreateUniqueIdentifier("arg");
auto* newPVD = CloneParmVarDecl(PVD, PVDII,
/*pushOnScopeChains=*/true,
/*cloneDefaultArg=*/false);
if (!PVD->getDeclName()) // We can't use lookup-based replacements
m_DeclReplacements[PVD] = newPVD;
}

params.push_back(newPVD);

auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD);
if (it != std::end(diffParams)) {
Expand Down Expand Up @@ -4480,7 +4478,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Variables[*it] =
utils::BuildParenExpr(m_Sema, m_Variables[*it]);
}
m_ParamVarsWithDiff.emplace(*it);
}
}
}
Expand Down

0 comments on commit 443fcd8

Please sign in to comment.