Skip to content

Commit

Permalink
Support std::initializer_list parameters in the reverse mode.
Browse files Browse the repository at this point in the history
Previously, we replaced ``std::initializer_list`` variables with ``clad::array`` in the reverse mode so that they become modifiable. This PR moves the logic for type replacement from ``RMV::DifferentiateVarDecl`` to a``RMV::CloneType`` to use it to handle constructors with ``std::initializer_list`` parameters.

Fixes vgvassilev#1082.
  • Loading branch information
PetroZarytskyi committed Jan 10, 2025
1 parent 80c1e82 commit 938e0ef
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 37 deletions.
8 changes: 8 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,14 @@ namespace clad {
/// \returns True if the statement was added to the block, false otherwise.
bool AddToGlobalBlock(clang::Stmt* S) { return addToBlock(S, m_Globals); }

/// Updates size references in VariableArrayType and replaces
/// std::initializer_list with clad::array.
clang::QualType CloneType(clang::QualType T);

/// If E is a CXXSTDInializerListExpr, returns its size expr.
/// Otherwise, returns nullptr;
clang::Expr* getStdInitListSizeExpr(const clang::Expr* E);

/// Stores the result of an expression in a temporary variable (of the same
/// type as is the result of the expression) and returns a reference to it.
/// If force decl creation is true, this will allways create a temporary
Expand Down
25 changes: 25 additions & 0 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define CLAD_STL_BUILTINS_H

#include <array>
#include <clad/Differentiator/Array.h>
#include <clad/Differentiator/BuiltinDerivatives.h>
#include <clad/Differentiator/FunctionTraits.h>
#include <initializer_list>
Expand Down Expand Up @@ -463,6 +464,19 @@ constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::vector<T>>,
return {v, d_v};
}

template <typename T>
::clad::ValueAndAdjoint<::std::vector<T>, ::std::vector<T>>
constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::vector<T>>,
const clad::array<T>& list,
const clad::array<T>& d_list) {
::std::vector<T> v(list.size());
const T* iter = list.begin();
for (T& el : v)
el = *(iter++);
::std::vector<T> d_v(list.size(), 0);
return {v, d_v};
}

