From 19104bdcd26a78ad0b5a1c98dcff331e7e3249f2 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Thu, 28 Nov 2024 20:17:23 +0100 Subject: [PATCH] Support member calls with xvalue bases in the reverse mode. Currently, we assume that the base of a member call always has a clear memory location (that it can be referenced with a variable). This is true for lvalues, e.g. ``` obj1.mem_fn_1(i, j); // can be referenced with `obj/_d_obj` ``` However, not true for xvalues, e.g. ``` (obj1 + obj2).mem_fn_1(i, j); ``` In the latter case, we need to define an additional variable to store the derivative of ``obj1 + obj2``. Fixes #917 --- lib/Differentiator/ReverseModeVisitor.cpp | 15 ++++- test/Gradient/UserDefinedTypes.C | 79 +++++++++++++++++++++++ 2 files changed, 91 insertions(+), 3 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 62b7eeafa..afb3d8045 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1715,10 +1715,19 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) { baseOriginalE = MCE->getImplicitObjectArgument(); else if (const auto* OCE = dyn_cast(CE)) baseOriginalE = OCE->getArg(0); - - baseDiff = Visit(baseOriginalE); + if (baseOriginalE->isXValue()) { + QualType dBaseTy = getNonConstType(baseOriginalE->getType(), m_Context, m_Sema); + VarDecl* dBaseDecl = BuildVarDecl(dBaseTy, "_r", getZeroInit(dBaseTy)); + PreCallStmts.push_back(BuildDeclStmt(dBaseDecl)); + DeclRefExpr* dBaseRef = BuildDeclRef(dBaseDecl); + baseDiff = Visit(baseOriginalE, dBaseRef); + baseDiff.updateStmtDx(Clone(dBaseRef)); + } else + baseDiff = Visit(baseOriginalE); baseExpr = baseDiff.getExpr(); - Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr()); + Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr(), "_t", /*force=*/true); + if (baseOriginalE->isXValue()) + baseExpr = baseDiffStore; baseDiff.updateStmt(baseDiffStore); Expr* baseDerivative = baseDiff.getExpr_dx(); if (!baseDerivative->getType()->isPointerType()) diff --git a/test/Gradient/UserDefinedTypes.C b/test/Gradient/UserDefinedTypes.C index 0623e8cd1..d17ecc82e 100644 --- a/test/Gradient/UserDefinedTypes.C +++ b/test/Gradient/UserDefinedTypes.C @@ -518,6 +518,57 @@ double fn15(double x, double y) { // CHECK-NEXT: _d_arr.elements[0] += 1; // CHECK-NEXT:} +class SimpleFunctions1 { +public: + SimpleFunctions1() noexcept : x(0), y(0) {} + SimpleFunctions1(double p_x, double p_y) noexcept : x(p_x), y(p_y) {} + double x; + double y; + double mem_fn_1(double i, double j) { return (x + y) * i + i * j * j; } + SimpleFunctions1 operator+(const SimpleFunctions1& other) const { + SimpleFunctions1 res(x + other.x, y + other.y); + return res; + } +}; + +// CHECK: void operator_plus_pullback(const SimpleFunctions1 &other, SimpleFunctions1 _d_y, SimpleFunctions1 *_d_this, SimpleFunctions1 *_d_other) const; + +// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j); + +double fn16(double i, double j) { + SimpleFunctions1 obj1(2, 3); + SimpleFunctions1 obj2(3, 5); + return (obj1 + obj2).mem_fn_1(i, j); +} + +// CHECK: void fn16_grad(double i, double j, double *_d_i, double *_d_j) { +// CHECK-NEXT: SimpleFunctions1 obj1(2, 3); +// CHECK-NEXT: SimpleFunctions1 _d_obj1(obj1); +// CHECK-NEXT: clad::zero_init(_d_obj1); +// CHECK-NEXT: SimpleFunctions1 obj2(3, 5); +// CHECK-NEXT: SimpleFunctions1 _d_obj2(obj2); +// CHECK-NEXT: clad::zero_init(_d_obj2); +// CHECK-NEXT: SimpleFunctions1 _t0 = obj1; +// CHECK-NEXT: SimpleFunctions1 _t1 = obj1.operator+(obj2); +// CHECK-NEXT: { +// CHECK-NEXT: double _r4 = 0.; +// CHECK-NEXT: double _r5 = 0.; +// CHECK-NEXT: SimpleFunctions1 _r6 = {}; +// CHECK-NEXT: _t1.mem_fn_1_pullback(i, j, 1, &_r6, &_r4, &_r5); +// CHECK-NEXT: *_d_i += _r4; +// CHECK-NEXT: *_d_j += _r5; +// CHECK-NEXT: _t0.operator_plus_pullback(obj2, _r6, &_d_obj1, &_d_obj2); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r2 = 0.; +// CHECK-NEXT: double _r3 = 0.; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 0.; +// CHECK-NEXT: double _r1 = 0.; +// CHECK-NEXT: } +// CHECK-NEXT:} + void print(const Tangent& t) { for (int i = 0; i < 5; ++i) { printf("%.2f", t.data[i]); @@ -591,6 +642,9 @@ int main() { TEST_GRADIENT(fn14, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {30.00, 22.00} INIT_GRADIENT(fn15); + + INIT_GRADIENT(fn16); + TEST_GRADIENT(fn16, /*numOfDerivativeArgs=*/2, 2, 3, &d_i, &d_j); // CHECK-EXEC: {22.00, 12.00} } // CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) { @@ -739,4 +793,29 @@ int main() { // CHECK-NEXT: double _t1 = this->b; // CHECK-NEXT: this->b = arg.b; // CHECK-NEXT: return {*this, (*_d_this)}; +// CHECK-NEXT:} + +// CHECK: void operator_plus_pullback(const SimpleFunctions1 &other, SimpleFunctions1 _d_y, SimpleFunctions1 *_d_this, SimpleFunctions1 *_d_other) const { +// CHECK-NEXT: SimpleFunctions1 res(this->x + other.x, this->y + other.y); +// CHECK-NEXT: SimpleFunctions1 _d_res(res); +// CHECK-NEXT: clad::zero_init(_d_res); +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 0.; +// CHECK-NEXT: double _r1 = 0.; +// CHECK-NEXT: (*_d_this).x += _r0; +// CHECK-NEXT: (*_d_other).x += _r0; +// CHECK-NEXT: (*_d_this).y += _r1; +// CHECK-NEXT: (*_d_other).y += _r1; +// CHECK-NEXT: } +// CHECK-NEXT:} + +// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { +// CHECK-NEXT: { +// CHECK-NEXT: (*_d_this).x += _d_y * i; +// CHECK-NEXT: (*_d_this).y += _d_y * i; +// CHECK-NEXT: *_d_i += (this->x + this->y) * _d_y; +// CHECK-NEXT: *_d_i += _d_y * j * j; +// CHECK-NEXT: *_d_j += i * _d_y * j; +// CHECK-NEXT: *_d_j += i * j * _d_y; +// CHECK-NEXT: } // CHECK-NEXT:} \ No newline at end of file