diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 41b174f5c..5b1920d6a 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -404,12 +404,6 @@ DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() { "Doesn't support recursive diff. Use DiffPlan."); m_DerivativeInFlight = true; - auto originalFnEffectiveName = - utils::ComputeEffectiveFnName(m_DiffReq.Function); - - IdentifierInfo* derivedFnII = &m_Context.Idents.get( - originalFnEffectiveName + GetPushForwardFunctionSuffix()); - DeclarationNameInfo derivedFnName(derivedFnII, m_DiffReq->getLocation()); llvm::SmallVector paramTypes; llvm::SmallVector derivedParamTypes; @@ -446,9 +440,15 @@ DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() { auto* DC = const_cast(m_DiffReq->getDeclContext()); m_Sema.CurContext = DC; - SourceLocation loc{m_DiffReq->getLocation()}; - DeclWithContext cloneFunctionResult = m_Builder.cloneFunction( - m_DiffReq.Function, *this, DC, loc, derivedFnName, derivedFnType); + auto originalFnEffectiveName = utils::ComputeEffectiveFnName(FD); + + IdentifierInfo* derivedFnII = &m_Context.Idents.get( + originalFnEffectiveName + GetPushForwardFunctionSuffix()); + SourceLocation loc{FD->getLocation()}; + DeclarationNameInfo derivedFnName(derivedFnII, loc); + + DeclWithContext cloneFunctionResult = + m_Builder.cloneFunction(FD, *this, DC, loc, derivedFnName, derivedFnType); m_Derivative = cloneFunctionResult.first; llvm::SmallVector params; @@ -518,7 +518,6 @@ DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() { Stmt* derivativeBody = endBlock(); m_Derivative->setBody(derivativeBody); - endScope(); // Function body scope // Size >= current derivative order means that there exists a declaration