Skip to content

Commit

Permalink
Add support for differentiating switch stmt in the reverse mode AD.
Browse files Browse the repository at this point in the history
  • Loading branch information
parth-07 committed Feb 4, 2024
1 parent ed8422c commit fb1251e
Show file tree
Hide file tree
Showing 7 changed files with 980 additions and 28 deletions.
6 changes: 4 additions & 2 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ namespace clad {
llvm::SmallVectorImpl<clang::Expr*>& Exprs);

bool ContainsFunctionCalls(const clang::Stmt* E);
} // namespace utils
}

void SetSwitchCaseSubStmt(clang::SwitchCase* SC, clang::Stmt* subStmt);
} // namespace utils
} // namespace clad

#endif
34 changes: 30 additions & 4 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ namespace clad {
StmtDiff
VisitMaterializeTemporaryExpr(const clang::MaterializeTemporaryExpr* MTE);
StmtDiff VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* SCE);
StmtDiff VisitSwitchStmt(const clang::SwitchStmt* SS);
StmtDiff VisitCaseStmt(const clang::CaseStmt* CS);
StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS);
VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD);
StmtDiff VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP);
Expand Down Expand Up @@ -483,7 +486,7 @@ namespace clad {
clang::Stmt* forLoopIncDiff = nullptr,
bool isForLoop = false);

