Skip to content

Commit

Permalink
Support member calls with xvalue bases in the reverse mode.
Browse files Browse the repository at this point in the history
Currently, we assume that the base of a member call always has a clear memory location (that it can be referenced with a variable). This is true for lvalues, e.g.
 ```
obj1.mem_fn_1(i, j); // can be referenced with `obj/_d_obj`
```

However, not true for xvalues, e.g.
```
(obj1 + obj2).mem_fn_1(i, j);
```
In the latter case, we need to define an additional variable to store the derivative of ``obj1 + obj2``.
Fixes #917
  • Loading branch information
PetroZarytskyi committed Jan 16, 2025
1 parent c8ce282 commit 19104bd
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 3 deletions.
15 changes: 12 additions & 3 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1715,10 +1715,19 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
baseOriginalE = MCE->getImplicitObjectArgument();
else if (const auto* OCE = dyn_cast<CXXOperatorCallExpr>(CE))
baseOriginalE = OCE->getArg(0);

baseDiff = Visit(baseOriginalE);
if (baseOriginalE->isXValue()) {
QualType dBaseTy = getNonConstType(baseOriginalE->getType(), m_Context, m_Sema);
VarDecl* dBaseDecl = BuildVarDecl(dBaseTy, "_r", getZeroInit(dBaseTy));
PreCallStmts.push_back(BuildDeclStmt(dBaseDecl));
DeclRefExpr* dBaseRef = BuildDeclRef(dBaseDecl);
baseDiff = Visit(baseOriginalE, dBaseRef);
baseDiff.updateStmtDx(Clone(dBaseRef));
} else
baseDiff = Visit(baseOriginalE);
baseExpr = baseDiff.getExpr();
Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr());
Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr(), "_t", /*force=*/true);
if (baseOriginalE->isXValue())
baseExpr = baseDiffStore;
baseDiff.updateStmt(baseDiffStore);
Expr* baseDerivative = baseDiff.getExpr_dx();
if (!baseDerivative->getType()->isPointerType())
Expand Down
79 changes: 79 additions & 0 deletions test/Gradient/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,57 @@ double fn15(double x, double y) {
// CHECK-NEXT: _d_arr.elements[0] += 1;
// CHECK-NEXT:}

class SimpleFunctions1 {
public:
SimpleFunctions1() noexcept : x(0), y(0) {}
SimpleFunctions1(double p_x, double p_y) noexcept : x(p_x), y(p_y) {}
double x;
double y;
double mem_fn_1(double i, double j) { return (x + y) * i + i * j * j; }
SimpleFunctions1 operator+(const SimpleFunctions1& other) const {
SimpleFunctions1 res(x + other.x, y + other.y);
return res;
}
};

// CHECK: void operator_plus_pullback(const SimpleFunctions1 &other, SimpleFunctions1 _d_y, SimpleFunctions1 *_d_this, SimpleFunctions1 *_d_other) const;

// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j);

double fn16(double i, double j) {
SimpleFunctions1 obj1(2, 3);
SimpleFunctions1 obj2(3, 5);
return (obj1 + obj2).mem_fn_1(i, j);
}

// CHECK: void fn16_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 obj1(2, 3);
// CHECK-NEXT: SimpleFunctions1 _d_obj1(obj1);
// CHECK-NEXT: clad::zero_init(_d_obj1);
// CHECK-NEXT: SimpleFunctions1 obj2(3, 5);
// CHECK-NEXT: SimpleFunctions1 _d_obj2(obj2);
// CHECK-NEXT: clad::zero_init(_d_obj2);
// CHECK-NEXT: SimpleFunctions1 _t0 = obj1;
// CHECK-NEXT: SimpleFunctions1 _t1 = obj1.operator+(obj2);
// CHECK-NEXT: {
// CHECK-NEXT: double _r4 = 0.;
// CHECK-NEXT: double _r5 = 0.;
// CHECK-NEXT: SimpleFunctions1 _r6 = {};
// CHECK-NEXT: _t1.mem_fn_1_pullback(i, j, 1, &_r6, &_r4, &_r5);
// CHECK-NEXT: *_d_i += _r4;
// CHECK-NEXT: *_d_j += _r5;
// CHECK-NEXT: _t0.operator_plus_pullback(obj2, _r6, &_d_obj1, &_d_obj2);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: double _r2 = 0.;
// CHECK-NEXT: double _r3 = 0.;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
// CHECK-NEXT: double _r1 = 0.;
// CHECK-NEXT: }
// CHECK-NEXT:}

void print(const Tangent& t) {
for (int i = 0; i < 5; ++i) {
printf("%.2f", t.data[i]);
Expand Down Expand Up @@ -591,6 +642,9 @@ int main() {
TEST_GRADIENT(fn14, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {30.00, 22.00}

INIT_GRADIENT(fn15);

INIT_GRADIENT(fn16);
TEST_GRADIENT(fn16, /*numOfDerivativeArgs=*/2, 2, 3, &d_i, &d_j); // CHECK-EXEC: {22.00, 12.00}
}

// CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) {
Expand Down Expand Up @@ -739,4 +793,29 @@ int main() {
// CHECK-NEXT: double _t1 = this->b;
// CHECK-NEXT: this->b = arg.b;
// CHECK-NEXT: return {*this, (*_d_this)};
// CHECK-NEXT:}

// CHECK: void operator_plus_pullback(const SimpleFunctions1 &other, SimpleFunctions1 _d_y, SimpleFunctions1 *_d_this, SimpleFunctions1 *_d_other) const {
// CHECK-NEXT: SimpleFunctions1 res(this->x + other.x, this->y + other.y);
// CHECK-NEXT: SimpleFunctions1 _d_res(res);
// CHECK-NEXT: clad::zero_init(_d_res);
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
// CHECK-NEXT: double _r1 = 0.;
// CHECK-NEXT: (*_d_this).x += _r0;
// CHECK-NEXT: (*_d_other).x += _r0;
// CHECK-NEXT: (*_d_this).y += _r1;
// CHECK-NEXT: (*_d_other).y += _r1;
// CHECK-NEXT: }
// CHECK-NEXT:}

// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: {
// CHECK-NEXT: (*_d_this).x += _d_y * i;
// CHECK-NEXT: (*_d_this).y += _d_y * i;
// CHECK-NEXT: *_d_i += (this->x + this->y) * _d_y;
// CHECK-NEXT: *_d_i += _d_y * j * j;
// CHECK-NEXT: *_d_j += i * _d_y * j;
// CHECK-NEXT: *_d_j += i * j * _d_y;
// CHECK-NEXT: }
// CHECK-NEXT:}

0 comments on commit 19104bd

Please sign in to comment.