diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 0e5a6d308..715707e0d 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -324,7 +324,9 @@ namespace clad { llvm::SmallVectorImpl& Exprs); bool ContainsFunctionCalls(const clang::Stmt* E); - } // namespace utils -} + + void SetSwitchCaseSubStmt(clang::SwitchCase* SC, clang::Stmt* subStmt); + } // namespace utils + } // namespace clad #endif diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index c0010bce3..9e0dc3ca5 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -289,7 +289,7 @@ namespace clad { llvm::StringRef prefix = "_t"); struct CladTapeResult { - ReverseModeVisitor& V; + ReverseModeVisitor& V; clang::Expr* Push; clang::Expr* Pop; clang::Expr* Ref; @@ -372,6 +372,9 @@ namespace clad { StmtDiff VisitMaterializeTemporaryExpr(const clang::MaterializeTemporaryExpr* MTE); StmtDiff VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* SCE); + StmtDiff VisitSwitchStmt(const clang::SwitchStmt* SS); + StmtDiff VisitCaseStmt(const clang::CaseStmt* CS); + StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS); VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD); StmtDiff VisitSubstNonTypeTemplateParmExpr( const clang::SubstNonTypeTemplateParmExpr* NTTP); @@ -483,7 +486,8 @@ namespace clad { clang::Stmt* forLoopIncDiff = nullptr, bool isForLoop = false); - /// This class modifies forward and reverse blocks of the loop + + /// This class modifies forward and reverse blocks of the loop/switch /// body so that `break` and `continue` statements are correctly /// handled. `break` and `continue` statements are handled by /// enclosing entire reverse block loop body in a switch statement @@ -526,6 +530,7 @@ namespace clad { ReverseModeVisitor& m_RMV; + const bool m_IsInvokedBySwitchStmt = false; /// Builds and returns a literal expression of type `std::size_t` with /// `value` as value. clang::Expr* CreateSizeTLiteralExpr(std::size_t value); @@ -540,7 +545,8 @@ namespace clad { clang::Expr* CreateCFTapePushExpr(std::size_t value); public: - BreakContStmtHandler(ReverseModeVisitor& RMV) : m_RMV(RMV) {} + BreakContStmtHandler(ReverseModeVisitor& RMV, bool forSwitchStmt = false) + : m_RMV(RMV), m_IsInvokedBySwitchStmt(forSwitchStmt) {} /// Begins control flow switch statement scope. /// Control flow switch statement is used to refer to the @@ -572,8 +578,8 @@ namespace clad { BreakContStmtHandler* GetActiveBreakContStmtHandler() { return &m_BreakContStmtHandlers.back(); } - BreakContStmtHandler* PushBreakContStmtHandler() { - m_BreakContStmtHandlers.emplace_back(*this); + BreakContStmtHandler* PushBreakContStmtHandler(bool forSwitchStmt=false) { + m_BreakContStmtHandlers.emplace_back(*this, forSwitchStmt); return &m_BreakContStmtHandlers.back(); } void PopBreakContStmtHandler() { @@ -607,6 +613,29 @@ namespace clad { clang::QualType ComputeAdjointType(clang::QualType T); clang::QualType ComputeParamType(clang::QualType T); + /// Stores data required for differentiating a switch statement. + class SwitchStmtInfo { + public: + llvm::SmallVector cases; + clang::Expr* switchStmtCond = nullptr; + clang::IfStmt* defaultIfBreakExpr = nullptr; + }; + + /// Maintains a stack of `SwitchStmtInfo`. + llvm::SmallVector m_SwitchStmtsData; + + SwitchStmtInfo* GetActiveSwitchStmtInfo() { + return &m_SwitchStmtsData.back(); + } + + SwitchStmtInfo* PushSwitchStmtInfo() { + m_SwitchStmtsData.emplace_back(); + return &m_SwitchStmtsData.back(); + } + + void PopSwitchStmtInfo() { + m_SwitchStmtsData.pop_back(); + } }; } // end namespace clad diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index c6689ec6e..cc297e4d5 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1577,16 +1577,6 @@ static SwitchCase* getContainedSwitchCaseStmt(const CompoundStmt* CS) { return nullptr; } -static void setSwitchCaseSubStmt(SwitchCase* SC, Stmt* subStmt) { - if (auto caseStmt = dyn_cast(SC)) { - caseStmt->setSubStmt(subStmt); - } else if (auto defaultStmt = dyn_cast(SC)) { - defaultStmt->setSubStmt(subStmt); - } else { - assert(0 && "Unsupported switch case statement"); - } -} - /// Returns top switch statement in the `SwitchStack` of the given /// Function Scope. static SwitchStmt* getTopSwitchStmtOfSwitchStack(sema::FunctionScopeInfo* FSI) { @@ -1659,7 +1649,7 @@ StmtDiff BaseForwardModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) { // been processed aka when all the statments in switch statement body // have been processed. if (activeSC) { - setSwitchCaseSubStmt(activeSC, endBlock()); + utils::SetSwitchCaseSubStmt(activeSC, endBlock()); endScope(); activeSC = nullptr; } @@ -1687,7 +1677,7 @@ BaseForwardModeVisitor::DeriveSwitchStmtBodyHelper(const Stmt* stmt, // corresponding to the active switch case label, and update its // substatement. if (activeSC) { - setSwitchCaseSubStmt(activeSC, endBlock()); + utils::SetSwitchCaseSubStmt(activeSC, endBlock()); endScope(); } // sub statement will be updated later, either when the corresponding diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 16da9e5cf..775e40776 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -632,5 +632,12 @@ namespace clad { finder.TraverseStmt(const_cast(S)); return finder.hasCallExpr; } + + void SetSwitchCaseSubStmt(SwitchCase* SC, Stmt* subStmt) { + if (auto caseStmt = dyn_cast(SC)) + caseStmt->setSubStmt(subStmt); + else + cast(SC)->setSubStmt(subStmt); + } } // namespace utils } // namespace clad diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index d0ec52a6c..be0b7a644 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -3237,6 +3237,214 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return {forwardDS, reverseBlock}; } + // Basic idea used for differentiating switch statement is that in the reverse + // pass, processing of the differentiated statments of the switch statement + // body should start either from a `break` statement or from the last + // statement of the switch statement body and always end at a switch + // case/default statement. + // + // Therefore, here we keep track of which `break` was hit in the forward pass, + // or if we no `break` statement was hit at all in a variable or clad tape. + // This information is further used by an auxilliary switch statement in the + // reverse pass to jump the execution to the correct point (that is, + // differentiated statement of the statement just before the `break` statement + // that was hit in the forward pass) + StmtDiff ReverseModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) { + // Scope and blocks for the compound statement that encloses the switch statement + // in both the forward and the reverse pass. + // Block is required handling condition variable and switch-init statement. + beginScope(Scope::DeclScope); + beginBlock(direction::forward); + beginBlock(direction::reverse); + + // Handles switch init statement + if (SS->getInit()) { + StmtDiff switchInitDiff = DifferentiateSingleStmt(SS->getInit()); + addToCurrentBlock(switchInitDiff.getStmt(), direction::forward); + addToCurrentBlock(switchInitDiff.getStmt_dx(), direction::reverse); + } + + // Handles condition variable + if (SS->getConditionVariable()) { + StmtDiff condVarDiff = DifferentiateSingleStmt(SS->getConditionVariableDeclStmt()); + addToCurrentBlock(condVarDiff.getStmt(), direction::forward); + addToCurrentBlock(condVarDiff.getStmt_dx(), direction::reverse); + } + // Condition is only cloned, and not differentiated. + // Its because conditions generally contain non-differentiable constructs, but + // this behaviour will lead to incorrect results if the condition expression + // modifies any variable. + Expr* condClone = (SS->getCond() ? Clone(SS->getCond()) : nullptr); + + Expr* condExpr = nullptr; + llvm::Optional condTape; + + if (isInsideLoop) { + // If we are inside a loop, condition will be stored and used as follows: + // + // forward block: + // switch (clad::push(..., cond)) { ... } + // + // reverse block: + // switch (...) { ... } + // clad::pop(...); + condTape.emplace(MakeCladTapeFor(condClone, "_cond")); + condExpr = condTape->Push; + } else { + condExpr = GlobalStoreAndRef(condClone, "_cond").getExpr(); + } + + auto activeBreakContHandler = PushBreakContStmtHandler( + /*forSwitchStmt=*/true); + activeBreakContHandler->BeginCFSwitchStmtScope(); + auto SSData = PushSwitchStmtInfo(); + + if (isInsideLoop) + SSData->switchStmtCond = condTape->Last(); + else + SSData->switchStmtCond = condExpr; + + // scope for the switch statement body. + beginScope(Scope::DeclScope); + + const Stmt* body = SS->getBody(); + StmtDiff bodyDiff = nullptr; + if (isa(body)) { + bodyDiff = Visit(body); + } else { + bodyDiff = DifferentiateSingleStmt(body); + } + + // Each switch case statement of the original function gets transformed to + // an if condition in the reverse pass. The if condition decides at runtime + // whether the processing of the differentiated statements of the switch statement + // body should stop or continue. This is based on the fact that processing + // of statements of switch statement body always starts at a case statement. + // For example, + // ``` + // case 3: + // ``` + // gets transformed to, + // + // ``` + // if (3 == _cond) + // break; + // ``` + // + // This kind of if expression cannot by easily formed for the default + // statement, thus, we instead compare value of the switch condition with + // the values of all the case statements to determine if the default + // statement was selected in the forward pass. + // Therefore, + // + // ``` + // default: + // ``` + // + // will get transformed to something like, + // + // ``` + // if (_cond != 1 && _cond != 2 && _cond != 3) + // break; + // ``` + if (SSData->defaultIfBreakExpr) { + Expr* breakCond = nullptr; + for (auto SC : SSData->cases) { + if (auto CS = dyn_cast(SC)) { + if (breakCond) { + breakCond = BuildOp(BinaryOperatorKind::BO_LAnd, breakCond, + BuildOp(BinaryOperatorKind::BO_NE, + SSData->switchStmtCond, CS->getLHS())); + } else { + breakCond = BuildOp(BinaryOperatorKind::BO_NE, SSData->switchStmtCond, CS->getLHS()); + } + } + } + SSData->defaultIfBreakExpr->setCond(breakCond); + } + + activeBreakContHandler->EndCFSwitchStmtScope(); + + // If switch statement contains no cases, then, no statement of the switch statement body + // will be processed in both the forward and the reverse pass. Thus, we do not need + // to add them in the differentiated function. + if (!(SSData->cases.empty())) { + Sema::ConditionResult condRes = m_Sema.ActOnCondition( + getCurrentScope(), noLoc, condExpr, Sema::ConditionKind::Switch); + SwitchStmt* forwardSS = + clad_compat::Sema_ActOnStartOfSwitchStmt(m_Sema, nullptr, condRes) + .getAs(); + activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); + + // Registers all the cases to the switch statement. + for (auto SC : SSData->cases) { + forwardSS->addSwitchCase(SC); + } + + forwardSS = m_Sema + .ActOnFinishSwitchStmt(noLoc, forwardSS, + bodyDiff.getStmt()) + .getAs(); + + addToCurrentBlock(forwardSS, direction::forward); + if (isInsideLoop) { + addToCurrentBlock(condTape->Pop, direction::reverse); + } + addToCurrentBlock(bodyDiff.getStmt_dx(), direction::reverse); + } + + PopBreakContStmtHandler(); + PopSwitchStmtInfo(); + return {endBlock(direction::forward), endBlock(direction::reverse)}; + } + + StmtDiff ReverseModeVisitor::VisitCaseStmt(const CaseStmt* CS) { + beginBlock(direction::forward); + beginBlock(direction::reverse); + auto SSData = GetActiveSwitchStmtInfo(); + + Expr* lhsClone = (CS->getLHS() ? Clone(CS->getLHS()) : nullptr); + Expr* rhsClone = (CS->getRHS() ? Clone(CS->getRHS()) : nullptr); + + auto newSC = clad_compat::CaseStmt_Create(m_Sema.getASTContext(), lhsClone, + rhsClone, noLoc, noLoc, noLoc); + + Expr* ifCond = BuildOp(BinaryOperatorKind::BO_EQ, newSC->getLHS(), + SSData->switchStmtCond); + Stmt* ifThen = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get(); + Stmt* ifBreakExpr = clad_compat::IfStmt_Create(m_Context, noLoc, false, + nullptr, nullptr, ifCond, + noLoc, noLoc, ifThen, noLoc, + nullptr); + SSData->cases.push_back(newSC); + addToCurrentBlock(ifBreakExpr, direction::reverse); + addToCurrentBlock(newSC, direction::forward); + auto diff = DifferentiateSingleStmt(CS->getSubStmt()); + utils::SetSwitchCaseSubStmt(newSC, diff.getStmt()); + addToCurrentBlock(diff.getStmt_dx(), direction::reverse); + return {endBlock(direction::forward), endBlock(direction::reverse)}; + } + + StmtDiff ReverseModeVisitor::VisitDefaultStmt(const DefaultStmt* DS) { + beginBlock(direction::reverse); + beginBlock(direction::forward); + auto SSData = GetActiveSwitchStmtInfo(); + auto newDefaultStmt = new (m_Sema.getASTContext()) DefaultStmt(noLoc, noLoc, nullptr); + Stmt* ifThen = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get(); + Stmt* ifBreakExpr = clad_compat::IfStmt_Create(m_Context, noLoc, false, + nullptr, nullptr, nullptr, + noLoc, noLoc, ifThen, noLoc, + nullptr); + SSData->cases.push_back(newDefaultStmt); + SSData->defaultIfBreakExpr = cast(ifBreakExpr); + addToCurrentBlock(ifBreakExpr, direction::reverse); + addToCurrentBlock(newDefaultStmt, direction::forward); + auto diff = DifferentiateSingleStmt(DS->getSubStmt()); + utils::SetSwitchCaseSubStmt(newDefaultStmt, diff.getStmt()); + addToCurrentBlock(diff.getStmt_dx(), direction::reverse); + return {endBlock(direction::forward), endBlock(direction::reverse)}; + } + StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body, LoopCounter& loopCounter, Stmt* condVarDiff, @@ -3380,10 +3588,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } CaseStmt* ReverseModeVisitor::BreakContStmtHandler::GetNextCFCaseStmt() { - // End scope for currenly active case statement, if any. - if (!m_SwitchCases.empty()) - m_RMV.endScope(); - ++m_CaseCounter; auto* counterLiteral = CreateSizeTLiteralExpr(m_CaseCounter); CaseStmt* CS = clad_compat::CaseStmt_Create(m_RMV.m_Context, counterLiteral, @@ -3397,8 +3601,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // corresponding next statements. CS->setSubStmt(m_RMV.m_Sema.ActOnNullStmt(noLoc).get()); - // begin scope for the new active switch case statement. - m_RMV.beginScope(Scope::DeclScope); m_SwitchCases.push_back(CS); return CS; } @@ -3412,12 +3614,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, void ReverseModeVisitor::BreakContStmtHandler::UpdateForwAndRevBlocks( StmtDiff& bodyDiff) { - if (m_SwitchCases.empty()) + if (m_SwitchCases.empty() && !m_IsInvokedBySwitchStmt) return; - // end scope for last switch case. - m_RMV.endScope(); - // Add case statement in the beginning of the reverse block // and corresponding push expression for this case statement // at the end of the forward block to cover the case when no diff --git a/test/Gradient/Switch.C b/test/Gradient/Switch.C new file mode 100644 index 000000000..c4b0286f2 --- /dev/null +++ b/test/Gradient/Switch.C @@ -0,0 +1,533 @@ +// RUN: %cladclang %s -I%S/../../include -oSwitch.out 2>&1 -lstdc++ -lm | FileCheck %s +// RUN: ./Switch.out | FileCheck -check-prefix=CHECK-EXEC %s +//CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" + +double fn1(double i, double j) { + double res = 0; + int count = 1; + switch (count) { + case 0: res += i * j; break; + case 1: res += i * i; { + case 2: res += j * j; + } + default: res += i * i * j * j; + } + return res; +} + +// CHECK: void fn1_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: int _d_count = 0; +// CHECK-NEXT: int _cond0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double _t3; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: int count = 1; +// CHECK-NEXT: { +// CHECK-NEXT: _cond0 = count; +// CHECK-NEXT: switch (_cond0) { +// CHECK-NEXT: { +// CHECK-NEXT: case 0: +// CHECK-NEXT: res += i * j; +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case 1: +// CHECK-NEXT: res += i * i; +// CHECK-NEXT: _t2 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: case 2: +// CHECK-NEXT: res += j * j; +// CHECK-NEXT: _t3 = res; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: default: +// CHECK-NEXT: res += i * i * j * j; +// CHECK-NEXT: _t4 = res; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t1, 2UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t1)) { +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t4; +// CHECK-NEXT: double _r_d3 = _d_res; +// CHECK-NEXT: * _d_i += _r_d3 * j * j * i; +// CHECK-NEXT: * _d_i += i * _r_d3 * j * j; +// CHECK-NEXT: * _d_j += i * i * _r_d3 * j; +// CHECK-NEXT: * _d_j += i * i * j * _r_d3; +// CHECK-NEXT: } +// CHECK-NEXT: if (_cond0 != 0 && _cond0 != 1 && _cond0 != 2) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t3; +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: * _d_j += _r_d2 * j; +// CHECK-NEXT: * _d_j += j * _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: if (2 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t2; +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: * _d_i += _r_d1 * i; +// CHECK-NEXT: * _d_i += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: if (1 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j; +// CHECK-NEXT: * _d_j += i * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: if (0 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn2(double i, double j) { + double res = 0; + switch (int count = 2) { + res += i * i * j * j; + res += 50 * i; + case 0: res += i; break; + case 1: res += j; + case 2: res += i * j; break; + default: res += i + j; + } + return res; +} + +// CHECK: void fn2_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: int _d_count = 0; +// CHECK-NEXT: int _cond0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double _t1; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double _t5; +// CHECK-NEXT: double _t6; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: { +// CHECK-NEXT: int count = 2; +// CHECK-NEXT: _cond0 = count; +// CHECK-NEXT: switch (_cond0) { +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: res += i * i * j * j; +// CHECK-NEXT: _t1 = res; +// CHECK-NEXT: res += 50 * i; +// CHECK-NEXT: { +// CHECK-NEXT: case 0: +// CHECK-NEXT: res += i; +// CHECK-NEXT: _t2 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t3, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case 1: +// CHECK-NEXT: res += j; +// CHECK-NEXT: _t4 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case 2: +// CHECK-NEXT: res += i * j; +// CHECK-NEXT: _t5 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t3, 2UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: default: +// CHECK-NEXT: res += i + j; +// CHECK-NEXT: _t6 = res; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t3, 3UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t3)) { +// CHECK-NEXT: case 3UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t6; +// CHECK-NEXT: double _r_d5 = _d_res; +// CHECK-NEXT: * _d_i += _r_d5; +// CHECK-NEXT: * _d_j += _r_d5; +// CHECK-NEXT: } +// CHECK-NEXT: if (_cond0 != 0 && _cond0 != 1 && _cond0 != 2) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t5; +// CHECK-NEXT: double _r_d4 = _d_res; +// CHECK-NEXT: * _d_i += _r_d4 * j; +// CHECK-NEXT: * _d_j += i * _r_d4; +// CHECK-NEXT: } +// CHECK-NEXT: if (2 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t4; +// CHECK-NEXT: double _r_d3 = _d_res; +// CHECK-NEXT: * _d_j += _r_d3; +// CHECK-NEXT: } +// CHECK-NEXT: if (1 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t2; +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: * _d_i += _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: if (0 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: res = _t1; +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: * _d_i += 50 * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j * j * i; +// CHECK-NEXT: * _d_i += i * _r_d0 * j * j; +// CHECK-NEXT: * _d_j += i * i * _r_d0 * j; +// CHECK-NEXT: * _d_j += i * i * j * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn3(double i, double j) { + double res = 0; + int counter = 2; + while (counter--) { + switch (counter) { + case 0: res += i * i * j * j; + case 1: { + res += i * i; + } break; + case 2: res += j * j; + default: res += i + j; + } + } + return res; +} + +// CHECK: void fn3_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: int _d_counter = 0; +// CHECK-NEXT: unsigned long _t0; +// CHECK-NEXT: clad::tape _cond0 = {}; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: clad::tape _t2 = {}; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: clad::tape _t4 = {}; +// CHECK-NEXT: clad::tape _t5 = {}; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: int counter = 2; +// CHECK-NEXT: _t0 = 0; +// CHECK-NEXT: while (counter--) +// CHECK-NEXT: { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::push(_cond0, counter)) { +// CHECK-NEXT: { +// CHECK-NEXT: case 0: +// CHECK-NEXT: res += i * i * j * j; +// CHECK-NEXT: clad::push(_t1, res); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case 1: +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t2, res); +// CHECK-NEXT: res += i * i; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t3, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case 2: +// CHECK-NEXT: res += j * j; +// CHECK-NEXT: clad::push(_t4, res); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: default: +// CHECK-NEXT: res += i + j; +// CHECK-NEXT: clad::push(_t5, res); +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t3, 2UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: while (_t0) +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t3)) { +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t5); +// CHECK-NEXT: double _r_d3 = _d_res; +// CHECK-NEXT: * _d_i += _r_d3; +// CHECK-NEXT: * _d_j += _r_d3; +// CHECK-NEXT: } +// CHECK-NEXT: if (clad::back(_cond0) != 0 && clad::back(_cond0) != 1 && clad::back(_cond0) != 2) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t4); +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: * _d_j += _r_d2 * j; +// CHECK-NEXT: * _d_j += j * _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: if (2 == clad::back(_cond0)) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t2); +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: * _d_i += _r_d1 * i; +// CHECK-NEXT: * _d_i += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: if (1 == clad::back(_cond0)) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t1); +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j * j * i; +// CHECK-NEXT: * _d_i += i * _r_d0 * j * j; +// CHECK-NEXT: * _d_j += i * i * _r_d0 * j; +// CHECK-NEXT: * _d_j += i * i * j * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: if (0 == clad::back(_cond0)) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: clad::pop(_cond0); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: _t0--; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn4(double i, double j) { + double res = 0; + switch (1) { + case 0: res += i * i * j * j; break; + case 1: + int counter = 2; + while (counter--) { + res += i * j; + } + break; + } + return res; +} + +// CHECK: void fn4_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: int _d_counter = 0; +// CHECK-NEXT: unsigned long _t2; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: { +// CHECK-NEXT: switch (1) { +// CHECK-NEXT: { +// CHECK-NEXT: case 0: +// CHECK-NEXT: res += i * i * j * j; +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case 1: +// CHECK-NEXT: int counter = 2; +// CHECK-NEXT: } +// CHECK-NEXT: _t2 = 0; +// CHECK-NEXT: while (counter--) +// CHECK-NEXT: { +// CHECK-NEXT: _t2++; +// CHECK-NEXT: clad::push(_t3, res); +// CHECK-NEXT: res += i * j; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 2UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t1, 3UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t1)) { +// CHECK-NEXT: case 3UL: +// CHECK-NEXT: ; +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: while (_t2) +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t3); +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: * _d_i += _r_d1 * j; +// CHECK-NEXT: * _d_j += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: _t2--; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: if (1 == 1) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j * j * i; +// CHECK-NEXT: * _d_i += i * _r_d0 * j * j; +// CHECK-NEXT: * _d_j += i * i * _r_d0 * j; +// CHECK-NEXT: * _d_j += i * i * j * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: if (0 == 1) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn5(double i, double j) { + double res=0; + switch(int count = 1) + case 1: + res += i*j; + return res; +} + +// CHECK: void fn5_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: int _d_count = 0; +// CHECK-NEXT: int _cond0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: { +// CHECK-NEXT: int count = 1; +// CHECK-NEXT: _cond0 = count; +// CHECK-NEXT: switch (_cond0) { +// CHECK-NEXT: case 1: +// CHECK-NEXT: res += i * j; +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: clad::push(_t1, 1UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t1)) { +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j; +// CHECK-NEXT: * _d_j += i * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: if (1 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +#define TEST_2(F, x, y) \ + { \ + result[0] = result[1] = 0; \ + auto d_##F = clad::gradient(F); \ + d_##F.execute(x, y, result, result + 1); \ + printf("{%.2f, %.2f}\n", result[0], result[1]); \ + } + +int main() { + double result[2] = {}; + clad::array_ref result_ref(result, 2); + + TEST_2(fn1, 3, 5); // CHECK-EXEC: {156.00, 100.00} + TEST_2(fn2, 3, 5); // CHECK-EXEC: {5.00, 3.00} + TEST_2(fn3, 3, 5); // CHECK-EXEC: {162.00, 90.00} + TEST_2(fn4, 3, 5); // CHECK-EXEC: {10.00, 6.00} + TEST_2(fn5, 3, 5); // CHECK-EXEC: {5.00, 3.00} +} \ No newline at end of file diff --git a/test/Gradient/SwitchInit.C b/test/Gradient/SwitchInit.C new file mode 100644 index 000000000..355b91855 --- /dev/null +++ b/test/Gradient/SwitchInit.C @@ -0,0 +1,132 @@ +// RUN: %cladclang %s -I%S/../../include -std=c++17 -oSwitchInit.out 2>&1 -lstdc++ -lm | FileCheck %s +// RUN: ./SwitchInit.out | FileCheck -check-prefix=CHECK-EXEC %s +//CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" + +double fn1(double i, double j) { + double res = 0; + switch (int count = 1;count) { + case 0: res += i * j; break; + case 1: res += i * i; { + case 2: res += j * j; + } + default: res += i * i * j * j; + } + return res; +} + +// CHECK: void fn1_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: int _d_count = 0; +// CHECK-NEXT: int _cond0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double _t3; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: { +// CHECK-NEXT: int count = 1; +// CHECK-NEXT: _cond0 = count; +// CHECK-NEXT: switch (_cond0) { +// CHECK-NEXT: { +// CHECK-NEXT: case 0: +// CHECK-NEXT: res += i * j; +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case 1: +// CHECK-NEXT: res += i * i; +// CHECK-NEXT: _t2 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: case 2: +// CHECK-NEXT: res += j * j; +// CHECK-NEXT: _t3 = res; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: default: +// CHECK-NEXT: res += i * i * j * j; +// CHECK-NEXT: _t4 = res; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t1, 2UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t1)) { +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t4; +// CHECK-NEXT: double _r_d3 = _d_res; +// CHECK-NEXT: * _d_i += _r_d3 * j * j * i; +// CHECK-NEXT: * _d_i += i * _r_d3 * j * j; +// CHECK-NEXT: * _d_j += i * i * _r_d3 * j; +// CHECK-NEXT: * _d_j += i * i * j * _r_d3; +// CHECK-NEXT: } +// CHECK-NEXT: if (_cond0 != 0 && _cond0 != 1 && _cond0 != 2) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t3; +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: * _d_j += _r_d2 * j; +// CHECK-NEXT: * _d_j += j * _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: if (2 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t2; +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: * _d_i += _r_d1 * i; +// CHECK-NEXT: * _d_i += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: if (1 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j; +// CHECK-NEXT: * _d_j += i * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: if (0 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +#define TEST_2(F, x, y) \ + { \ + result[0] = result[1] = 0; \ + auto d_##F = clad::gradient(F); \ + d_##F.execute(x, y, result, result + 1); \ + printf("{%.2f, %.2f}\n", result[0], result[1]); \ + } + +int main() { + double result[2] = {}; + clad::array_ref result_ref(result, 2); + + TEST_2(fn1, 3, 5); // CHECK-EXEC: {156.00, 100.00} +} \ No newline at end of file