From 938e0efcea7874918bd0dcd2e1d5c592045b7749 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Sat, 21 Dec 2024 14:57:07 +0100 Subject: [PATCH] Support std::initializer_list parameters in the reverse mode. Previously, we replaced ``std::initializer_list`` variables with ``clad::array`` in the reverse mode so that they become modifiable. This PR moves the logic for type replacement from ``RMV::DifferentiateVarDecl`` to a``RMV::CloneType`` to use it to handle constructors with ``std::initializer_list`` parameters. Fixes #1082. --- .../clad/Differentiator/ReverseModeVisitor.h | 8 ++ include/clad/Differentiator/STLBuiltins.h | 25 ++++++ lib/Differentiator/ReverseModeVisitor.cpp | 89 +++++++++++-------- test/Gradient/STLCustomDerivatives.C | 32 ++++++- 4 files changed, 117 insertions(+), 37 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 2d0141810..5fee7c282 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -200,6 +200,14 @@ namespace clad { /// \returns True if the statement was added to the block, false otherwise. bool AddToGlobalBlock(clang::Stmt* S) { return addToBlock(S, m_Globals); } + /// Updates size references in VariableArrayType and replaces + /// std::initializer_list with clad::array. + clang::QualType CloneType(clang::QualType T); + + /// If E is a CXXSTDInializerListExpr, returns its size expr. + /// Otherwise, returns nullptr; + clang::Expr* getStdInitListSizeExpr(const clang::Expr* E); + /// Stores the result of an expression in a temporary variable (of the same /// type as is the result of the expression) and returns a reference to it. /// If force decl creation is true, this will allways create a temporary diff --git a/include/clad/Differentiator/STLBuiltins.h b/include/clad/Differentiator/STLBuiltins.h index d6b169ec0..35dcc53fc 100644 --- a/include/clad/Differentiator/STLBuiltins.h +++ b/include/clad/Differentiator/STLBuiltins.h @@ -2,6 +2,7 @@ #define CLAD_STL_BUILTINS_H #include +#include #include #include #include @@ -463,6 +464,19 @@ constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::vector>, return {v, d_v}; } +template +::clad::ValueAndAdjoint<::std::vector, ::std::vector> +constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::vector>, + const clad::array& list, + const clad::array& d_list) { + ::std::vector v(list.size()); + const T* iter = list.begin(); + for (T& el : v) + el = *(iter++); + ::std::vector d_v(list.size(), 0); + return {v, d_v}; +} + template void constructor_pullback(::std::vector* v, S count, U val, typename ::std::vector::allocator_type alloc, @@ -473,6 +487,17 @@ void constructor_pullback(::std::vector* v, S count, U val, d_v->clear(); } +// A specialization for std::initializer_list (which is replaced with +// clad::array). +template +void constructor_pullback(::std::vector* v, clad::array init, + ::std::vector* d_v, clad::array* d_init) { + for (unsigned i = 0; i < init.size(); ++i) { + (*d_init)[i] += (*d_v)[i]; + (*d_v)[i] = 0; + } +} + template void assign_pullback(::std::vector* v, typename ::std::vector::size_type n, U /*val*/, diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 3d4b5245a..98604455a 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -60,6 +60,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return nullptr; } +Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) { + if (E) + if (const auto* CXXILE = + dyn_cast(E->IgnoreImplicit())) + if (const auto* ILE = + dyn_cast(CXXILE->getSubExpr()->IgnoreImplicit())) { + unsigned numInits = ILE->getNumInits(); + return ConstantFolder::synthesizeLiteral(m_Context.getSizeType(), + m_Context, numInits); + } + return nullptr; +} + Expr* ReverseModeVisitor::CladTapeResult::Last() { LookupResult& Back = V.GetCladTapeBack(); CXXScopeSpec CSS; @@ -2501,6 +2514,25 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(op, ResultRef, nullptr, valueForRevPass); } + QualType ReverseModeVisitor::CloneType(QualType T) { + QualType dT = VisitorBase::CloneType(T); + + bool isLValueRefType = dT->isLValueReferenceType(); + dT = dT.getNonReferenceType(); + + // We need to replace std::initializer_list with clad::array because the + // former is temporary by design and it's not possible to create modifiable + // adjoints. + QualType elemType; + if (m_Sema.isStdInitializerList(utils::GetValueType(T), &elemType)) + dT = GetCladArrayOfType(elemType); + + if (isLValueRefType) + return m_Context.getLValueReferenceType(dT); + + return dT; + } + DeclDiff ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD, bool keepLocal) { StmtDiff initDiff; @@ -2516,6 +2548,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, QualType VDCloneType; QualType VDDerivedType; QualType VDType = VD->getType(); + VarDecl::InitializationStyle VDStyle = VD->getInitStyle(); // If the cloned declaration is moved to the function global scope, // change its type for the corresponding adjoint type. if (promoteToFnScope) { @@ -2535,37 +2568,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, bool isInitializedByNewExpr = false; bool initializeDerivedVar = true; - // We need to replace std::initializer_list with clad::array because the - // former is temporary by design and it's not possible to create modifiable - // adjoints. - if (m_Sema.isStdInitializerList(utils::GetValueType(VDType), - /*Element=*/nullptr)) { - if (const Expr* init = VD->getInit()) { - if (const auto* CXXILE = - dyn_cast(init->IgnoreImplicit())) { - if (const auto* ILE = dyn_cast( - CXXILE->getSubExpr()->IgnoreImplicit())) { - VDDerivedType = - GetCladArrayOfType(ILE->getInit(/*Init=*/0)->getType()); - unsigned numInits = ILE->getNumInits(); - VDDerivedInit = ConstantFolder::synthesizeLiteral( - m_Context.getSizeType(), m_Context, numInits); - VDCloneType = VDDerivedType; - } - } else if (isRefType) { - initDiff = Visit(init); - if (promoteToFnScope) { - VDDerivedInit = BuildOp(UO_AddrOf, initDiff.getExpr_dx()); - VDDerivedType = VDDerivedInit->getType(); - } else { - VDDerivedInit = initDiff.getExpr_dx(); - VDDerivedType = - m_Context.getLValueReferenceType(VDDerivedInit->getType()); - } - VDCloneType = VDDerivedType; - } - } - } + if (Expr* size = getStdInitListSizeExpr(VD->getInit())) + VDDerivedInit = size; // Check if the variable is pointer type and initialized by new expression if (isPointerType && VD->getInit() && isa(VD->getInit())) @@ -2629,6 +2633,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, resetConstructorPullbackCallInfo(); if (initDiff.getForwSweepExpr_dx()) VDDerivedInit = initDiff.getForwSweepExpr_dx(); + // ListInit style combined with `_t0.value`/`_t0.adjoint` inits will be + // displayed incorrectly. + if (VDStyle == VarDecl::InitializationStyle::ListInit) + VDStyle = VarDecl::InitializationStyle::CallInit; } // FIXME: Remove the special cases introduced by `specialThisDiffCase` @@ -2675,7 +2683,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (initializeDerivedVar) VDDerived = BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), - VDDerivedInit, false, nullptr, VD->getInitStyle()); + VDDerivedInit, false, nullptr, VDStyle); if (!m_DiffReq.shouldHaveAdjoint((VD))) VDDerived = nullptr; @@ -2758,7 +2766,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, else VDClone = BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(), initDiff.getExpr(), VD->isDirectInit(), - nullptr, VD->getInitStyle()); + nullptr, VDStyle); if (isPointerType && derivedVDE) { if (promoteToFnScope) { Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign, @@ -3108,7 +3116,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff subExprDiff = Visit(EWC->getSubExpr(), dfdx()); // FIXME: We are unable to create cleanup objects currently, this can be // potentially problematic - return StmtDiff(subExprDiff.getExpr(), subExprDiff.getExpr_dx()); + return StmtDiff(subExprDiff.getStmt(), subExprDiff.getStmt_dx(), + subExprDiff.getForwSweepStmt_dx(), + subExprDiff.getRevSweepStmt()); } bool ReverseModeVisitor::ShouldRecompute(const Expr* E) { @@ -3986,6 +3996,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // FIXME: Restore arguments passed as non-const reference. for (const auto* arg : CE->arguments()) { + // FIXME: Use this workaround to support some custom constructors. + // Remove when default arguments are supported. + if (isa(arg->IgnoreImplicit())) + break; QualType ArgTy = arg->getType(); StmtDiff argDiff{}; Expr* adjointArg = nullptr; @@ -4008,8 +4022,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // double _r0 = 0; // SomeClass_pullback(c, u, ..., &_d_c, &_r0, ...); // _d_u += _r0; - QualType dArgTy = getNonConstType(ArgTy, m_Context, m_Sema); - VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy)); + QualType dArgTy = getNonConstType(CloneType(ArgTy), m_Context, m_Sema); + Expr* init = getStdInitListSizeExpr(arg); + if (!init) + init = getZeroInit(dArgTy); + VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", init); prePullbackCallStmts.push_back(BuildDeclStmt(dArgDecl)); adjointArg = BuildDeclRef(dArgDecl); argDiff = Visit(arg, BuildDeclRef(dArgDecl)); diff --git a/test/Gradient/STLCustomDerivatives.C b/test/Gradient/STLCustomDerivatives.C index 53bf46441..937b1939a 100644 --- a/test/Gradient/STLCustomDerivatives.C +++ b/test/Gradient/STLCustomDerivatives.C @@ -184,6 +184,11 @@ double fn21(double x, double y) { return a[0]; } +double fn22(double u, double v) { + std::vector ls{u, v}; + return ls[1] - 2 * ls[0]; +} + int main() { double d_i, d_j; INIT_GRADIENT(fn10); @@ -198,6 +203,7 @@ int main() { INIT_GRADIENT(fn19); INIT_GRADIENT(fn20); INIT_GRADIENT(fn21); + INIT_GRADIENT(fn22); TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00} TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00} @@ -211,6 +217,7 @@ int main() { TEST_GRADIENT(fn19, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {3.00, 2.00} TEST_GRADIENT(fn20, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {11.00, 1.00} TEST_GRADIENT(fn21, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {6.00, 0.00} + TEST_GRADIENT(fn22, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {-2.00, 1.00} } // CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) { @@ -840,4 +847,27 @@ int main() { // CHECK-NEXT: {{.*}}value_type _r0 = 0.; // CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r0); // CHECK-NEXT: } -// CHECK-NEXT: } \ No newline at end of file +// CHECK-NEXT: } + +// CHECK: void fn22_grad(double u, double v, double *_d_u, double *_d_v) { +// CHECK-NEXT: {{.*}} _t0 = {{.*}}::class_functions::constructor_reverse_forw(clad::ConstructorReverseForwTag<{{.*}}> >(), {{.*u, v.*}}, {}); +// CHECK-NEXT: std::vector _d_ls(_t0.adjoint); +// CHECK-NEXT: std::vector ls(_t0.value); +// CHECK-NEXT: std::vector _t1 = ls; +// CHECK-NEXT: clad::ValueAndAdjoint _t2 = clad::custom_derivatives::class_functions::operator_subscript_reverse_forw(&ls, 1, &_d_ls, {{0U|0UL|0}}); +// CHECK-NEXT: std::vector _t4 = ls; +// CHECK-NEXT: clad::ValueAndAdjoint _t5 = clad::custom_derivatives::class_functions::operator_subscript_reverse_forw(&ls, 0, &_d_ls, {{0U|0UL|0}}); +// CHECK-NEXT: {{.*}}value_type _t3 = _t5.value; +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}}size_type _r1 = 0{{.*}}; +// CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&_t1, 1, 1, &_d_ls, &_r1); +// CHECK-NEXT: {{.*}}size_type _r2 = 0{{.*}}; +// CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&_t4, 0, 2 * -1, &_d_ls, &_r2); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::array _r0 = {{2U|2UL|2ULL}}; +// CHECK-NEXT: clad::custom_derivatives::class_functions::constructor_pullback(&ls, {u, v}, &_d_ls, &_r0); +// CHECK-NEXT: *_d_u += _r0[0]; +// CHECK-NEXT: *_d_v += _r0[1]; +// CHECK-NEXT: } +// CHECK-NEXT: }