Skip to content

Commit

Permalink
Introduce ActBeforeDifferentiatingLoopCondition to fix the error esti…
Browse files Browse the repository at this point in the history
…mation mode.
  • Loading branch information
PetroZarytskyi committed Jan 6, 2024
1 parent 6904f1a commit dea82f4
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 4 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/ErrorEstimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ class ErrorEstimationHandler : public ExternalRMVSource {
void ActOnEndOfDerivedFnBody() override;
void ActBeforeDifferentiatingStmtInVisitCompoundStmt() override;
void ActAfterProcessingStmtInVisitCompoundStmt() override;
void ActBeforeDifferentiatingLoopCondition() override;
void ActBeforeDifferentiatingSingleStmtBranchInVisitIfStmt() override;
void ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt() override;
void ActBeforeDifferentiatingLoopInitStmt() override;
Expand Down
3 changes: 3 additions & 0 deletions include/clad/Differentiator/ExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ class ExternalRMVSource {
/// branch in `VisitBranch` lambda in
virtual void ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt() {}

/// This is called just before differentiating loop conditions.
virtual void ActBeforeDifferentiatingLoopCondition() {}

/// This is called just before differentiating init statement of loops.
virtual void ActBeforeDifferentiatingLoopInitStmt() {}

Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/MultiplexExternalRMVSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class MultiplexExternalRMVSource : public ExternalRMVSource {
void ActAfterProcessingStmtInVisitCompoundStmt() override;
void ActBeforeDifferentiatingSingleStmtBranchInVisitIfStmt() override;
void ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt() override;
void ActBeforeDifferentiatingLoopCondition() override;
void ActBeforeDifferentiatingLoopInitStmt() override;
void ActBeforeDifferentiatingSingleStmtLoopBody() override;
void ActAfterProcessingSingleStmtBodyInVisitForLoop() override;
Expand Down
5 changes: 5 additions & 0 deletions lib/Differentiator/ErrorEstimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,11 @@ void ErrorEstimationHandler::ActAfterProcessingStmtInVisitCompoundStmt() {
EmitErrorEstimationStmts(direction::reverse);
}

void ErrorEstimationHandler::
ActBeforeDifferentiatingLoopCondition() {
m_ShouldEmit.push(true);
}

void ErrorEstimationHandler::
ActBeforeDifferentiatingSingleStmtBranchInVisitIfStmt() {
m_ShouldEmit.push(true);
Expand Down
7 changes: 7 additions & 0 deletions lib/Differentiator/MultiplexExternalRMVSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ void MultiplexExternalRMVSource::
}
}

void MultiplexExternalRMVSource::
ActBeforeDifferentiatingLoopCondition() {
for (auto source : m_Sources) {
source->ActBeforeDifferentiatingLoopCondition();
}
}

void MultiplexExternalRMVSource::ActBeforeDifferentiatingLoopInitStmt() {
for (auto source : m_Sources) {
source->ActBeforeDifferentiatingLoopInitStmt();
Expand Down
14 changes: 10 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3273,8 +3273,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff bodyDiff = nullptr;
StmtDiff condDiff = nullptr;
if (isa<CompoundStmt>(body)) {
if (cond)
if (cond) {
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingLoopCondition();
condDiff = DifferentiateSingleStmt(cond, /*dfdS=*/nullptr);
}
bodyDiff = Visit(body);

bodyDiff.updateStmt(utils::PrependAndCreateCompoundStmt(
Expand All @@ -3287,8 +3290,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
beginScope(Scope::DeclScope);
beginBlock(direction::forward);
addToCurrentBlock(counterIncrement);
if (cond)

if (cond) {
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingLoopCondition();
condDiff = DifferentiateSingleStmt(cond, /*dfdS=*/nullptr);
}
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingSingleStmtLoopBody();
bodyDiff = DifferentiateSingleStmt(body, /*dfdS=*/nullptr);
Expand Down Expand Up @@ -3338,10 +3345,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// `for` loops have counter decrement expression in the
// loop iteration-expression.
if (!isForLoop) {
if (!isForLoop)
bodyDiff.updateStmtDx(utils::AppendAndCreateCompoundStmt(
m_Context, bodyDiff.getStmt_dx(), counterDecrement));
}
return {bodyDiff, condDiff};
}

Expand Down

0 comments on commit dea82f4

Please sign in to comment.