Skip to content

Commit

Permalink
Support default arguments of object types in the reverse mode.
Browse files Browse the repository at this point in the history
The major use case of this feature is supporting ``std::initializer_list``-based constructors of STL containers. Most of them have a second default argument for the allocator.
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Feb 26, 2025
1 parent f066366 commit 57a47e7
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 24 deletions.
9 changes: 7 additions & 2 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
#define CLAD_REVERSE_MODE_VISITOR_H

#include "clad/Differentiator/Compatibility.h"
#include "clad/Differentiator/VisitorBase.h"
#include "clad/Differentiator/ReverseModeVisitorDirectionKinds.h"
#include "clad/Differentiator/ParseDiffArgsTypes.h"
#include "clad/Differentiator/ReverseModeVisitorDirectionKinds.h"
#include "clad/Differentiator/VisitorBase.h"

#include "clang/AST/ExprCXX.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"
Expand Down Expand Up @@ -370,6 +372,7 @@ namespace clad {
virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO);
StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL);
StmtDiff VisitCXXBindTemporaryExpr(const clang::CXXBindTemporaryExpr* BTE);
StmtDiff VisitCharacterLiteral(const clang::CharacterLiteral* CL);
StmtDiff VisitStringLiteral(const clang::StringLiteral* SL);
StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE);
Expand Down Expand Up @@ -403,6 +406,8 @@ namespace clad {
StmtDiff VisitBreakStmt(const clang::BreakStmt* BS);
StmtDiff
VisitCXXStdInitializerListExpr(const clang::CXXStdInitializerListExpr* ILE);
StmtDiff
VisitCXXTemporaryObjectExpr(const clang::CXXTemporaryObjectExpr* TOE);
StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE);
StmtDiff VisitCXXNewExpr(const clang::CXXNewExpr* CNE);
StmtDiff VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE);
Expand Down
11 changes: 11 additions & 0 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,17 @@ void constructor_pullback(
}
}

// A specialization for std::initializer_list (which is replaced with
// clad::array).
template <typename T>
void constructor_pullback(::std::vector<T>* v, clad::array<T> init,
::std::vector<T>* d_v, clad::array<T>* d_init) {
for (unsigned i = 0; i < init.size(); ++i) {
(*d_init)[i] += (*d_v)[i];
(*d_v)[i] = 0;
}
}

template <typename T, typename U, typename dU>
void assign_pullback(::std::vector<T>* v,
typename ::std::vector<T>::size_type n, U /*val*/,
Expand Down
3 changes: 2 additions & 1 deletion lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ namespace clad {
// The argument is passed by reference if it's passed as an L-value.
// However, if arg is a MaterializeTemporaryExpr, then arg is a
// temporary variable passed as a const reference.
bool isRefType = arg->isLValue() && !isa<MaterializeTemporaryExpr>(arg);
bool isRefType = arg->isLValue() && !isa<MaterializeTemporaryExpr>(arg) &&
!isa<CXXDefaultArgExpr>(arg);
return isRefType || isArrayOrPointerType(arg->getType());
}

Expand Down
55 changes: 35 additions & 20 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "llvm/Support/SaveAndRestore.h"
#include <llvm/ADT/STLExtras.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/raw_ostream.h>

#include <algorithm>
Expand Down Expand Up @@ -4021,6 +4022,16 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
return {clonedCTE, m_ThisExprDerivative};
}

StmtDiff ReverseModeVisitor::VisitCXXTemporaryObjectExpr(
const clang::CXXTemporaryObjectExpr* TOE) {
return Clone(TOE);
}

StmtDiff ReverseModeVisitor::VisitCXXBindTemporaryExpr(
const clang::CXXBindTemporaryExpr* BTE) {
return Visit(BTE->getSubExpr(), dfdx());
}

