Skip to content

Commit

Permalink
Remove excessive stores for multiplication in reverse mode.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Nov 28, 2023
1 parent 9d26a46 commit 05aabe0
Show file tree
Hide file tree
Showing 33 changed files with 876 additions and 1,609 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ namespace clad {
clang::Stmt* Clone(const clang::Stmt* S);
/// A shorthand to simplify cloning of expressions.
clang::Expr* Clone(const clang::Expr* E);
clang::DeclRefExpr* Clone(const clang::DeclRefExpr* DRE);
/// Cloning types is necessary since VariableArrayType
/// store a pointer to their size expression.
clang::QualType CloneType(clang::QualType T);
Expand Down
34 changes: 25 additions & 9 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(add_assign, direction::reverse);
}
}
llvm::errs() << "\n\n\n";
DRE->dump();
llvm::errs() << "|/ |/ |/\n";
clonedDRE->dump();
Clone(DRE)->dump();
llvm::errs() << "–––––––––––\n\n\n";
return StmtDiff(clonedDRE, it->second, it->second);
}
}
Expand Down Expand Up @@ -2188,29 +2194,39 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// to reduce cloning complexity and only clones once. Storing it in a
// global variable allows to save current result and make it accessible
// in the reverse pass.
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
std::unique_ptr<DelayedStoreResult> RDelayed;
Expr::EvalResult dummy;
bool RisNotConst = !clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context);
StmtDiff RResult;
if (R->HasSideEffects(m_Context) && RisNotConst) {
RDelayed = std::unique_ptr<DelayedStoreResult>(new DelayedStoreResult(DelayedGlobalStoreAndRef(R)));
RResult = RDelayed->Result;
} else {
// RResult = StmtDiff(Clone(R));
// R->dump();
RResult = Visit(R);
}

