From 6329ae89043a10a54e0ac7f0131dc9ca077514f0 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Fri, 3 Jan 2025 20:32:39 +0000 Subject: [PATCH] Clean up the code computing the parameters and their types. This patch harmonizes the steps to build the derived function prototype along with its parameters. It refactors a lot of tricky code that grew organically and included many workarounds. This is only a single step in a direction of merging the common code patterns across visitors. As a bonus we do not generate broken code when we have derivative parameters named `_d_x`. That's common in computing hessians. --- .../clad/Differentiator/ReverseModeVisitor.h | 18 +- lib/Differentiator/ReverseModeVisitor.cpp | 314 +++++++++--------- test/Hessian/NestedFunctionCalls.C | 20 +- 3 files changed, 177 insertions(+), 175 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 1f6f24225..77c7b11ad 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -23,6 +23,10 @@ #include #include +namespace llvm { +template class SmallVectorImpl; +} + namespace clad { class ErrorEstimationHandler; class ExternalRMVSource; @@ -682,19 +686,11 @@ namespace clad { ///\paramp[in] source An external RMV source void AddExternalSource(ExternalRMVSource& source); - /// Computes and returns the sequence of derived function parameter types. - /// - /// Information about the original function and the differentiation mode - /// are taken from the data member variables. - llvm::SmallVector ComputeParamTypes(const DiffParams& diffParams); + /// Computes and returns the derived function prototype. + clang::QualType ComputeDerivativeFunctionType(); /// Builds and returns the sequence of derived function parameters. - /// - /// Information about the original function, derived function, derived - /// function parameter types and the differentiation mode are implicitly - /// taken from the data member variables. - llvm::SmallVector - BuildParams(DiffParams& diffParams); + void BuildParams(llvm::SmallVectorImpl& params); clang::QualType ComputeAdjointType(clang::QualType T); clang::QualType ComputeParamType(clang::QualType T); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 8c7856da5..1bca6a06c 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -209,40 +209,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActAfterParsingDiffArgs(m_DiffReq, args); - // If we are in error estimation mode, we have an extra `double&` - // parameter that stores the final error - unsigned numExtraParam = 0; - if (m_ExternalSource) - m_ExternalSource->ActBeforeCreatingDerivedFnParamTypes(numExtraParam); - - auto paramTypes = ComputeParamTypes(args); - - if (m_ExternalSource) - m_ExternalSource->ActAfterCreatingDerivedFnParamTypes(paramTypes); - // If reverse mode differentiates only part of the arguments it needs to // generate an overload that can take in all the diff variables bool shouldCreateOverload = false; // FIXME: Gradient overload doesn't know how to handle additional parameters // added by the plugins yet. - if (numExtraParam == 0) + if (!m_ExternalSource) shouldCreateOverload = true; if (!m_DiffReq.DeclarationOnly && !m_DiffReq.DerivedFDPrototypes.empty()) // If the overload is already created, we don't need to create it again. shouldCreateOverload = false; - const auto* originalFnType = - dyn_cast(m_DiffReq->getType()); // For a function f of type R(A1, A2, ..., An), // the type of the gradient function is void(A1, A2, ..., An, R*, R*, ..., // R*) . the type of the jacobian function is void(A1, A2, ..., An, R*, R*) // and for error estimation, the function type is // void(A1, A2, ..., An, R*, R*, ..., R*, double&) - QualType gradientFunctionType = m_Context.getFunctionType( - m_Context.VoidTy, - llvm::ArrayRef(paramTypes.data(), paramTypes.size()), - // Cast to function pointer. - originalFnType->getExtProtoInfo()); + QualType gradientFunctionType = ComputeDerivativeFunctionType(); // Check if the function is already declared as a custom derivative. std::string gradientName = m_DiffReq.ComputeDerivativeName(); @@ -284,7 +267,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActAfterCreatingDerivedFnScope(); - auto params = BuildParams(args); + llvm::SmallVector params; + BuildParams(params); if (m_ExternalSource) m_ExternalSource->ActAfterCreatingDerivedFnParams(params); @@ -364,16 +348,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActAfterParsingDiffArgs(m_DiffReq, args); - auto derivativeName = m_DiffReq.ComputeDerivativeName(); - auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName); - - auto paramTypes = ComputeParamTypes(args); - const auto* originalFnType = - dyn_cast(m_DiffReq->getType()); - - if (m_ExternalSource) - m_ExternalSource->ActAfterCreatingDerivedFnParamTypes(paramTypes); - llvm::SaveAndRestore saveContext(m_Sema.CurContext); llvm::SaveAndRestore saveScope(getCurrentScope(), getEnclosingNamespaceOrTUScope()); @@ -382,8 +356,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Sema.CurContext = const_cast(m_DiffReq->getDeclContext()); SourceLocation validLoc{m_DiffReq->getLocation()}; - QualType pullbackFnType = m_Context.getFunctionType( - m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo()); + QualType pullbackFnType = ComputeDerivativeFunctionType(); + auto derivativeName = m_DiffReq.ComputeDerivativeName(); + auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName); DeclWithContext fnBuildRes = m_Builder.cloneFunction(m_DiffReq.Function, *this, m_Sema.CurContext, validLoc, DNI, pullbackFnType); @@ -400,7 +375,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActAfterCreatingDerivedFnScope(); - auto params = BuildParams(args); + llvm::SmallVector params; + BuildParams(params); if (m_ExternalSource) m_ExternalSource->ActAfterCreatingDerivedFnParams(params); @@ -4341,153 +4317,183 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return m_Context.getPointerType(TValueType); } - llvm::SmallVector - ReverseModeVisitor::ComputeParamTypes(const DiffParams& diffParams) { - llvm::SmallVector paramTypes; - paramTypes.reserve(m_DiffReq->getNumParams() * 2); - for (auto* PVD : m_DiffReq->parameters()) - paramTypes.push_back(PVD->getType()); - // TODO: Add DiffMode::experimental_pullback support here as well. - if (m_DiffReq.Mode == DiffMode::reverse || - m_DiffReq.Mode == DiffMode::experimental_pullback) { - QualType effectiveReturnType = - m_DiffReq->getReturnType().getNonReferenceType(); - // FIXME: We ignore the pointer return type for pullbacks. - if (m_DiffReq.Mode == DiffMode::experimental_pullback && - !effectiveReturnType->isVoidType() && - !effectiveReturnType->isPointerType()) - paramTypes.push_back(effectiveReturnType); - - if (const auto* MD = dyn_cast(m_DiffReq.Function)) { - const CXXRecordDecl* RD = MD->getParent(); - if (MD->isInstance() && !RD->isLambda()) { - QualType thisType = MD->getThisType(); - paramTypes.push_back( - GetParameterDerivativeType(effectiveReturnType, thisType)); - } - } - - for (auto* PVD : m_DiffReq->parameters()) { - const auto* it = - std::find(std::begin(diffParams), std::end(diffParams), PVD); - if (it != std::end(diffParams)) - paramTypes.push_back(ComputeParamType(PVD->getType())); - } + static bool needsDThis(const FunctionDecl* FD) { + if (const auto* MD = dyn_cast(FD)) { + const CXXRecordDecl* RD = MD->getParent(); + if (MD->isInstance() && !RD->isLambda()) + return true; } - return paramTypes; + return false; } - llvm::SmallVector - ReverseModeVisitor::BuildParams(DiffParams& diffParams) { - llvm::SmallVector params; - llvm::SmallVector paramDerivatives; - params.reserve(m_DiffReq->getNumParams() + diffParams.size()); - const auto* derivativeFnType = - cast(m_Derivative->getType()); - std::size_t dParamTypesIdx = m_DiffReq->getNumParams(); + // FIXME: Merge with BaseForwardModeVisitor::ComputeDerivativeFunctionType. + clang::QualType ReverseModeVisitor::ComputeDerivativeFunctionType() { + const FunctionDecl* FD = m_DiffReq.Function; + + const auto* FnProtoTy = cast(FD->getType()); + llvm::SmallVector FnTypes(FnProtoTy->getParamTypes().begin(), + FnProtoTy->getParamTypes().end()); + QualType oRetTy = FD->getReturnType(); + QualType dRetTy = oRetTy.getNonReferenceType(); + + // FIXME: We ignore the pointer return type for pullbacks. + bool HasRet = false; if (m_DiffReq.Mode == DiffMode::experimental_pullback && - !m_DiffReq->getReturnType()->isVoidType() && - !m_DiffReq->getReturnType()->isPointerType()) { - ++dParamTypesIdx; + !dRetTy->isVoidType() && !dRetTy->isPointerType()) { + FnTypes.push_back(dRetTy); + HasRet = true; } - if (const auto* MD = dyn_cast(m_DiffReq.Function)) { - const CXXRecordDecl* RD = MD->getParent(); - if (MD->isInstance() && !RD->isLambda()) { - auto* thisDerivativePVD = utils::BuildParmVarDecl( - m_Sema, m_Derivative, CreateUniqueIdentifier("_d_this"), - derivativeFnType->getParamType(dParamTypesIdx)); - paramDerivatives.push_back(thisDerivativePVD); - - if (thisDerivativePVD->getIdentifier()) - m_Sema.PushOnScopeChains(thisDerivativePVD, getCurrentScope(), - /*AddToContext=*/false); - - // This can instantiate an array_ref and needs a fake source location. - SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema); - Expr* deref = BuildOp(UnaryOperatorKind::UO_Deref, - BuildDeclRef(thisDerivativePVD), fakeLoc); - m_ThisExprDerivative = utils::BuildParenExpr(m_Sema, deref); - ++dParamTypesIdx; - } + bool HasThis = needsDThis(FD); + if (HasThis) { + const auto* MD = cast(FD); + QualType thisTy = GetParameterDerivativeType(dRetTy, MD->getThisType()); + FnTypes.push_back(thisTy); + } + + // Iterate over all but the "this" type and extend the signature to add the + // extra parameters. + for (size_t i = 0, e = FnTypes.size() - HasThis - HasRet; i < e; ++i) { + QualType PVDTy = FnTypes[i]; + // Check if (IsDifferentiableType(PVDTy)) + // FIXME: We can't use std::find(DVI.begin(), DVI.end()) because the + // operator== considers params and intervals as different entities and + // breaks the hessian tests. We should implement more robust checks in + // DiffInputVarInfo to check if this is a variable we differentiate wrt. + for (const DiffInputVarInfo& VarInfo : m_DiffReq.DVI) + if (VarInfo.param == FD->getParamDecl(i)) + FnTypes.push_back(ComputeParamType(PVDTy)); } - for (auto* PVD : m_DiffReq->parameters()) { + if (m_ExternalSource) + m_ExternalSource->ActAfterCreatingDerivedFnParamTypes(FnTypes); + + FunctionProtoType::ExtProtoInfo EPI = FnProtoTy->getExtProtoInfo(); + return m_Context.getFunctionType(m_Context.VoidTy, FnTypes, EPI); + ; + } + + void + ReverseModeVisitor::BuildParams(llvm::SmallVectorImpl& params) { + const FunctionDecl* FD = m_DiffReq.Function; + for (ParmVarDecl* PVD : FD->parameters()) { IdentifierInfo* PVDII = PVD->getIdentifier(); // Implicitly created special member functions have no parameter names. if (!PVD->getDeclName()) PVDII = CreateUniqueIdentifier("arg"); + auto* newPVD = CloneParmVarDecl(PVD, PVDII, /*pushOnScopeChains=*/true, /*cloneDefaultArg=*/false); + + // Point m_IndependentVars to the argument of the newly created param. + m_IndependentVars.push_back(newPVD); + if (!PVD->getDeclName()) // We can't use lookup-based replacements m_DeclReplacements[PVD] = newPVD; params.push_back(newPVD); + } - auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD); - if (it != std::end(diffParams)) { - *it = newPVD; - if (m_DiffReq.Mode == DiffMode::reverse || - m_DiffReq.Mode == DiffMode::experimental_pullback) { - QualType dType = derivativeFnType->getParamType(dParamTypesIdx); - IdentifierInfo* dII = - CreateUniqueIdentifier("_d_" + newPVD->getNameAsString()); - auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType, - PVD->getStorageClass()); - paramDerivatives.push_back(dPVD); - ++dParamTypesIdx; - - if (dPVD->getIdentifier()) - m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), - /*AddToContext=*/false); - - if (utils::isArrayOrPointerType(PVD->getType())) { - m_Variables[*it] = (Expr*)BuildDeclRef(dPVD); - } else { - QualType valueType = dPVD->getType()->getPointeeType(); - m_Variables[*it] = - BuildOp(UO_Deref, BuildDeclRef(dPVD), m_DiffReq->getLocation()); - // Add additional paranthesis if derivative is of record type - // because `*derivative.someField` will be incorrectly evaluated if - // the derived function is compiled standalone. - if (valueType->isRecordType()) - m_Variables[*it] = - utils::BuildParenExpr(m_Sema, m_Variables[*it]); - } - } + bool HasRet = false; + // FIXME: We ignore the pointer return type for pullbacks. + QualType dRetTy = FD->getReturnType().getNonReferenceType(); + if (m_DiffReq.Mode == DiffMode::experimental_pullback && + !dRetTy->isVoidType() && !dRetTy->isPointerType()) { + auto paramNameExists = [¶ms](llvm::StringRef name) { + for (ParmVarDecl* PVD : params) + if (PVD->getName() == name) + return true; + return false; + }; + + // Make sure that we have no other parameter with the same name. + // FIXME: This is to avoid changing a lot of tests which for some reason + // add d_y when passing the return type value. We should probably not pick + // a more appropriate name. + std::string identifier = "y"; + for (unsigned idx = 0;; ++idx) { + if (idx) + identifier += std::to_string(idx - 1); + if (!paramNameExists(identifier)) + break; } + IdentifierInfo* II = &m_Context.Idents.get("_d_" + identifier); + ParmVarDecl* retPVD = + utils::BuildParmVarDecl(m_Sema, m_Derivative, II, dRetTy); + m_Sema.PushOnScopeChains(retPVD, getCurrentScope(), + /*AddToContext=*/false); + + params.push_back(retPVD); + m_Pullback = BuildDeclRef(retPVD); + HasRet = true; } - if (m_DiffReq.Mode == DiffMode::experimental_pullback && - !m_DiffReq->getReturnType()->isVoidType() && - !m_DiffReq->getReturnType()->isPointerType()) { - IdentifierInfo* pullbackParamII = CreateUniqueIdentifier("_d_y"); - QualType pullbackType = - derivativeFnType->getParamType(m_DiffReq->getNumParams()); - ParmVarDecl* pullbackPVD = utils::BuildParmVarDecl( - m_Sema, m_Derivative, pullbackParamII, pullbackType); - paramDerivatives.insert(paramDerivatives.begin(), pullbackPVD); - - if (pullbackPVD->getIdentifier()) - m_Sema.PushOnScopeChains(pullbackPVD, getCurrentScope(), - /*AddToContext=*/false); - - m_Pullback = BuildDeclRef(pullbackPVD); - ++dParamTypesIdx; + bool HasThis = needsDThis(FD); + // If we are differentiating an instance member function then create a + // parameter for representing derivative of `this` pointer with respect to + // the independent parameter. + if (HasThis) { + IdentifierInfo* dThisII = &m_Context.Idents.get("_d_this"); + const auto* MD = cast(FD); + QualType thisTy = GetParameterDerivativeType(dRetTy, MD->getThisType()); + + auto* dPVD = + utils::BuildParmVarDecl(m_Sema, m_Sema.CurContext, dThisII, thisTy); + m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), /*AddToContext=*/false); + params.push_back(dPVD); + // FIXME: Replace m_ThisExprDerivative in favor of lookups of _d_this. + // This can instantiate an array_ref and needs a fake source location. + SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema); + Expr* deref = + BuildOp(UnaryOperatorKind::UO_Deref, BuildDeclRef(dPVD), fakeLoc); + m_ThisExprDerivative = utils::BuildParenExpr(m_Sema, deref); + // m_ThisExprDerivative = BuildDeclRef(dPVD); } - params.insert(params.end(), paramDerivatives.begin(), - paramDerivatives.end()); - // FIXME: If we do not consider diffParams as an independent argument for - // jacobian mode, then we should keep diffParams list empty for jacobian - // mode and thus remove the if condition. - if (m_DiffReq.Mode == DiffMode::reverse || - m_DiffReq.Mode == DiffMode::experimental_pullback) - m_IndependentVars.insert(m_IndependentVars.end(), diffParams.begin(), - diffParams.end()); - return params; + const auto* FnType = cast(m_Derivative->getType()); + for (size_t i = 0, s = params.size(), p = s; i < s - HasThis - HasRet; + ++i) { + const ParmVarDecl* oPVD = FD->getParamDecl(i); + + // FIXME: We can't use std::find(DVI.begin(), DVI.end()) because the + // operator== considers params and intervals as different entities and + // breaks the hessian tests. We should implement more robust checks in + // DiffInputVarInfo to check if this is a variable we differentiate wrt. + bool IsSelected = false; + for (const DiffInputVarInfo& VarInfo : m_DiffReq.DVI) { + if (VarInfo.param == oPVD) { + IsSelected = true; + break; + } + } + + if (!IsSelected) + continue; + + const ParmVarDecl* PVD = params[i]; + IdentifierInfo* II = + CreateUniqueIdentifier("_d_" + PVD->getNameAsString()); + QualType dPVDTy = FnType->getParamType(p++); + auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, II, dPVDTy, + PVD->getStorageClass()); + m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), /*AddToContext=*/false); + // Ensure that parameters passed by value are always dereferenced on use. + // For example d_x in f(float x, float *d_x) should be used as *d_x. + if (utils::isArrayOrPointerType(oPVD->getType())) { + m_Variables[PVD] = BuildDeclRef(dPVD); + + } else { + Expr* Deref = + BuildOp(UO_Deref, BuildDeclRef(dPVD), oPVD->getLocation()); + if (dPVDTy->getPointeeType()->isRecordType()) + Deref = utils::BuildParenExpr(m_Sema, Deref); + m_Variables[PVD] = Deref; + } + + params.push_back(dPVD); + } } Expr* ReverseModeVisitor::BuildCallToCustomForwPassFn( diff --git a/test/Hessian/NestedFunctionCalls.C b/test/Hessian/NestedFunctionCalls.C index b5c784d84..200b382d9 100644 --- a/test/Hessian/NestedFunctionCalls.C +++ b/test/Hessian/NestedFunctionCalls.C @@ -36,7 +36,7 @@ double f2(double x, double y){ // CHECK-NEXT: return _d_ans; // CHECK-NEXT: } -// CHECK: void f_pushforward_pullback(double x, double y, double _d_x, double _d_y, clad::ValueAndPushforward _d_y0, double *_d_x, double *_d_y, double *_d__d_x, double *_d__d_y); +// CHECK: void f_pushforward_pullback(double x, double y, double _d_x, double _d_y, clad::ValueAndPushforward _d_y0, double *_d_x0, double *_d_y1, double *_d__d_x, double *_d__d_y); // CHECK: void f2_darg0_grad(double x, double y, double *_d_x, double *_d_y) { // CHECK-NEXT: double _d__d_x = 0.; @@ -109,19 +109,19 @@ double f2(double x, double y){ // CHECK-NEXT: return {x * x + y * y, _d_x * x + x * _d_x + _d_y * y + y * _d_y}; // CHECK-NEXT: } -// CHECK: void f_pushforward_pullback(double x, double y, double _d_x, double _d_y, clad::ValueAndPushforward _d_y0, double *_d_x, double *_d_y, double *_d__d_x, double *_d__d_y) { +// CHECK: void f_pushforward_pullback(double x, double y, double _d_x, double _d_y, clad::ValueAndPushforward _d_y0, double *_d_x0, double *_d_y1, double *_d__d_x, double *_d__d_y) { // CHECK-NEXT: { -// CHECK-NEXT: *_d_x += _d_y0.value * x; -// CHECK-NEXT: *_d_x += x * _d_y0.value; -// CHECK-NEXT: *_d_y += _d_y0.value * y; -// CHECK-NEXT: *_d_y += y * _d_y0.value; +// CHECK-NEXT: *_d_x0 += _d_y0.value * x; +// CHECK-NEXT: *_d_x0 += x * _d_y0.value; +// CHECK-NEXT: *_d_y1 += _d_y0.value * y; +// CHECK-NEXT: *_d_y1 += y * _d_y0.value; // CHECK-NEXT: *_d__d_x += _d_y0.pushforward * x; -// CHECK-NEXT: *_d_x += _d_x * _d_y0.pushforward; -// CHECK-NEXT: *_d_x += _d_y0.pushforward * _d_x; +// CHECK-NEXT: *_d_x0 += _d_x * _d_y0.pushforward; +// CHECK-NEXT: *_d_x0 += _d_y0.pushforward * _d_x; // CHECK-NEXT: *_d__d_x += x * _d_y0.pushforward; // CHECK-NEXT: *_d__d_y += _d_y0.pushforward * y; -// CHECK-NEXT: *_d_y += _d_y * _d_y0.pushforward; -// CHECK-NEXT: *_d_y += _d_y0.pushforward * _d_y; +// CHECK-NEXT: *_d_y1 += _d_y * _d_y0.pushforward; +// CHECK-NEXT: *_d_y1 += _d_y0.pushforward * _d_y; // CHECK-NEXT: *_d__d_y += y * _d_y0.pushforward; // CHECK-NEXT: } // CHECK-NEXT: }