From f201b71dec4abe2f3c637dd9cf126bf42368a2cf Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Wed, 25 Dec 2024 10:56:38 +0000 Subject: [PATCH] Harmonize Derive and DerivePullback --- lib/Differentiator/ReverseModeVisitor.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6db6ee915..ecc17c951 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -398,9 +398,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActAfterCreatingDerivedFnParamTypes(paramTypes); - QualType pullbackFnType = m_Context.getFunctionType( - m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo()); - llvm::SaveAndRestore saveContext(m_Sema.CurContext); llvm::SaveAndRestore saveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); @@ -409,6 +406,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Sema.CurContext = const_cast(m_DiffReq->getDeclContext()); SourceLocation validLoc{m_DiffReq->getLocation()}; + QualType pullbackFnType = m_Context.getFunctionType( + m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo()); DeclWithContext fnBuildRes = m_Builder.cloneFunction(m_DiffReq.Function, *this, m_Sema.CurContext, validLoc, DNI, pullbackFnType);