template <typename T, typename S, typename U>
void constructor_pullback(::std::vector<T>* v, S count, U val,
typename ::std::vector<T>::allocator_type alloc,
Expand All @@ -473,6 +487,17 @@ void constructor_pullback(::std::vector<T>* v, S count, U val,
d_v->clear();
}

// A specialization for std::initializer_list (which is replaced with
// clad::array).
template <typename T>
void constructor_pullback(::std::vector<T>* v, clad::array<T> init,
::std::vector<T>* d_v, clad::array<T>* d_init) {
for (unsigned i = 0; i < init.size(); ++i) {
(*d_init)[i] += (*d_v)[i];
(*d_v)[i] = 0;
}
}

template <typename T, typename U, typename dU>
void assign_pullback(::std::vector<T>* v,
typename ::std::vector<T>::size_type n, U /*val*/,
Expand Down
89 changes: 53 additions & 36 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return nullptr;
}

Expr* ReverseModeVisitor::getStdInitListSizeExpr(const Expr* E) {
if (E)
if (const auto* CXXILE =
dyn_cast<CXXStdInitializerListExpr>(E->IgnoreImplicit()))
if (const auto* ILE =
dyn_cast<InitListExpr>(CXXILE->getSubExpr()->IgnoreImplicit())) {
unsigned numInits = ILE->getNumInits();
return ConstantFolder::synthesizeLiteral(m_Context.getSizeType(),
m_Context, numInits);
}
return nullptr;
}

Expr* ReverseModeVisitor::CladTapeResult::Last() {
LookupResult& Back = V.GetCladTapeBack();
CXXScopeSpec CSS;
Expand Down Expand Up @@ -2501,6 +2514,25 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(op, ResultRef, nullptr, valueForRevPass);
}

QualType ReverseModeVisitor::CloneType(QualType T) {
QualType dT = VisitorBase::CloneType(T);

bool isLValueRefType = dT->isLValueReferenceType();
dT = dT.getNonReferenceType();

// We need to replace std::initializer_list with clad::array because the
// former is temporary by design and it's not possible to create modifiable
// adjoints.
QualType elemType;
if (m_Sema.isStdInitializerList(utils::GetValueType(T), &elemType))
dT = GetCladArrayOfType(elemType);

if (isLValueRefType)
return m_Context.getLValueReferenceType(dT);

return dT;
}

DeclDiff<VarDecl> ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD,
bool keepLocal) {
StmtDiff initDiff;
Expand All @@ -2516,6 +2548,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
QualType VDCloneType;
QualType VDDerivedType;
QualType VDType = VD->getType();
VarDecl::InitializationStyle VDStyle = VD->getInitStyle();
// If the cloned declaration is moved to the function global scope,
// change its type for the corresponding adjoint type.
if (promoteToFnScope) {
Expand All @@ -2535,37 +2568,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bool isInitializedByNewExpr = false;
bool initializeDerivedVar = true;

// We need to replace std::initializer_list with clad::array because the
// former is temporary by design and it's not possible to create modifiable
// adjoints.
if (m_Sema.isStdInitializerList(utils::GetValueType(VDType),
/*Element=*/nullptr)) {
if (const Expr* init = VD->getInit()) {
if (const auto* CXXILE =
dyn_cast<CXXStdInitializerListExpr>(init->IgnoreImplicit())) {
if (const auto* ILE = dyn_cast<InitListExpr>(
CXXILE->getSubExpr()->IgnoreImplicit())) {
VDDerivedType =
GetCladArrayOfType(ILE->getInit(/*Init=*/0)->getType());
unsigned numInits = ILE->getNumInits();
VDDerivedInit = ConstantFolder::synthesizeLiteral(
m_Context.getSizeType(), m_Context, numInits);
VDCloneType = VDDerivedType;
}
} else if (isRefType) {
initDiff = Visit(init);
if (promoteToFnScope) {
VDDerivedInit = BuildOp(UO_AddrOf, initDiff.getExpr_dx());
VDDerivedType = VDDerivedInit->getType();
} else {
VDDerivedInit = initDiff.getExpr_dx();
VDDerivedType =
m_Context.getLValueReferenceType(VDDerivedInit->getType());
}
VDCloneType = VDDerivedType;
}
}
}
if (Expr* size = getStdInitListSizeExpr(VD->getInit()))
VDDerivedInit = size;

// Check if the variable is pointer type and initialized by new expression
if (isPointerType && VD->getInit() && isa<CXXNewExpr>(VD->getInit()))
Expand Down Expand Up @@ -2629,6 +2633,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
resetConstructorPullbackCallInfo();
if (initDiff.getForwSweepExpr_dx())
VDDerivedInit = initDiff.getForwSweepExpr_dx();
// ListInit style combined with `_t0.value`/`_t0.adjoint` inits will be
// displayed incorrectly.
if (VDStyle == VarDecl::InitializationStyle::ListInit)
VDStyle = VarDecl::InitializationStyle::CallInit;
}

// FIXME: Remove the special cases introduced by `specialThisDiffCase`
Expand Down Expand Up @@ -2675,7 +2683,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (initializeDerivedVar)
VDDerived =
BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
VDDerivedInit, false, nullptr, VD->getInitStyle());
VDDerivedInit, false, nullptr, VDStyle);