StmtDiff ReverseModeVisitor::VisitCXXNewExpr(const clang::CXXNewExpr* CNE) {
StmtDiff initializerDiff;
if (CNE->hasInitializer())
Expand Down Expand Up @@ -4142,29 +4153,33 @@ Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
// created. The caller which triggers 'VisitCXXConstructExpr' is
// responsible for updating these args.
Expr* thisE = getZeroInit(recordPointerType);
Expr* dThisE = getZeroInit(recordPointerType);
if (!m_TrackConstructorPullbackInfo && dfdx())
Expr* dThisE = nullptr;
if (m_TrackConstructorPullbackInfo)
dThisE = getZeroInit(recordPointerType);
else if (dfdx())
dThisE = BuildOp(UnaryOperatorKind::UO_AddrOf, dfdx(),
m_DiffReq->getLocation());

pullbackArgs.push_back(thisE);
pullbackArgs.append(primalArgs.begin(), primalArgs.end());
pullbackArgs.push_back(dThisE);
pullbackArgs.append(adjointArgs.begin(), adjointArgs.end());

Stmts& curRevBlock = getCurrentBlock(direction::reverse);
Stmts::iterator it = std::begin(curRevBlock) + insertionPoint;
curRevBlock.insert(it, prePullbackCallStmts.begin(),
prePullbackCallStmts.end());
it += prePullbackCallStmts.size();
std::string customPullbackName = "constructor_pullback";
if (Expr* customPullbackCall =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullbackName, pullbackArgs, getCurrentScope(), CE)) {
curRevBlock.insert(it, customPullbackCall);
if (m_TrackConstructorPullbackInfo) {
setConstructorPullbackCallInfo(llvm::cast<CallExpr>(customPullbackCall),
primalArgs.size() + 1);
if (dThisE) {
pullbackArgs.push_back(thisE);
pullbackArgs.append(primalArgs.begin(), primalArgs.end());
pullbackArgs.push_back(dThisE);
pullbackArgs.append(adjointArgs.begin(), adjointArgs.end());

Stmts& curRevBlock = getCurrentBlock(direction::reverse);
Stmts::iterator it = std::begin(curRevBlock) + insertionPoint;
curRevBlock.insert(it, prePullbackCallStmts.begin(),
prePullbackCallStmts.end());
it += prePullbackCallStmts.size();
std::string customPullbackName = "constructor_pullback";
if (Expr* customPullbackCall =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullbackName, pullbackArgs, getCurrentScope(), CE)) {
curRevBlock.insert(it, customPullbackCall);
if (m_TrackConstructorPullbackInfo) {
setConstructorPullbackCallInfo(
llvm::cast<CallExpr>(customPullbackCall), primalArgs.size() + 1);
}
}
}
// FIXME: If no compatible custom constructor pullback is found then try
Expand Down
91 changes: 90 additions & 1 deletion test/Gradient/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,15 @@ double fn25(double u, double v) {
return prod;
}

double fn26(double u, double v) {
for (int i = 0; i < 3; ++i) {
std::vector<double> ls{u, v};
ls[1] += ls[0];
u = ls[1];
}
return u;
}

int main() {
double d_i, d_j;
INIT_GRADIENT(fn10);
Expand All @@ -236,6 +245,7 @@ int main() {
INIT_GRADIENT(fn23);
INIT_GRADIENT(fn24);
INIT_GRADIENT(fn25);
INIT_GRADIENT(fn26);

TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00}
TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00}
Expand All @@ -253,6 +263,7 @@ int main() {
TEST_GRADIENT(fn23, /*numOfDerivativeArgs=*/2, 1, 1, &d_i, &d_j); // CHECK-EXEC: {1.00, 3.00}
TEST_GRADIENT(fn24, /*NumOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 5.00}
TEST_GRADIENT(fn25, /*NumOfDerivativeArgs=*/2, 3, 1, &d_i, &d_j); // CHECK-EXEC: {48.00, 48.00}
TEST_GRADIENT(fn26, /*numOfDerivativeArgs=*/2, 1, 1, &d_i, &d_j); // CHECK-EXEC: {1.00, 3.00}
}

// CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) {
Expand Down Expand Up @@ -1079,4 +1090,82 @@ int main() {
// CHECK-NEXT: vec = clad::pop(_t2);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }



// CHECK: void fn26_grad(double u, double v, double *_d_u, double *_d_v) {
// CHECK-NEXT: int _d_i = 0;
// CHECK-NEXT: int i = 0;
// CHECK-NEXT: clad::tape<std::vector<double> > _t1 = {};
// CHECK-NEXT: clad::tape<std::vector<double> > _t2 = {};
// CHECK-NEXT: std::vector<double> ls = {};
// CHECK-NEXT: std::vector<double> _d_ls{};
// CHECK-NEXT: clad::tape<std::vector<double> > _t3 = {};
// CHECK-NEXT: clad::tape<clad::ValueAndAdjoint<double &, double &> > _t4 = {};
// CHECK-NEXT: clad::tape<double> _t5 = {};
// CHECK-NEXT: clad::tape<std::vector<double> > _t6 = {};
// CHECK-NEXT: clad::tape<clad::ValueAndAdjoint<double &, double &> > _t7 = {};
// CHECK-NEXT: clad::tape<double> _t8 = {};
// CHECK-NEXT: clad::tape<std::vector<double> > _t9 = {};
// CHECK-NEXT: clad::tape<clad::ValueAndAdjoint<double &, double &> > _t10 = {};
// CHECK-NEXT: unsigned {{int|long|long long}} _t0 = {{0U|0UL|0ULL}};
// CHECK-NEXT: for (i = 0; ; ++i) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!(i < 3))
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: _t0++;
// CHECK-NEXT: clad::push(_t1, std::move(_d_ls));
// CHECK-NEXT: clad::push(_t2, std::move(ls)) , ls = {{.*{u, v}.*}};
// CHECK-NEXT: _d_ls = ls;
// CHECK-NEXT: clad::zero_init(_d_ls);
// CHECK-NEXT: clad::push(_t3, ls);
// CHECK-NEXT: clad::push(_t4, clad::custom_derivatives::class_functions::operator_subscript_reverse_forw(&ls, 1, &_d_ls, {{0U|0UL|0}}));
// CHECK-NEXT: clad::push(_t5, clad::back(_t4).value);
// CHECK-NEXT: clad::push(_t6, ls);
// CHECK-NEXT: clad::push(_t7, clad::custom_derivatives::class_functions::operator_subscript_reverse_forw(&ls, 0, &_d_ls, {{0U|0UL|0}}));
// CHECK-NEXT: clad::back(_t4).value += clad::back(_t7).value;
// CHECK-NEXT: clad::push(_t8, u);
// CHECK-NEXT: clad::push(_t9, ls);
// CHECK-NEXT: clad::push(_t10, clad::custom_derivatives::class_functions::operator_subscript_reverse_forw(&ls, 1, &_d_ls, {{0U|0UL|0}}));
// CHECK-NEXT: u = clad::back(_t10).value;
// CHECK-NEXT: }
// CHECK-NEXT: *_d_u += 1;
// CHECK-NEXT: for (;; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: --i;
// CHECK-NEXT: {
// CHECK-NEXT: u = clad::pop(_t8);
// CHECK-NEXT: double _r_d1 = *_d_u;
// CHECK-NEXT: *_d_u = 0.;
// CHECK-NEXT: {{.*}}size_type _r{{3|4}} = {{0U|0UL|0ULL}};
// CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&clad::back(_t9), 1, _r_d1, &_d_ls, &_r{{3|4}});
// CHECK-NEXT: clad::pop(_t9);
// CHECK-NEXT: clad::pop(_t10);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: clad::back(_t4).value = clad::pop(_t5);
// CHECK-NEXT: double _r_d0 = clad::back(_t4).adjoint;
// CHECK-NEXT: {{.*}}size_type _r{{2|3}} = {{0U|0UL|0ULL}};
// CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&clad::back(_t6), 0, _r_d0, &_d_ls, &_r{{2|3}});
// CHECK-NEXT: clad::pop(_t6);
// CHECK-NEXT: clad::pop(_t7);
// CHECK-NEXT: {{.*}}size_type _r{{1|2}} = {{0U|0UL|0ULL}};
// CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&clad::back(_t3), 1, 0., &_d_ls, &_r{{1|2}});
// CHECK-NEXT: clad::pop(_t3);
// CHECK-NEXT: clad::pop(_t4);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: clad::array<double> _r0 = {{2U|2UL|2ULL}};
// CHECK: clad::custom_derivatives::class_functions::constructor_pullback(&ls, {u, v},{{.*}} &_d_ls, &_r0{{.*}});
// CHECK-NEXT: *_d_u += _r0[0];
// CHECK-NEXT: *_d_v += _r0[1];
// CHECK-NEXT: _d_ls = clad::pop(_t1);
// CHECK-NEXT: ls = clad::pop(_t2);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

0 comments on commit 57a47e7

Please sign in to comment.