diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 1a82108ec..14181682e 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -8,9 +8,11 @@ #define CLAD_REVERSE_MODE_VISITOR_H #include "clad/Differentiator/Compatibility.h" -#include "clad/Differentiator/VisitorBase.h" -#include "clad/Differentiator/ReverseModeVisitorDirectionKinds.h" #include "clad/Differentiator/ParseDiffArgsTypes.h" +#include "clad/Differentiator/ReverseModeVisitorDirectionKinds.h" +#include "clad/Differentiator/VisitorBase.h" + +#include "clang/AST/ExprCXX.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/StmtVisitor.h" #include "clang/Sema/Sema.h" @@ -370,6 +372,7 @@ namespace clad { virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO); StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL); + StmtDiff VisitCXXBindTemporaryExpr(const clang::CXXBindTemporaryExpr* BTE); StmtDiff VisitCharacterLiteral(const clang::CharacterLiteral* CL); StmtDiff VisitStringLiteral(const clang::StringLiteral* SL); StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE); @@ -403,6 +406,8 @@ namespace clad { StmtDiff VisitBreakStmt(const clang::BreakStmt* BS); StmtDiff VisitCXXStdInitializerListExpr(const clang::CXXStdInitializerListExpr* ILE); + StmtDiff + VisitCXXTemporaryObjectExpr(const clang::CXXTemporaryObjectExpr* TOE); StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE); StmtDiff VisitCXXNewExpr(const clang::CXXNewExpr* CNE); StmtDiff VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE); diff --git a/include/clad/Differentiator/STLBuiltins.h b/include/clad/Differentiator/STLBuiltins.h index 87e136f16..a52dd9f8f 100644 --- a/include/clad/Differentiator/STLBuiltins.h +++ b/include/clad/Differentiator/STLBuiltins.h @@ -477,6 +477,17 @@ void constructor_pullback( } } +// A specialization for std::initializer_list (which is replaced with +// clad::array). +template +void constructor_pullback(::std::vector* v, clad::array init, + ::std::vector* d_v, clad::array* d_init) { + for (unsigned i = 0; i < init.size(); ++i) { + (*d_init)[i] += (*d_v)[i]; + (*d_v)[i] = 0; + } +} + template void assign_pullback(::std::vector* v, typename ::std::vector::size_type n, U /*val*/, diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 346a677a9..d224c8750 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -323,7 +323,8 @@ namespace clad { // The argument is passed by reference if it's passed as an L-value. // However, if arg is a MaterializeTemporaryExpr, then arg is a // temporary variable passed as a const reference. - bool isRefType = arg->isLValue() && !isa(arg); + bool isRefType = arg->isLValue() && !isa(arg) && + !isa(arg); return isRefType || isArrayOrPointerType(arg->getType()); } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6d0afbd73..d7eb0b5f9 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -41,6 +41,7 @@ #include "llvm/Support/SaveAndRestore.h" #include #include +#include #include #include @@ -4021,6 +4022,16 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) { return {clonedCTE, m_ThisExprDerivative}; } + StmtDiff ReverseModeVisitor::VisitCXXTemporaryObjectExpr( + const clang::CXXTemporaryObjectExpr* TOE) { + return Clone(TOE); + } + + StmtDiff ReverseModeVisitor::VisitCXXBindTemporaryExpr( + const clang::CXXBindTemporaryExpr* BTE) { + return Visit(BTE->getSubExpr(), dfdx()); + } + StmtDiff ReverseModeVisitor::VisitCXXNewExpr(const clang::CXXNewExpr* CNE) { StmtDiff initializerDiff; if (CNE->hasInitializer()) @@ -4142,29 +4153,33 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) { // created. The caller which triggers 'VisitCXXConstructExpr' is // responsible for updating these args. Expr* thisE = getZeroInit(recordPointerType); - Expr* dThisE = getZeroInit(recordPointerType); - if (!m_TrackConstructorPullbackInfo && dfdx()) + Expr* dThisE = nullptr; + if (m_TrackConstructorPullbackInfo) + dThisE = getZeroInit(recordPointerType); + else if (dfdx()) dThisE = BuildOp(UnaryOperatorKind::UO_AddrOf, dfdx(), m_DiffReq->getLocation()); - pullbackArgs.push_back(thisE); - pullbackArgs.append(primalArgs.begin(), primalArgs.end()); - pullbackArgs.push_back(dThisE); - pullbackArgs.append(adjointArgs.begin(), adjointArgs.end()); - - Stmts& curRevBlock = getCurrentBlock(direction::reverse); - Stmts::iterator it = std::begin(curRevBlock) + insertionPoint; - curRevBlock.insert(it, prePullbackCallStmts.begin(), - prePullbackCallStmts.end()); - it += prePullbackCallStmts.size(); - std::string customPullbackName = "constructor_pullback"; - if (Expr* customPullbackCall = - m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( - customPullbackName, pullbackArgs, getCurrentScope(), CE)) { - curRevBlock.insert(it, customPullbackCall); - if (m_TrackConstructorPullbackInfo) { - setConstructorPullbackCallInfo(llvm::cast(customPullbackCall), - primalArgs.size() + 1); + if (dThisE) { + pullbackArgs.push_back(thisE); + pullbackArgs.append(primalArgs.begin(), primalArgs.end()); + pullbackArgs.push_back(dThisE); + pullbackArgs.append(adjointArgs.begin(), adjointArgs.end()); + + Stmts& curRevBlock = getCurrentBlock(direction::reverse); + Stmts::iterator it = std::begin(curRevBlock) + insertionPoint; + curRevBlock.insert(it, prePullbackCallStmts.begin(), + prePullbackCallStmts.end()); + it += prePullbackCallStmts.size(); + std::string customPullbackName = "constructor_pullback"; + if (Expr* customPullbackCall = + m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( + customPullbackName, pullbackArgs, getCurrentScope(), CE)) { + curRevBlock.insert(it, customPullbackCall); + if (m_TrackConstructorPullbackInfo) { + setConstructorPullbackCallInfo( + llvm::cast(customPullbackCall), primalArgs.size() + 1); + } } } // FIXME: If no compatible custom constructor pullback is found then try diff --git a/test/Gradient/STLCustomDerivatives.C b/test/Gradient/STLCustomDerivatives.C index 8702eab4a..56a8d4217 100644 --- a/test/Gradient/STLCustomDerivatives.C +++ b/test/Gradient/STLCustomDerivatives.C @@ -218,6 +218,15 @@ double fn25(double u, double v) { return prod; } +double fn26(double u, double v) { + for (int i = 0; i < 3; ++i) { + std::vector ls{u, v}; + ls[1] += ls[0]; + u = ls[1]; + } + return u; +} + int main() { double d_i, d_j; INIT_GRADIENT(fn10); @@ -236,6 +245,7 @@ int main() { INIT_GRADIENT(fn23); INIT_GRADIENT(fn24); INIT_GRADIENT(fn25); + INIT_GRADIENT(fn26); TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00} TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00} @@ -253,6 +263,7 @@ int main() { TEST_GRADIENT(fn23, /*numOfDerivativeArgs=*/2, 1, 1, &d_i, &d_j); // CHECK-EXEC: {1.00, 3.00} TEST_GRADIENT(fn24, /*NumOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 5.00} TEST_GRADIENT(fn25, /*NumOfDerivativeArgs=*/2, 3, 1, &d_i, &d_j); // CHECK-EXEC: {48.00, 48.00} + TEST_GRADIENT(fn26, /*numOfDerivativeArgs=*/2, 1, 1, &d_i, &d_j); // CHECK-EXEC: {1.00, 3.00} } // CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) { @@ -1079,4 +1090,82 @@ int main() { // CHECK-NEXT: vec = clad::pop(_t2); // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: } \ No newline at end of file +// CHECK-NEXT: } + + + +// CHECK: void fn26_grad(double u, double v, double *_d_u, double *_d_v) { +// CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: int i = 0; +// CHECK-NEXT: clad::tape > _t1 = {}; +// CHECK-NEXT: clad::tape > _t2 = {}; +// CHECK-NEXT: std::vector ls = {}; +// CHECK-NEXT: std::vector _d_ls{}; +// CHECK-NEXT: clad::tape > _t3 = {}; +// CHECK-NEXT: clad::tape > _t4 = {}; +// CHECK-NEXT: clad::tape _t5 = {}; +// CHECK-NEXT: clad::tape > _t6 = {}; +// CHECK-NEXT: clad::tape > _t7 = {}; +// CHECK-NEXT: clad::tape _t8 = {}; +// CHECK-NEXT: clad::tape > _t9 = {}; +// CHECK-NEXT: clad::tape > _t10 = {}; +// CHECK-NEXT: unsigned {{int|long|long long}} _t0 = {{0U|0UL|0ULL}}; +// CHECK-NEXT: for (i = 0; ; ++i) { +// CHECK-NEXT: { +// CHECK-NEXT: if (!(i < 3)) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: _t0++; +// CHECK-NEXT: clad::push(_t1, std::move(_d_ls)); +// CHECK-NEXT: clad::push(_t2, std::move(ls)) , ls = {{.*{u, v}.*}}; +// CHECK-NEXT: _d_ls = ls; +// CHECK-NEXT: clad::zero_init(_d_ls); +// CHECK-NEXT: clad::push(_t3, ls); +// CHECK-NEXT: clad::push(_t4, clad::custom_derivatives::class_functions::operator_subscript_reverse_forw(&ls, 1, &_d_ls, {{0U|0UL|0}})); +// CHECK-NEXT: clad::push(_t5, clad::back(_t4).value); +// CHECK-NEXT: clad::push(_t6, ls); +// CHECK-NEXT: clad::push(_t7, clad::custom_derivatives::class_functions::operator_subscript_reverse_forw(&ls, 0, &_d_ls, {{0U|0UL|0}})); +// CHECK-NEXT: clad::back(_t4).value += clad::back(_t7).value; +// CHECK-NEXT: clad::push(_t8, u); +// CHECK-NEXT: clad::push(_t9, ls); +// CHECK-NEXT: clad::push(_t10, clad::custom_derivatives::class_functions::operator_subscript_reverse_forw(&ls, 1, &_d_ls, {{0U|0UL|0}})); +// CHECK-NEXT: u = clad::back(_t10).value; +// CHECK-NEXT: } +// CHECK-NEXT: *_d_u += 1; +// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: { +// CHECK-NEXT: if (!_t0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: --i; +// CHECK-NEXT: { +// CHECK-NEXT: u = clad::pop(_t8); +// CHECK-NEXT: double _r_d1 = *_d_u; +// CHECK-NEXT: *_d_u = 0.; +// CHECK-NEXT: {{.*}}size_type _r{{3|4}} = {{0U|0UL|0ULL}}; +// CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&clad::back(_t9), 1, _r_d1, &_d_ls, &_r{{3|4}}); +// CHECK-NEXT: clad::pop(_t9); +// CHECK-NEXT: clad::pop(_t10); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::back(_t4).value = clad::pop(_t5); +// CHECK-NEXT: double _r_d0 = clad::back(_t4).adjoint; +// CHECK-NEXT: {{.*}}size_type _r{{2|3}} = {{0U|0UL|0ULL}}; +// CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&clad::back(_t6), 0, _r_d0, &_d_ls, &_r{{2|3}}); +// CHECK-NEXT: clad::pop(_t6); +// CHECK-NEXT: clad::pop(_t7); +// CHECK-NEXT: {{.*}}size_type _r{{1|2}} = {{0U|0UL|0ULL}}; +// CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&clad::back(_t3), 1, 0., &_d_ls, &_r{{1|2}}); +// CHECK-NEXT: clad::pop(_t3); +// CHECK-NEXT: clad::pop(_t4); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::array _r0 = {{2U|2UL|2ULL}}; +// CHECK: clad::custom_derivatives::class_functions::constructor_pullback(&ls, {u, v},{{.*}} &_d_ls, &_r0{{.*}}); +// CHECK-NEXT: *_d_u += _r0[0]; +// CHECK-NEXT: *_d_v += _r0[1]; +// CHECK-NEXT: _d_ls = clad::pop(_t1); +// CHECK-NEXT: ls = clad::pop(_t2); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: }