if (!m_DiffReq.shouldHaveAdjoint((VD)))
VDDerived = nullptr;
Expand Down Expand Up @@ -2758,7 +2766,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
else
VDClone = BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(),
initDiff.getExpr(), VD->isDirectInit(),
nullptr, VD->getInitStyle());
nullptr, VDStyle);
if (isPointerType && derivedVDE) {
if (promoteToFnScope) {
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
Expand Down Expand Up @@ -3108,7 +3116,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff subExprDiff = Visit(EWC->getSubExpr(), dfdx());
// FIXME: We are unable to create cleanup objects currently, this can be
// potentially problematic
return StmtDiff(subExprDiff.getExpr(), subExprDiff.getExpr_dx());
return StmtDiff(subExprDiff.getStmt(), subExprDiff.getStmt_dx(),
subExprDiff.getForwSweepStmt_dx(),
subExprDiff.getRevSweepStmt());
}

bool ReverseModeVisitor::ShouldRecompute(const Expr* E) {
Expand Down Expand Up @@ -3986,6 +3996,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// FIXME: Restore arguments passed as non-const reference.
for (const auto* arg : CE->arguments()) {
// FIXME: Use this workaround to support some custom constructors.
// Remove when default arguments are supported.
if (isa<CXXDefaultArgExpr>(arg->IgnoreImplicit()))
break;
QualType ArgTy = arg->getType();
StmtDiff argDiff{};
Expr* adjointArg = nullptr;
Expand All @@ -4008,8 +4022,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// double _r0 = 0;
// SomeClass_pullback(c, u, ..., &_d_c, &_r0, ...);
// _d_u += _r0;
QualType dArgTy = getNonConstType(ArgTy, m_Context, m_Sema);
VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy));
QualType dArgTy = getNonConstType(CloneType(ArgTy), m_Context, m_Sema);
Expr* init = getStdInitListSizeExpr(arg);
if (!init)
init = getZeroInit(dArgTy);
VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", init);
prePullbackCallStmts.push_back(BuildDeclStmt(dArgDecl));
adjointArg = BuildDeclRef(dArgDecl);
argDiff = Visit(arg, BuildDeclRef(dArgDecl));
Expand Down
32 changes: 31 additions & 1 deletion test/Gradient/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ double fn21(double x, double y) {
return a[0];
}

double fn22(double u, double v) {
std::vector<double> ls{u, v};
return ls[1] - 2 * ls[0];
}

int main() {
double d_i, d_j;
INIT_GRADIENT(fn10);
Expand All @@ -198,6 +203,7 @@ int main() {
INIT_GRADIENT(fn19);
INIT_GRADIENT(fn20);
INIT_GRADIENT(fn21);
INIT_GRADIENT(fn22);

TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00}
TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00}
Expand All @@ -211,6 +217,7 @@ int main() {
TEST_GRADIENT(fn19, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {3.00, 2.00}
TEST_GRADIENT(fn20, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {11.00, 1.00}
TEST_GRADIENT(fn21, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {6.00, 0.00}
TEST_GRADIENT(fn22, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {-2.00, 1.00}
}

// CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) {
Expand Down Expand Up @@ -840,4 +847,27 @@ int main() {
// CHECK-NEXT: {{.*}}value_type _r0 = 0.;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn22_grad(double u, double v, double *_d_u, double *_d_v) {
// CHECK-NEXT: {{.*}} _t0 = {{.*}}::class_functions::constructor_reverse_forw(clad::ConstructorReverseForwTag<{{.*}}> >(), {{.*u, v.*}}, {});
// CHECK-NEXT: std::vector<double> _d_ls(_t0.adjoint);
// CHECK-NEXT: std::vector<double> ls(_t0.value);
// CHECK-NEXT: std::vector<double> _t1 = ls;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t2 = clad::custom_derivatives::class_functions::operator_subscript_reverse_forw(&ls, 1, &_d_ls, {{0U|0UL|0}});
// CHECK-NEXT: std::vector<double> _t4 = ls;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t5 = clad::custom_derivatives::class_functions::operator_subscript_reverse_forw(&ls, 0, &_d_ls, {{0U|0UL|0}});
// CHECK-NEXT: {{.*}}value_type _t3 = _t5.value;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}size_type _r1 = 0{{.*}};
// CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&_t1, 1, 1, &_d_ls, &_r1);
// CHECK-NEXT: {{.*}}size_type _r2 = 0{{.*}};
// CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&_t4, 0, 2 * -1, &_d_ls, &_r2);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: clad::array<double> _r0 = {{2U|2UL|2ULL}};
// CHECK-NEXT: clad::custom_derivatives::class_functions::constructor_pullback(&ls, {u, v}, &_d_ls, &_r0);
// CHECK-NEXT: *_d_u += _r0[0];
// CHECK-NEXT: *_d_v += _r0[1];
// CHECK-NEXT: }
// CHECK-NEXT: }

0 comments on commit 938e0ef

Please sign in to comment.