/// This class modifies forward and reverse blocks of the loop
/// This class modifies forward and reverse blocks of the loop/switch
/// body so that `break` and `continue` statements are correctly
/// handled. `break` and `continue` statements are handled by
/// enclosing entire reverse block loop body in a switch statement
Expand Down Expand Up @@ -526,6 +529,7 @@ namespace clad {

ReverseModeVisitor& m_RMV;

const bool m_IsInvokedBySwitchStmt = false;
/// Builds and returns a literal expression of type `std::size_t` with
/// `value` as value.
clang::Expr* CreateSizeTLiteralExpr(std::size_t value);
Expand All @@ -540,7 +544,8 @@ namespace clad {
clang::Expr* CreateCFTapePushExpr(std::size_t value);

public:
BreakContStmtHandler(ReverseModeVisitor& RMV) : m_RMV(RMV) {}
BreakContStmtHandler(ReverseModeVisitor& RMV, bool forSwitchStmt = false)
: m_RMV(RMV), m_IsInvokedBySwitchStmt(forSwitchStmt) {}

/// Begins control flow switch statement scope.
/// Control flow switch statement is used to refer to the
Expand Down Expand Up @@ -572,8 +577,8 @@ namespace clad {
BreakContStmtHandler* GetActiveBreakContStmtHandler() {
return &m_BreakContStmtHandlers.back();
}
BreakContStmtHandler* PushBreakContStmtHandler() {
m_BreakContStmtHandlers.emplace_back(*this);
BreakContStmtHandler* PushBreakContStmtHandler(bool forSwitchStmt = false) {
m_BreakContStmtHandlers.emplace_back(*this, forSwitchStmt);
return &m_BreakContStmtHandlers.back();
}
void PopBreakContStmtHandler() {
Expand Down Expand Up @@ -607,6 +612,27 @@ namespace clad {

clang::QualType ComputeAdjointType(clang::QualType T);
clang::QualType ComputeParamType(clang::QualType T);
/// Stores data required for differentiating a switch statement.
class SwitchStmtInfo {
public:
llvm::SmallVector<clang::SwitchCase*, 16> cases;
clang::Expr* switchStmtCond = nullptr;
clang::IfStmt* defaultIfBreakExpr = nullptr;
};

/// Maintains a stack of `SwitchStmtInfo`.
llvm::SmallVector<SwitchStmtInfo, 4> m_SwitchStmtsData;

SwitchStmtInfo* GetActiveSwitchStmtInfo() {
return &m_SwitchStmtsData.back();
}

SwitchStmtInfo* PushSwitchStmtInfo() {
m_SwitchStmtsData.emplace_back();
return &m_SwitchStmtsData.back();
}

void PopSwitchStmtInfo() { m_SwitchStmtsData.pop_back(); }
};
} // end namespace clad

Expand Down
14 changes: 2 additions & 12 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1577,16 +1577,6 @@ static SwitchCase* getContainedSwitchCaseStmt(const CompoundStmt* CS) {
return nullptr;
}

static void setSwitchCaseSubStmt(SwitchCase* SC, Stmt* subStmt) {
if (auto caseStmt = dyn_cast<CaseStmt>(SC)) {
caseStmt->setSubStmt(subStmt);
} else if (auto defaultStmt = dyn_cast<DefaultStmt>(SC)) {
defaultStmt->setSubStmt(subStmt);
} else {
assert(0 && "Unsupported switch case statement");
}
}

/// Returns top switch statement in the `SwitchStack` of the given
/// Function Scope.
static SwitchStmt* getTopSwitchStmtOfSwitchStack(sema::FunctionScopeInfo* FSI) {
Expand Down Expand Up @@ -1659,7 +1649,7 @@ StmtDiff BaseForwardModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) {
// been processed aka when all the statments in switch statement body
// have been processed.
if (activeSC) {
setSwitchCaseSubStmt(activeSC, endBlock());
utils::SetSwitchCaseSubStmt(activeSC, endBlock());
endScope();
activeSC = nullptr;
}
Expand Down Expand Up @@ -1687,7 +1677,7 @@ BaseForwardModeVisitor::DeriveSwitchStmtBodyHelper(const Stmt* stmt,
// corresponding to the active switch case label, and update its
// substatement.
if (activeSC) {
setSwitchCaseSubStmt(activeSC, endBlock());
utils::SetSwitchCaseSubStmt(activeSC, endBlock());
endScope();
}
// sub statement will be updated later, either when the corresponding
Expand Down
7 changes: 7 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,5 +632,12 @@ namespace clad {
finder.TraverseStmt(const_cast<Stmt*>(S));
return finder.hasCallExpr;
}

void SetSwitchCaseSubStmt(SwitchCase* SC, Stmt* subStmt) {
if (auto caseStmt = dyn_cast<CaseStmt>(SC))
caseStmt->setSubStmt(subStmt);
else
cast<DefaultStmt>(SC)->setSubStmt(subStmt);
}
} // namespace utils
} // namespace clad
219 changes: 209 additions & 10 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ConstantFolder.h"

#include "TBRAnalyzer.h"
#include "clad/Differentiator/DerivativeBuilder.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/ExternalRMVSource.h"
Expand All @@ -17,7 +18,9 @@

#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
#include "clang/AST/Stmt.h"
#include "clang/AST/TemplateBase.h"
#include "clang/Basic/TokenKinds.h"
#include "clang/Sema/Lookup.h"
#include "clang/Sema/Overload.h"
#include "clang/Sema/Scope.h"
Expand Down Expand Up @@ -3237,6 +3240,211 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return {forwardDS, reverseBlock};
}

// Basic idea used for differentiating switch statement is that in the reverse
// pass, processing of the differentiated statments of the switch statement
// body should start either from a `break` statement or from the last
// statement of the switch statement body and always end at a switch
// case/default statement.
//
// Therefore, here we keep track of which `break` was hit in the forward pass,
// or if we no `break` statement was hit at all in a variable or clad tape.
// This information is further used by an auxilliary switch statement in the
// reverse pass to jump the execution to the correct point (that is,
// differentiated statement of the statement just before the `break` statement
// that was hit in the forward pass)
StmtDiff ReverseModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) {
// Scope and blocks for the compound statement that encloses the switch
// statement in both the forward and the reverse pass. Block is required
// for handling condition variable and switch-init statement.
beginScope(Scope::DeclScope);
beginBlock(direction::forward);
beginBlock(direction::reverse);

// Handles switch init statement
if (SS->getInit()) {
StmtDiff switchInitDiff = DifferentiateSingleStmt(SS->getInit());
addToCurrentBlock(switchInitDiff.getStmt(), direction::forward);
addToCurrentBlock(switchInitDiff.getStmt_dx(), direction::reverse);
}

// Handles condition variable
if (SS->getConditionVariable()) {
StmtDiff condVarDiff =
DifferentiateSingleStmt(SS->getConditionVariableDeclStmt());
addToCurrentBlock(condVarDiff.getStmt(), direction::forward);
addToCurrentBlock(condVarDiff.getStmt_dx(), direction::reverse);
}

StmtDiff condDiff = DifferentiateSingleStmt(SS->getCond());
addToCurrentBlock(condDiff.getStmt(), direction::forward);
addToCurrentBlock(condDiff.getStmt_dx(), direction::reverse);
Expr* condExpr = nullptr;
clad_compat::llvm_Optional<CladTapeResult> condTape;

if (isInsideLoop) {
// If we are inside a loop, condition will be stored and used as follows:
//
// forward block:
// switch (clad::push(..., cond)) { ... }
//
// reverse block:
// switch (...) { ... }
// clad::pop(...);
condTape.emplace(MakeCladTapeFor(condDiff.getExpr(), "_cond"));
condExpr = condTape->Push;
} else {
condExpr = GlobalStoreAndRef(condDiff.getExpr(), "_cond").getExpr();
}

auto activeBreakContHandler = PushBreakContStmtHandler(
/*forSwitchStmt=*/true);
activeBreakContHandler->BeginCFSwitchStmtScope();
auto SSData = PushSwitchStmtInfo();

if (isInsideLoop)
SSData->switchStmtCond = condTape->Last();
else
SSData->switchStmtCond = condExpr;

// scope for the switch statement body.
beginScope(Scope::DeclScope);

const Stmt* body = SS->getBody();
StmtDiff bodyDiff = nullptr;
if (isa<CompoundStmt>(body))
bodyDiff = Visit(body);
else
bodyDiff = DifferentiateSingleStmt(body);

// Each switch case statement of the original function gets transformed to
// an if condition in the reverse pass. The if condition decides at runtime
// whether the processing of the differentiated statements of the switch
// statement body should stop or continue. This is based on the fact that
// processing of statements of switch statement body always starts at a case
// statement. For example,
// ```
// case 3:
// ```
// gets transformed to,
//
// ```
// if (3 == _cond)
// break;
// ```
//
// This kind of if expression cannot by easily formed for the default
// statement, thus, we instead compare value of the switch condition with
// the values of all the case statements to determine if the default
// statement was selected in the forward pass.
// Therefore,
//
// ```
// default:
// ```
//
// will get transformed to something like,
//
// ```
// if (_cond != 1 && _cond != 2 && _cond != 3)
// break;
// ```
if (SSData->defaultIfBreakExpr) {
Expr* breakCond = nullptr;
for (auto SC : SSData->cases) {
if (auto CS = dyn_cast<CaseStmt>(SC)) {
if (breakCond) {
breakCond = BuildOp(BinaryOperatorKind::BO_LAnd, breakCond,
BuildOp(BinaryOperatorKind::BO_NE,
SSData->switchStmtCond, CS->getLHS()));
} else {
breakCond = BuildOp(BinaryOperatorKind::BO_NE,
SSData->switchStmtCond, CS->getLHS());
}
}
}
if (!breakCond)
breakCond = m_Sema.ActOnCXXBoolLiteral(noLoc, tok::kw_true).get();
SSData->defaultIfBreakExpr->setCond(breakCond);
}

activeBreakContHandler->EndCFSwitchStmtScope();

// If switch statement contains no cases, then, no statement of the switch
// statement body will be processed in both the forward and the reverse
// pass. Thus, we do not need to add them in the differentiated function.
if (!(SSData->cases.empty())) {
Sema::ConditionResult condRes = m_Sema.ActOnCondition(
getCurrentScope(), noLoc, condExpr, Sema::ConditionKind::Switch);
SwitchStmt* forwardSS =
clad_compat::Sema_ActOnStartOfSwitchStmt(m_Sema, nullptr, condRes)
.getAs<SwitchStmt>();
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);

// Registers all the cases to the switch statement.
for (auto* SC : SSData->cases)
forwardSS->addSwitchCase(SC);

forwardSS =
m_Sema.ActOnFinishSwitchStmt(noLoc, forwardSS, bodyDiff.getStmt())
.getAs<SwitchStmt>();

addToCurrentBlock(forwardSS, direction::forward);
if (isInsideLoop)
addToCurrentBlock(condTape->Pop, direction::reverse);
addToCurrentBlock(bodyDiff.getStmt_dx(), direction::reverse);
}

PopBreakContStmtHandler();
PopSwitchStmtInfo();
return {endBlock(direction::forward), endBlock(direction::reverse)};
}

StmtDiff ReverseModeVisitor::VisitCaseStmt(const CaseStmt* CS) {
beginBlock(direction::forward);
beginBlock(direction::reverse);
auto SSData = GetActiveSwitchStmtInfo();

Expr* lhsClone = (CS->getLHS() ? Clone(CS->getLHS()) : nullptr);
Expr* rhsClone = (CS->getRHS() ? Clone(CS->getRHS()) : nullptr);

auto newSC = clad_compat::CaseStmt_Create(m_Sema.getASTContext(), lhsClone,
rhsClone, noLoc, noLoc, noLoc);

Expr* ifCond = BuildOp(BinaryOperatorKind::BO_EQ, newSC->getLHS(),
SSData->switchStmtCond);
Stmt* ifThen = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get();
Stmt* ifBreakExpr = clad_compat::IfStmt_Create(
m_Context, noLoc, false, nullptr, nullptr, ifCond, noLoc, noLoc, ifThen,
noLoc, nullptr);
SSData->cases.push_back(newSC);
addToCurrentBlock(ifBreakExpr, direction::reverse);
addToCurrentBlock(newSC, direction::forward);
auto diff = DifferentiateSingleStmt(CS->getSubStmt());
utils::SetSwitchCaseSubStmt(newSC, diff.getStmt());
addToCurrentBlock(diff.getStmt_dx(), direction::reverse);
return {endBlock(direction::forward), endBlock(direction::reverse)};
}

StmtDiff ReverseModeVisitor::VisitDefaultStmt(const DefaultStmt* DS) {
beginBlock(direction::reverse);
beginBlock(direction::forward);
auto* SSData = GetActiveSwitchStmtInfo();
DefaultStmt* newDefaultStmt =
new (m_Sema.getASTContext()) DefaultStmt(noLoc, noLoc, nullptr);
Stmt* ifThen = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get();
Stmt* ifBreakExpr = clad_compat::IfStmt_Create(
m_Context, noLoc, false, nullptr, nullptr, nullptr, noLoc, noLoc,
ifThen, noLoc, nullptr);
SSData->cases.push_back(newDefaultStmt);
SSData->defaultIfBreakExpr = cast<IfStmt>(ifBreakExpr);
addToCurrentBlock(ifBreakExpr, direction::reverse);
addToCurrentBlock(newDefaultStmt, direction::forward);
auto diff = DifferentiateSingleStmt(DS->getSubStmt());
utils::SetSwitchCaseSubStmt(newDefaultStmt, diff.getStmt());
addToCurrentBlock(diff.getStmt_dx(), direction::reverse);
return {endBlock(direction::forward), endBlock(direction::reverse)};
}

StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body,
LoopCounter& loopCounter,
Stmt* condVarDiff,
Expand Down Expand Up @@ -3380,10 +3588,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

CaseStmt* ReverseModeVisitor::BreakContStmtHandler::GetNextCFCaseStmt() {
// End scope for currenly active case statement, if any.
if (!m_SwitchCases.empty())
m_RMV.endScope();

++m_CaseCounter;
auto* counterLiteral = CreateSizeTLiteralExpr(m_CaseCounter);
CaseStmt* CS = clad_compat::CaseStmt_Create(m_RMV.m_Context, counterLiteral,
Expand All @@ -3397,8 +3601,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// corresponding next statements.
CS->setSubStmt(m_RMV.m_Sema.ActOnNullStmt(noLoc).get());

// begin scope for the new active switch case statement.
m_RMV.beginScope(Scope::DeclScope);
m_SwitchCases.push_back(CS);
return CS;
}
Expand All @@ -3412,12 +3614,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

void ReverseModeVisitor::BreakContStmtHandler::UpdateForwAndRevBlocks(
StmtDiff& bodyDiff) {
if (m_SwitchCases.empty())
if (m_SwitchCases.empty() && !m_IsInvokedBySwitchStmt)
return;

// end scope for last switch case.
m_RMV.endScope();

// Add case statement in the beginning of the reverse block
// and corresponding push expression for this case statement
// at the end of the forward block to cover the case when no
Expand Down
Loading

0 comments on commit fb1251e

Please sign in to comment.