From aa5cbdfd44148837e27a9acf44e012280bb42a29 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Fri, 27 Dec 2024 20:40:30 +0000 Subject: [PATCH] Reroute pushforwards to the generic ::Derive method. --- .../Differentiator/BaseForwardModeVisitor.h | 4 +- lib/Differentiator/BaseForwardModeVisitor.cpp | 213 +++++++----------- lib/Differentiator/DerivativeBuilder.cpp | 4 +- 3 files changed, 80 insertions(+), 141 deletions(-) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 107955895..1ba1a16dd 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -37,9 +37,7 @@ class BaseForwardModeVisitor /// DerivativeAndOverload Derive(); - DerivativeAndOverload DerivePushforward(); - - virtual void ExecuteInsidePushforwardFunctionBlock(); + virtual void ExecuteInsidePushforwardFunctionBlock() {} static bool IsDifferentiableType(clang::QualType T); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 29cf3d309..e54043cc6 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -12,6 +12,7 @@ #include "clad/Differentiator/DiffMode.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" +#include "clad/Differentiator/ParseDiffArgsTypes.h" #include "clang/AST/ASTContext.h" #include "clang/AST/ASTLambda.h" @@ -28,10 +29,13 @@ #include "clang/Sema/SemaInternal.h" #include "clang/Sema/Template.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/IR/Constants.h" #include "llvm/Support/SaveAndRestore.h" #include +#include #include #include @@ -67,75 +71,80 @@ bool IsRealNonReferenceType(QualType T) { DerivativeAndOverload BaseForwardModeVisitor::Derive() { const FunctionDecl* FD = m_DiffReq.Function; - assert(m_DiffReq.Mode == DiffMode::forward); + assert(m_DiffReq.Mode == DiffMode::forward || + m_DiffReq.Mode == DiffMode::experimental_pushforward || + m_DiffReq.Mode == DiffMode::experimental_vector_pushforward); assert(!m_DerivativeInFlight && "Doesn't support recursive diff. Use DiffPlan."); - m_DerivativeInFlight = true; - - DiffInputVarsInfo DVI = m_DiffReq.DVI; - - // FIXME: Shouldn't we give error here that no arg is specified? - if (DVI.empty()) - return {}; - - DiffInputVarInfo diffVarInfo = DVI.back(); - - // Check that only one arg is requested and if the arg requested is of array - // or pointer type, only one of the indices have been requested - if (DVI.size() > 1 || (isArrayOrPointerType(diffVarInfo.param->getType()) && - (diffVarInfo.paramIndexInterval.size() != 1))) { - diag(DiagnosticsEngine::Error, - m_DiffReq.Args ? m_DiffReq.Args->getEndLoc() : noLoc, - "Forward mode differentiation w.r.t. several parameters at once is " - "not " - "supported, call 'clad::differentiate' for each parameter " - "separately"); - return {}; - } - // FIXME: implement gradient-vector products to fix the issue. - assert((DVI.size() == 1) && - "nested forward mode differentiation for several args is broken"); - - // FIXME: Differentiation variable cannot always be represented just by - // `ValueDecl*` variable. For example -- `u.mem1.mem2,`, `arr[7]` etc. - // FIXME: independent variable is misleading terminology, what we actually - // mean here is 'variable' with respect to which differentiation is being - // performed. Mathematically, independent variables are all the function - // parameters, thus, does not convey the intendend meaning. - m_IndependentVar = DVI.back().param; - // If param is not real (i.e. floating point or integral), a pointer to a - // real type, or an array of a real type we cannot differentiate it. - // FIXME: we should support custom numeric types in the future. - if (isArrayOrPointerType(m_IndependentVar->getType())) { - if (!m_IndependentVar->getType() - ->getPointeeOrArrayElementType() - ->isRealType()) { - diag(DiagnosticsEngine::Error, m_IndependentVar->getEndLoc(), - "attempted differentiation w.r.t. a parameter ('%0') which is not" - " an array or pointer of a real type", - {m_IndependentVar->getNameAsString()}); + llvm::SaveAndRestore saveInFlight(m_DerivativeInFlight, + /*NewValue=*/true); + + if (m_DiffReq.Mode == DiffMode::forward) { + DiffInputVarsInfo DVI = m_DiffReq.DVI; + + // FIXME: Shouldn't we give error here that no arg is specified? + if (DVI.empty()) return {}; - } - m_IndependentVarIndex = diffVarInfo.paramIndexInterval.Start; - } else { - QualType T = m_IndependentVar->getType(); - bool isField = false; - if (auto RD = diffVarInfo.param->getType()->getAsCXXRecordDecl()) { - llvm::SmallVector ref(diffVarInfo.fields.begin(), - diffVarInfo.fields.end()); - T = utils::ComputeMemExprPathType(m_Sema, RD, ref); - isField = true; - } - if (!IsRealNonReferenceType(T)) { - diag(DiagnosticsEngine::Error, m_DiffReq.Args->getEndLoc(), - "Attempted differentiation w.r.t. %0 '%1' which is not " - "of real type.", - {(isField ? "member" : "parameter"), diffVarInfo.source}); + + DiffInputVarInfo diffVarInfo = DVI.back(); + + // Check that only one arg is requested and if the arg requested is of array + // or pointer type, only one of the indices have been requested + if (DVI.size() > 1 || (isArrayOrPointerType(diffVarInfo.param->getType()) && + (diffVarInfo.paramIndexInterval.size() != 1))) { + diag(DiagnosticsEngine::Error, + m_DiffReq.Args ? m_DiffReq.Args->getEndLoc() : noLoc, + "Forward mode differentiation w.r.t. several parameters at once is " + "not " + "supported, call 'clad::differentiate' for each parameter " + "separately"); return {}; } - } + // FIXME: implement gradient-vector products to fix the issue. + assert((DVI.size() == 1) && + "nested forward mode differentiation for several args is broken"); + + // FIXME: Differentiation variable cannot always be represented just by + // `ValueDecl*` variable. For example -- `u.mem1.mem2,`, `arr[7]` etc. + // FIXME: independent variable is misleading terminology, what we actually + // mean here is 'variable' with respect to which differentiation is being + // performed. Mathematically, independent variables are all the function + // parameters, thus, does not convey the intendend meaning. + m_IndependentVar = DVI.back().param; + // If param is not real (i.e. floating point or integral), a pointer to a + // real type, or an array of a real type we cannot differentiate it. + // FIXME: we should support custom numeric types in the future. + if (isArrayOrPointerType(m_IndependentVar->getType())) { + if (!m_IndependentVar->getType() + ->getPointeeOrArrayElementType() + ->isRealType()) { + diag(DiagnosticsEngine::Error, m_IndependentVar->getEndLoc(), + "attempted differentiation w.r.t. a parameter ('%0') which is not" + " an array or pointer of a real type", + {m_IndependentVar->getNameAsString()}); + return {}; + } + m_IndependentVarIndex = diffVarInfo.paramIndexInterval.Start; + } else { + QualType T = m_IndependentVar->getType(); + bool isField = false; + if (auto* RD = diffVarInfo.param->getType()->getAsCXXRecordDecl()) { + llvm::SmallVector ref(diffVarInfo.fields.begin(), + diffVarInfo.fields.end()); + T = utils::ComputeMemExprPathType(m_Sema, RD, ref); + isField = true; + } + if (!IsRealNonReferenceType(T)) { + diag(DiagnosticsEngine::Error, m_DiffReq.Args->getEndLoc(), + "Attempted differentiation w.r.t. %0 '%1' which is not " + "of real type.", + {(isField ? "member" : "parameter"), diffVarInfo.source}); + return {}; + } + } + } // Check if the function is already declared as a custom derivative. std::string gradientName = m_DiffReq.ComputeDerivativeName(); @@ -177,6 +186,10 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() { m_DerivativeFnScope = getCurrentScope(); beginBlock(); + // FIXME: Remove the override in VectorPushForwardModeVisitor. + if (m_DiffReq.Mode == DiffMode::experimental_vector_pushforward) + ExecuteInsidePushforwardFunctionBlock(); + if (m_DiffReq.Mode == DiffMode::forward) GenerateSeeds(derivedFD); @@ -202,8 +215,6 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() { m_Sema.PopDeclContext(); endScope(); // Function decl scope - m_DerivativeInFlight = false; - return DerivativeAndOverload{result.first, /*OverloadFunctionDecl=*/nullptr}; } @@ -285,7 +296,7 @@ void BaseForwardModeVisitor::SetupDerivativeParameters( // independent parameter. if (const auto* MD = dyn_cast(FD)) { if (MD->isInstance()) { - IdentifierInfo* dThisII = CreateUniqueIdentifier("_d_this"); + IdentifierInfo* dThisII = &m_Context.Idents.get("_d_this"); auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext, dThisII, MD->getThisType()); m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), /*AddToContext=*/false); @@ -302,7 +313,7 @@ void BaseForwardModeVisitor::SetupDerivativeParameters( if (!BaseForwardModeVisitor::IsDifferentiableType(PVD->getType())) continue; - IdentifierInfo* II = CreateUniqueIdentifier("_d_" + PVD->getNameAsString()); + IdentifierInfo* II = &m_Context.Idents.get("_d_" + PVD->getNameAsString()); auto* dPVD = utils::BuildParmVarDecl( m_Sema, m_Derivative, II, GetPushForwardDerivativeType(PVD->getType()), PVD->getStorageClass()); @@ -437,76 +448,6 @@ void BaseForwardModeVisitor::GenerateSeeds(const clang::FunctionDecl* dFD) { } } -void BaseForwardModeVisitor::ExecuteInsidePushforwardFunctionBlock() { - Stmt* bodyDiff = Visit(m_DiffReq->getBody()).getStmt(); - auto* CS = cast(bodyDiff); - for (Stmt* S : CS->body()) - addToCurrentBlock(S); -} - -DerivativeAndOverload BaseForwardModeVisitor::DerivePushforward() { - const FunctionDecl* FD = m_DiffReq.Function; - assert(m_DiffReq.Mode == GetPushForwardMode()); - assert(!m_DerivativeInFlight && - "Doesn't support recursive diff. Use DiffPlan."); - m_DerivativeInFlight = true; - - llvm::SaveAndRestore saveContext(m_Sema.CurContext); - llvm::SaveAndRestore saveScope(getCurrentScope(), - getEnclosingNamespaceOrTUScope()); - - IdentifierInfo* II = &m_Context.Idents.get(m_DiffReq.ComputeDerivativeName()); - SourceLocation loc{FD->getLocation()}; - DeclarationNameInfo derivedFnName(II, loc); - - // FIXME: We should not use const_cast to get the decl context here. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto* DC = const_cast(m_DiffReq->getDeclContext()); - m_Sema.CurContext = DC; - QualType derivedFnType = ComputeDerivativeFunctionType(); - DeclWithContext cloneFunctionResult = - m_Builder.cloneFunction(FD, *this, DC, loc, derivedFnName, derivedFnType); - m_Derivative = cloneFunctionResult.first; - - beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | - Scope::DeclScope); - m_Sema.PushFunctionScope(); - m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); - - llvm::SmallVector params; - SetupDerivativeParameters(params); - m_Derivative->setParams(params); - - m_Derivative->setBody(nullptr); - - if (!m_DiffReq.DeclarationOnly) { - beginScope(Scope::FnScope | Scope::DeclScope); - m_DerivativeFnScope = getCurrentScope(); - beginBlock(); - - // execute the functor inside the function body. - ExecuteInsidePushforwardFunctionBlock(); - - Stmt* derivativeBody = endBlock(); - m_Derivative->setBody(derivativeBody); - endScope(); // Function body scope - - // Size >= current derivative order means that there exists a declaration - // or prototype for the currently derived function. - if (m_DiffReq.DerivedFDPrototypes.size() >= - m_DiffReq.CurrentDerivativeOrder) - m_Derivative->setPreviousDeclaration( - m_DiffReq.DerivedFDPrototypes[m_DiffReq.CurrentDerivativeOrder - 1]); - } - - m_Sema.PopFunctionScopeInfo(); - m_Sema.PopDeclContext(); - endScope(); // Function decl scope - - m_DerivativeInFlight = false; - return DerivativeAndOverload{cloneFunctionResult.first}; -} - StmtDiff BaseForwardModeVisitor::VisitStmt(const Stmt* S) { diag(DiagnosticsEngine::Warning, S->getBeginLoc(), "attempted to differentiate unsupported statement, no changes applied"); diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index de4150735..d7482609c 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -431,13 +431,13 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { result = V.Derive(); } else if (request.Mode == DiffMode::experimental_pushforward) { PushForwardModeVisitor V(*this, request); - result = V.DerivePushforward(); + result = V.Derive(); } else if (request.Mode == DiffMode::vector_forward_mode) { VectorForwardModeVisitor V(*this, request); result = V.DeriveVectorMode(); } else if (request.Mode == DiffMode::experimental_vector_pushforward) { VectorPushForwardModeVisitor V(*this, request); - result = V.DerivePushforward(); + result = V.Derive(); } else if (request.Mode == DiffMode::reverse) { ReverseModeVisitor V(*this, request); result = V.Derive();