Expr* dl = nullptr;
if (dfdx()) {
dl = BuildOp(BO_Mul, dfdx(), RResult.getExpr_dx());
dl = BuildOp(BO_Mul, dfdx(), RResult.getRevSweepAsExpr());
dl = StoreAndRef(dl, direction::reverse);
}
Ldiff = Visit(L, dl);
// dxi/xr = xl
// df/dxr += df/dxi * dxi/xr = df/dxi * xl
// Store left multiplier and assign it with L.
Expr* LStored = Ldiff.getExpr();
// RDelayed.isConstant == true implies that R is a constant expression,
// therefore we can skip visiting it.
if (!RDelayed.isConstant) {
if (RisNotConst) {
Expr* dr = nullptr;
if (dfdx()) {
dr = BuildOp(BO_Mul, Ldiff.getRevSweepAsExpr(), dfdx());
dr = StoreAndRef(dr, direction::reverse);
}
Rdiff = Visit(R, dr);
// Assign right multiplier's variable with R.
RDelayed.Finalize(Rdiff.getExpr());
if (RDelayed)
RDelayed->Finalize(Rdiff.getExpr());
}
std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult.getExpr());
} else if (opCode == BO_Div) {
Expand Down Expand Up @@ -2243,7 +2259,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
}
std::tie(Ldiff, Rdiff) = std::make_pair(Ldiff.getRevSweepAsExpr(), RResult.getRevSweepAsExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(Ldiff.getExpr(), RResult.getExpr());
} else if (BinOp->isAssignmentOp()) {
if (L->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
diag(DiagnosticsEngine::Warning,
Expand Down Expand Up @@ -2970,7 +2986,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* Push = CladTape.Push;
Expr* Pop = CladTape.Pop;
return DelayedStoreResult{*this,
StmtDiff{Push, Pop},
StmtDiff{Push, Pop, nullptr, Pop},
/*isConstant*/ false,
/*isInsideLoop*/ true, /*pNeedsUpdate=*/ true};
}
Expand Down
23 changes: 23 additions & 0 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,32 @@ namespace clad {
return clonedStmt;
}
Expr* VisitorBase::Clone(const Expr* E) {
if (auto* DRE = dyn_cast<DeclRefExpr>(E))
return Clone(DRE);
const Stmt* S = E;
return llvm::cast<Expr>(Clone(S));
}
DeclRefExpr* VisitorBase::Clone(const DeclRefExpr* DRE) {
DeclRefExpr* clonedDRE = nullptr;
if (const auto* VD = dyn_cast<VarDecl>(DRE->getDecl())) {
auto it = m_DeclReplacements.find(VD);
if (it != std::end(m_DeclReplacements))
clonedDRE = BuildDeclRef(it->second);
else
clonedDRE = cast<DeclRefExpr>(cast<Stmt>(Clone(DRE)));
// If current context is different than the context of the original
// declaration (e.g. we are inside lambda), rebuild the DeclRefExpr
// with Sema::BuildDeclRefExpr. This is required in some cases, e.g.
// Sema::BuildDeclRefExpr is responsible for adding captured fields
// to the underlying struct of a lambda.
if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) {
auto* referencedDecl = cast<VarDecl>(clonedDRE->getDecl());
clonedDRE = cast<DeclRefExpr>(BuildDeclRef(referencedDecl));
}
return clonedDRE;
}
return cast<DeclRefExpr>(Clone(cast<Stmt>(DRE)));
}

QualType VisitorBase::CloneType(const QualType QT) {
auto clonedType = m_Builder.m_NodeCloner->CloneType(QT);
Expand Down
45 changes: 18 additions & 27 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,10 @@ float helper(float x) {
}

// CHECK: void helper_pullback(float x, float _d_y, clad::array_ref<float> _d_x) {
// CHECK-NEXT: float _t0;
// CHECK-NEXT: _t0 = x;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: float _r0 = _d_y * _t0;
// CHECK-NEXT: float _r0 = _d_y * x;
// CHECK-NEXT: float _r1 = 2 * _d_y;
// CHECK-NEXT: * _d_x += _r1;
// CHECK-NEXT: }
Expand Down Expand Up @@ -208,30 +206,26 @@ double func4(double x) {
}

//CHECK: void func4_grad(double x, clad::array_ref<double> _d_x) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: clad::array<double> _d_arr(3UL);
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned long _t2;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: clad::tape<double> _t3 = {};
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: double arr[3] = {x, 2 * _t0, x * _t1};
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double arr[3] = {x, 2 * x, x * x};
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t2 = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _t2++;
//CHECK-NEXT: clad::push(_t3, sum);
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: sum += addArr(arr, 3);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t2; _t2--) {
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: i--;
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t3);
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d0;
//CHECK-NEXT: int _grad1 = 0;
Expand All @@ -243,10 +237,10 @@ double func4(double x) {
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: * _d_x += _d_arr[0];
//CHECK-NEXT: double _r0 = _d_arr[1] * _t0;
//CHECK-NEXT: double _r0 = _d_arr[1] * x;
//CHECK-NEXT: double _r1 = 2 * _d_arr[1];
//CHECK-NEXT: * _d_x += _r1;
//CHECK-NEXT: double _r2 = _d_arr[2] * _t1;
//CHECK-NEXT: double _r2 = _d_arr[2] * x;
//CHECK-NEXT: * _d_x += _r2;
//CHECK-NEXT: double _r3 = x * _d_arr[2];
//CHECK-NEXT: * _d_x += _r3;
Expand Down Expand Up @@ -334,15 +328,14 @@ double func6(double seed) {
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: clad::tape<int> _t1 = {};
//CHECK-NEXT: clad::array<double> _d_arr(3UL);
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: double arr[3] = {seed, seed * clad::push(_t1, i), seed + i};
//CHECK-NEXT: clad::push(_t2, sum);
//CHECK-NEXT: double arr[3] = {seed, seed * i, seed + i};
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: sum += addArr(arr, 3);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
Expand All @@ -351,7 +344,7 @@ double func6(double seed) {
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: i--;
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t2);
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d0;
//CHECK-NEXT: int _grad1 = 0;
Expand All @@ -362,7 +355,7 @@ double func6(double seed) {
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: * _d_seed += _d_arr[0];
//CHECK-NEXT: double _r0 = _d_arr[1] * clad::pop(_t1);
//CHECK-NEXT: double _r0 = _d_arr[1] * i;
//CHECK-NEXT: * _d_seed += _r0;
//CHECK-NEXT: double _r1 = seed * _d_arr[1];
//CHECK-NEXT: _d_i += _r1;
Expand All @@ -379,15 +372,13 @@ double inv_square(double *params) {

//CHECK: void inv_square_pullback(double *params, double _d_y, clad::array_ref<double> _d_params) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: _t1 = params[0];
//CHECK-NEXT: _t0 = (params[0] * _t1);
//CHECK-NEXT: _t0 = (params[0] * params[0]);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y / _t0;
//CHECK-NEXT: double _r1 = _d_y * -1 / (_t0 * _t0);
//CHECK-NEXT: double _r2 = _r1 * _t1;
//CHECK-NEXT: double _r2 = _r1 * params[0];
//CHECK-NEXT: _d_params[0] += _r2;
//CHECK-NEXT: double _r3 = params[0] * _r1;
//CHECK-NEXT: _d_params[0] += _r3;
Expand Down
12 changes: 3 additions & 9 deletions test/Arrays/Arrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -94,26 +94,20 @@ double const_dot_product(double x, double y, double z) {
//CHECK: void const_dot_product_grad(double x, double y, double z, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y, clad::array_ref<double> _d_z) {
//CHECK-NEXT: clad::array<double> _d_vars(3UL);
//CHECK-NEXT: clad::array<double> _d_consts(3UL);
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double _t2;
//CHECK-NEXT: double vars[3] = {x, y, z};
//CHECK-NEXT: double consts[3] = {1, 2, 3};
//CHECK-NEXT: _t0 = consts[0];
//CHECK-NEXT: _t1 = consts[1];
//CHECK-NEXT: _t2 = consts[2];
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 1 * _t0;
//CHECK-NEXT: double _r0 = 1 * consts[0];
//CHECK-NEXT: _d_vars[0] += _r0;
//CHECK-NEXT: double _r1 = vars[0] * 1;
//CHECK-NEXT: _d_consts[0] += _r1;
//CHECK-NEXT: double _r2 = 1 * _t1;
//CHECK-NEXT: double _r2 = 1 * consts[1];
//CHECK-NEXT: _d_vars[1] += _r2;
//CHECK-NEXT: double _r3 = vars[1] * 1;
//CHECK-NEXT: _d_consts[1] += _r3;
//CHECK-NEXT: double _r4 = 1 * _t2;
//CHECK-NEXT: double _r4 = 1 * consts[2];
//CHECK-NEXT: _d_vars[2] += _r4;
//CHECK-NEXT: double _r5 = vars[2] * 1;
//CHECK-NEXT: _d_consts[2] += _r5;
Expand Down
16 changes: 5 additions & 11 deletions test/ErrorEstimation/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,21 @@ float func2(float x, int y) {
//CHECK-NEXT: float _t0;
//CHECK-NEXT: double _delta_x = 0;
//CHECK-NEXT: float _EERepl_x0 = x;
//CHECK-NEXT: float _t1;
//CHECK-NEXT: float _t2;
//CHECK-NEXT: float _EERepl_x1;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: _t2 = x;
//CHECK-NEXT: x = y * _t1 + x * _t2;
//CHECK-NEXT: x = y * x + x * x;
//CHECK-NEXT: _EERepl_x1 = x;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_x += 1;
//CHECK-NEXT: {
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: float _r_d0 = * _d_x;
//CHECK-NEXT: float _r0 = _r_d0 * _t1;
//CHECK-NEXT: float _r0 = _r_d0 * x;
//CHECK-NEXT: * _d_y += _r0;
//CHECK-NEXT: float _r1 = y * _r_d0;
//CHECK-NEXT: * _d_x += _r1;
//CHECK-NEXT: float _r2 = _r_d0 * _t2;
//CHECK-NEXT: float _r2 = _r_d0 * x;
//CHECK-NEXT: * _d_x += _r2;
//CHECK-NEXT: float _r3 = x * _r_d0;
//CHECK-NEXT: * _d_x += _r3;
Expand Down Expand Up @@ -194,14 +190,12 @@ float func6(float x) { return x; }
float func7(float x, float y) { return (x * y); }

//CHECK: void func7_grad(float x, float y, clad::array_ref<float> _d_x, clad::array_ref<float> _d_y, double &_final_error) {
//CHECK-NEXT: float _t0;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: _t0 = y;
//CHECK-NEXT: _ret_value0 = (x * _t0);
//CHECK-NEXT: _ret_value0 = (x * y);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: float _r0 = 1 * _t0;
//CHECK-NEXT: float _r0 = 1 * y;
//CHECK-NEXT: * _d_x += _r0;
//CHECK-NEXT: float _r1 = x * 1;
//CHECK-NEXT: * _d_y += _r1;
Expand Down
Loading

0 comments on commit 05aabe0

Please sign in to comment.