From 443fcd8fe53c7769b54bb56975750780c8f3fab7 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Mon, 30 Dec 2024 18:27:34 +0000 Subject: [PATCH] Use CloneParmVarDecl and the parameters stored in the DiffRequest. NFC --- .../clad/Differentiator/ReverseModeVisitor.h | 5 ---- lib/Differentiator/ReverseModeVisitor.cpp | 27 +++++++++---------- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 70a44f6aa..910b67c49 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -40,11 +40,6 @@ namespace clad { // several private/protected members of the visitor classes. friend class ErrorEstimationHandler; llvm::SmallVector m_IndependentVars; - /// Set used to keep track of parameter variables w.r.t which the - /// the derivative (gradient) is being computed. This is separate from the - /// m_Variables map because all other intermediate variables will - /// not be stored here. - std::unordered_set m_ParamVarsWithDiff; /// In addition to a sequence of forward-accumulated Stmts (m_Blocks), in /// the reverse mode we also accumulate Stmts for the reverse pass which /// will be executed on return. diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 5a935fb78..a0bcb252a 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1528,8 +1528,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If the arg is used for differentiation of the function, then we // cannot free it in the end as it's the result to be returned to the // user. - if (m_ParamVarsWithDiff.find(DRE->getDecl()) == - m_ParamVarsWithDiff.end()) + if (std::find(m_DiffReq.DVI.begin(), m_DiffReq.DVI.end(), + DRE->getDecl()) == m_DiffReq.DVI.end()) DerivedCallArgs.push_back(ArgDiff.getExpr_dx()); } } @@ -4436,19 +4436,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } for (auto* PVD : m_DiffReq->parameters()) { - auto* newPVD = utils::BuildParmVarDecl( - m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(), - PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo()); - params.push_back(newPVD); - - if (newPVD->getIdentifier()) - m_Sema.PushOnScopeChains(newPVD, getCurrentScope(), - /*AddToContext=*/false); - else { - IdentifierInfo* newName = CreateUniqueIdentifier("arg"); - newPVD->setDeclName(newName); + IdentifierInfo* PVDII = PVD->getIdentifier(); + // Implicitly created special member functions have no parameter names. + if (!PVD->getDeclName()) + PVDII = CreateUniqueIdentifier("arg"); + auto* newPVD = CloneParmVarDecl(PVD, PVDII, + /*pushOnScopeChains=*/true, + /*cloneDefaultArg=*/false); + if (!PVD->getDeclName()) // We can't use lookup-based replacements m_DeclReplacements[PVD] = newPVD; - } + + params.push_back(newPVD); auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD); if (it != std::end(diffParams)) { @@ -4480,7 +4478,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Variables[*it] = utils::BuildParenExpr(m_Sema, m_Variables[*it]); } - m_ParamVarsWithDiff.emplace(*it); } } }