From 5c7b19b50fb47b921f993768f21d8cab6218efe0 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Thu, 26 Oct 2023 01:11:54 +0530 Subject: [PATCH] Fix gradient of fxns with const reference parameters --- lib/Differentiator/ReverseModeVisitor.cpp | 24 ++++++++++--- test/Gradient/FunctionCalls.C | 44 +++++++++++++++++++++++ 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index b6eb0691d..5009c42d9 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1455,6 +1455,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, FD->getParamDecl(i - static_cast(isCXXOperatorCall)); StmtDiff argDiff{}; bool passByRef = utils::IsReferenceOrPointerType(PVD->getType()); + if (passByRef && isa(arg)) { + // If the argument is a temporary variable, this means that param type + // is a reference to a const type and we are passing a temporary + // variable to it. In this case, we should not pass the derivative + // argument by reference. + passByRef = false; + } // We do not need to create result arg for arguments passed by reference // because the derivatives of arguments passed by reference are directly // modified by the derived callee function. @@ -1498,7 +1505,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // same as the call expression as it is the type used to declare the // _gradX array Expr* dArg; - dArg = StoreAndRef(/*E=*/nullptr, arg->getType(), direction::reverse, "_r", + QualType argType = utils::GetValueType(arg->getType()); + dArg = StoreAndRef(/*E=*/nullptr, argType, direction::reverse, "_r", /*forceDeclCreation=*/true); ArgResultDecls.push_back( cast(cast(dArg)->getDecl())); @@ -1673,6 +1681,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto PVD = FD->getParamDecl(idx); bool passByRef = utils::IsReferenceOrPointerType(PVD->getType()); + if (passByRef && isa(CE->getArg(idx))) { + // If the argument is a temporary variable, this means that param type + // is a reference to a const type and we are passing a temporary + // variable to it. In this case, we should not pass the derivative + // argument by reference. + passByRef = false; + } if (passByRef) { // If derivative type is constant array type instead of // `clad::array_ref` or `clad::array` type, then create an @@ -1700,13 +1715,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } else { // Declare: diffArgType _grad; Expr* initVal = nullptr; - if (!PVD->getType()->isRecordType()) { + QualType gradVarType = utils::GetValueType(PVD->getType()); + if (!gradVarType->isRecordType()) { // If the argument is not a class type, then initialize the grad // variable with 0. initVal = - ConstantFolder::synthesizeLiteral(PVD->getType(), m_Context, 0); + ConstantFolder::synthesizeLiteral(gradVarType, m_Context, 0); } - gradVarDecl = BuildVarDecl(PVD->getType(), gradVarII, initVal); + gradVarDecl = BuildVarDecl(gradVarType, gradVarII, initVal); // Pass the address of the declared variable gradVarExpr = BuildDeclRef(gradVarDecl); gradArgExpr = diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 3f440c36d..aa4afad44 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -491,6 +491,48 @@ double fn8(double x, double y) { // CHECK-NEXT: } // CHECK-NEXT: } +double custom_max(const double& a, const double& b) { + return a > b ? a : b; +} + +// CHECK: void custom_max_pullback(const double &a, const double &b, double _d_y, clad::array_ref _d_a, clad::array_ref _d_b) { +// CHECK-NEXT: bool _cond0; +// CHECK-NEXT: _cond0 = a > b; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: if (_cond0) +// CHECK-NEXT: * _d_a += _d_y; +// CHECK-NEXT: else +// CHECK-NEXT: * _d_b += _d_y; +// CHECK-NEXT: } + +double fn9(double x, double y) { + return custom_max(x*y, y); +} + +// CHECK: void fn9_grad(double x, double y, clad::array_ref _d_x, clad::array_ref _d_y) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double _t1; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double _t3; +// CHECK-NEXT: _t1 = x; +// CHECK-NEXT: _t0 = y; +// CHECK-NEXT: _t2 = _t1 * _t0; +// CHECK-NEXT: _t3 = y; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: double _grad0 = 0.; +// CHECK-NEXT: custom_max_pullback(_t2, _t3, 1, &_grad0, &* _d_y); +// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: double _r1 = _r0 * _t0; +// CHECK-NEXT: * _d_x += _r1; +// CHECK-NEXT: double _r2 = _t1 * _r0; +// CHECK-NEXT: * _d_y += _r2; +// CHECK-NEXT: double _r3 = * _d_y; +// CHECK-NEXT: } +// CHECK-NEXT: } + template void reset(T* arr, int n) { for (int i=0; i