From 16a652b3dece8a4dedf0c3f9ba6b77a132447cf4 Mon Sep 17 00:00:00 2001 From: Parth Date: Tue, 13 Feb 2024 02:53:54 +0530 Subject: [PATCH] Add support for differentiating switch stmt in the reverse mode AD. (#339) This commit adds support for differentiating switch statements in the reverse mode AD. The basic idea used to differentiate switch statement is that in the forward pass, processing of the statements of the switch statement body always starts from a case/default label and ends at a break statement or at the end of the switch body. Similarly, in the reverse pass, processing of the differentiated statements of the switch statement body will start from the statement just above the break statement that was hit or from the last differentiated statement in the case when no break statement was hit. Thus, we can keep track of which break statement was hit in the forward pass or if no break statement was hit at all in a variable. This information is further used by an auxiliary switch statement in the reverse pass to jump the execution to the correct point (that is, differentiated statement of the statement just before the break statement that was hit in the forward pass). In this strategy, each switch case statement of the original function gets transformed to an if condition in the reverse pass. The if condition decides at runtime whether the processing of the differentiated statements of the switch statement body should stop or continue. This is again based on the fact that the processing of statements of the switch statement body always starts at a case statement. For an example, consider this code snippet: ```cpp switch (count) { case 0: a += i; break; case 2: a += 4 * i; break; default: a += 10 * i; } case 0 of this code snippet gets transformed to the following in the differentiated function: forward pass: { case 0: a += i; } { clad::push(_t0, 1UL); // this is used to keep track if this break was hit; 1UL is used to represent the case number break; } reverse pass: case 1UL:; // this case is selected if the corresponding break was hit in the forward pass { { double _r_d0 = _d_a; _d_a += _r_d0; *_d_i += _r_d0; _d_a -= _r_d0; } if (0 == _cond0) // If case 0: was selected in the forward pass, we should break out of processing differentiated switch stmt body here. break; } ``` --- include/clad/Differentiator/CladUtils.h | 6 +- .../clad/Differentiator/ReverseModeVisitor.h | 33 +- lib/Differentiator/BaseForwardModeVisitor.cpp | 14 +- lib/Differentiator/CladUtils.cpp | 7 + lib/Differentiator/ReverseModeVisitor.cpp | 219 +++++- test/Gradient/Switch.C | 716 ++++++++++++++++++ test/Gradient/SwitchInit.C | 133 ++++ 7 files changed, 1100 insertions(+), 28 deletions(-) create mode 100644 test/Gradient/Switch.C create mode 100644 test/Gradient/SwitchInit.C diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 0e5a6d308..715707e0d 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -324,7 +324,9 @@ namespace clad { llvm::SmallVectorImpl& Exprs); bool ContainsFunctionCalls(const clang::Stmt* E); - } // namespace utils -} + + void SetSwitchCaseSubStmt(clang::SwitchCase* SC, clang::Stmt* subStmt); + } // namespace utils + } // namespace clad #endif diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 129d02b98..52473036b 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -374,6 +374,9 @@ namespace clad { StmtDiff VisitMaterializeTemporaryExpr(const clang::MaterializeTemporaryExpr* MTE); StmtDiff VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* SCE); + StmtDiff VisitSwitchStmt(const clang::SwitchStmt* SS); + StmtDiff VisitCaseStmt(const clang::CaseStmt* CS); + StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS); VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD); StmtDiff VisitSubstNonTypeTemplateParmExpr( const clang::SubstNonTypeTemplateParmExpr* NTTP); @@ -485,7 +488,7 @@ namespace clad { clang::Stmt* forLoopIncDiff = nullptr, bool isForLoop = false); - /// This class modifies forward and reverse blocks of the loop + /// This class modifies forward and reverse blocks of the loop/switch /// body so that `break` and `continue` statements are correctly /// handled. `break` and `continue` statements are handled by /// enclosing entire reverse block loop body in a switch statement @@ -528,6 +531,7 @@ namespace clad { ReverseModeVisitor& m_RMV; + const bool m_IsInvokedBySwitchStmt = false; /// Builds and returns a literal expression of type `std::size_t` with /// `value` as value. clang::Expr* CreateSizeTLiteralExpr(std::size_t value); @@ -542,7 +546,8 @@ namespace clad { clang::Expr* CreateCFTapePushExpr(std::size_t value); public: - BreakContStmtHandler(ReverseModeVisitor& RMV) : m_RMV(RMV) {} + BreakContStmtHandler(ReverseModeVisitor& RMV, bool forSwitchStmt = false) + : m_RMV(RMV), m_IsInvokedBySwitchStmt(forSwitchStmt) {} /// Begins control flow switch statement scope. /// Control flow switch statement is used to refer to the @@ -574,8 +579,8 @@ namespace clad { BreakContStmtHandler* GetActiveBreakContStmtHandler() { return &m_BreakContStmtHandlers.back(); } - BreakContStmtHandler* PushBreakContStmtHandler() { - m_BreakContStmtHandlers.emplace_back(*this); + BreakContStmtHandler* PushBreakContStmtHandler(bool forSwitchStmt = false) { + m_BreakContStmtHandlers.emplace_back(*this, forSwitchStmt); return &m_BreakContStmtHandlers.back(); } void PopBreakContStmtHandler() { @@ -609,6 +614,26 @@ namespace clad { clang::QualType ComputeAdjointType(clang::QualType T); clang::QualType ComputeParamType(clang::QualType T); + /// Stores data required for differentiating a switch statement. + struct SwitchStmtInfo { + llvm::SmallVector cases; + clang::Expr* switchStmtCond = nullptr; + clang::IfStmt* defaultIfBreakExpr = nullptr; + }; + + /// Maintains a stack of `SwitchStmtInfo`. + llvm::SmallVector m_SwitchStmtsData; + + SwitchStmtInfo* GetActiveSwitchStmtInfo() { + return &m_SwitchStmtsData.back(); + } + + SwitchStmtInfo* PushSwitchStmtInfo() { + m_SwitchStmtsData.emplace_back(); + return &m_SwitchStmtsData.back(); + } + + void PopSwitchStmtInfo() { m_SwitchStmtsData.pop_back(); } }; } // end namespace clad diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index be71936a0..37f32bf81 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1592,16 +1592,6 @@ static SwitchCase* getContainedSwitchCaseStmt(const CompoundStmt* CS) { return nullptr; } -static void setSwitchCaseSubStmt(SwitchCase* SC, Stmt* subStmt) { - if (auto caseStmt = dyn_cast(SC)) { - caseStmt->setSubStmt(subStmt); - } else if (auto defaultStmt = dyn_cast(SC)) { - defaultStmt->setSubStmt(subStmt); - } else { - assert(0 && "Unsupported switch case statement"); - } -} - /// Returns top switch statement in the `SwitchStack` of the given /// Function Scope. static SwitchStmt* getTopSwitchStmtOfSwitchStack(sema::FunctionScopeInfo* FSI) { @@ -1674,7 +1664,7 @@ StmtDiff BaseForwardModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) { // been processed aka when all the statments in switch statement body // have been processed. if (activeSC) { - setSwitchCaseSubStmt(activeSC, endBlock()); + utils::SetSwitchCaseSubStmt(activeSC, endBlock()); endScope(); activeSC = nullptr; } @@ -1702,7 +1692,7 @@ BaseForwardModeVisitor::DeriveSwitchStmtBodyHelper(const Stmt* stmt, // corresponding to the active switch case label, and update its // substatement. if (activeSC) { - setSwitchCaseSubStmt(activeSC, endBlock()); + utils::SetSwitchCaseSubStmt(activeSC, endBlock()); endScope(); } // sub statement will be updated later, either when the corresponding diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 16da9e5cf..6be4b8349 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -632,5 +632,12 @@ namespace clad { finder.TraverseStmt(const_cast(S)); return finder.hasCallExpr; } + + void SetSwitchCaseSubStmt(SwitchCase* SC, Stmt* subStmt) { + if (auto* caseStmt = dyn_cast(SC)) + caseStmt->setSubStmt(subStmt); + else + cast(SC)->setSubStmt(subStmt); + } } // namespace utils } // namespace clad diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index ab496b1ab..bddb9de11 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -9,6 +9,7 @@ #include "ConstantFolder.h" #include "TBRAnalyzer.h" +#include "clad/Differentiator/DerivativeBuilder.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" #include "clad/Differentiator/ExternalRMVSource.h" @@ -17,7 +18,9 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" +#include "clang/AST/Stmt.h" #include "clang/AST/TemplateBase.h" +#include "clang/Basic/TokenKinds.h" #include "clang/Sema/Lookup.h" #include "clang/Sema/Overload.h" #include "clang/Sema/Scope.h" @@ -3258,6 +3261,211 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return {forwardDS, reverseBlock}; } + // Basic idea used for differentiating switch statement is that in the reverse + // pass, processing of the differentiated statments of the switch statement + // body should start either from a `break` statement or from the last + // statement of the switch statement body and always end at a switch + // case/default statement. + // + // Therefore, here we keep track of which `break` was hit in the forward pass, + // or if we no `break` statement was hit at all in a variable or clad tape. + // This information is further used by an auxilliary switch statement in the + // reverse pass to jump the execution to the correct point (that is, + // differentiated statement of the statement just before the `break` statement + // that was hit in the forward pass) + StmtDiff ReverseModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) { + // Scope and blocks for the compound statement that encloses the switch + // statement in both the forward and the reverse pass. Block is required + // for handling condition variable and switch-init statement. + beginScope(Scope::DeclScope); + beginBlock(direction::forward); + beginBlock(direction::reverse); + + // Handles switch init statement + if (SS->getInit()) { + StmtDiff switchInitDiff = DifferentiateSingleStmt(SS->getInit()); + addToCurrentBlock(switchInitDiff.getStmt(), direction::forward); + addToCurrentBlock(switchInitDiff.getStmt_dx(), direction::reverse); + } + + // Handles condition variable + if (SS->getConditionVariable()) { + StmtDiff condVarDiff = + DifferentiateSingleStmt(SS->getConditionVariableDeclStmt()); + addToCurrentBlock(condVarDiff.getStmt(), direction::forward); + addToCurrentBlock(condVarDiff.getStmt_dx(), direction::reverse); + } + + StmtDiff condDiff = DifferentiateSingleStmt(SS->getCond()); + addToCurrentBlock(condDiff.getStmt(), direction::forward); + addToCurrentBlock(condDiff.getStmt_dx(), direction::reverse); + Expr* condExpr = nullptr; + clad_compat::llvm_Optional condTape; + + if (isInsideLoop) { + // If we are inside a loop, condition will be stored and used as follows: + // + // forward block: + // switch (clad::push(..., cond)) { ... } + // + // reverse block: + // switch (...) { ... } + // clad::pop(...); + condTape.emplace(MakeCladTapeFor(condDiff.getExpr(), "_cond")); + condExpr = condTape->Push; + } else { + condExpr = GlobalStoreAndRef(condDiff.getExpr(), "_cond").getExpr(); + } + + auto* activeBreakContHandler = PushBreakContStmtHandler( + /*forSwitchStmt=*/true); + activeBreakContHandler->BeginCFSwitchStmtScope(); + auto* SSData = PushSwitchStmtInfo(); + + if (isInsideLoop) + SSData->switchStmtCond = condTape->Last(); + else + SSData->switchStmtCond = condExpr; + + // scope for the switch statement body. + beginScope(Scope::DeclScope); + + const Stmt* body = SS->getBody(); + StmtDiff bodyDiff = nullptr; + if (isa(body)) + bodyDiff = Visit(body); + else + bodyDiff = DifferentiateSingleStmt(body); + + // Each switch case statement of the original function gets transformed to + // an if condition in the reverse pass. The if condition decides at runtime + // whether the processing of the differentiated statements of the switch + // statement body should stop or continue. This is based on the fact that + // processing of statements of switch statement body always starts at a case + // statement. For example, + // ``` + // case 3: + // ``` + // gets transformed to, + // + // ``` + // if (3 == _cond) + // break; + // ``` + // + // This kind of if expression cannot by easily formed for the default + // statement, thus, we instead compare value of the switch condition with + // the values of all the case statements to determine if the default + // statement was selected in the forward pass. + // Therefore, + // + // ``` + // default: + // ``` + // + // will get transformed to something like, + // + // ``` + // if (_cond != 1 && _cond != 2 && _cond != 3) + // break; + // ``` + if (SSData->defaultIfBreakExpr) { + Expr* breakCond = nullptr; + for (auto* SC : SSData->cases) { + if (auto* CS = dyn_cast(SC)) { + if (breakCond) { + breakCond = BuildOp(BinaryOperatorKind::BO_LAnd, breakCond, + BuildOp(BinaryOperatorKind::BO_NE, + SSData->switchStmtCond, CS->getLHS())); + } else { + breakCond = BuildOp(BinaryOperatorKind::BO_NE, + SSData->switchStmtCond, CS->getLHS()); + } + } + } + if (!breakCond) + breakCond = m_Sema.ActOnCXXBoolLiteral(noLoc, tok::kw_true).get(); + SSData->defaultIfBreakExpr->setCond(breakCond); + } + + activeBreakContHandler->EndCFSwitchStmtScope(); + + // If switch statement contains no cases, then, no statement of the switch + // statement body will be processed in both the forward and the reverse + // pass. Thus, we do not need to add them in the differentiated function. + if (!(SSData->cases.empty())) { + Sema::ConditionResult condRes = m_Sema.ActOnCondition( + getCurrentScope(), noLoc, condExpr, Sema::ConditionKind::Switch); + SwitchStmt* forwardSS = + clad_compat::Sema_ActOnStartOfSwitchStmt(m_Sema, nullptr, condRes) + .getAs(); + activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); + + // Registers all the cases to the switch statement. + for (auto* SC : SSData->cases) + forwardSS->addSwitchCase(SC); + + forwardSS = + m_Sema.ActOnFinishSwitchStmt(noLoc, forwardSS, bodyDiff.getStmt()) + .getAs(); + + addToCurrentBlock(forwardSS, direction::forward); + if (isInsideLoop) + addToCurrentBlock(condTape->Pop, direction::reverse); + addToCurrentBlock(bodyDiff.getStmt_dx(), direction::reverse); + } + + PopBreakContStmtHandler(); + PopSwitchStmtInfo(); + return {endBlock(direction::forward), endBlock(direction::reverse)}; + } + + StmtDiff ReverseModeVisitor::VisitCaseStmt(const CaseStmt* CS) { + beginBlock(direction::forward); + beginBlock(direction::reverse); + SwitchStmtInfo* SSData = GetActiveSwitchStmtInfo(); + + Expr* lhsClone = (CS->getLHS() ? Clone(CS->getLHS()) : nullptr); + Expr* rhsClone = (CS->getRHS() ? Clone(CS->getRHS()) : nullptr); + + auto* newSC = clad_compat::CaseStmt_Create(m_Sema.getASTContext(), lhsClone, + rhsClone, noLoc, noLoc, noLoc); + + Expr* ifCond = BuildOp(BinaryOperatorKind::BO_EQ, newSC->getLHS(), + SSData->switchStmtCond); + Stmt* ifThen = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get(); + Stmt* ifBreakExpr = clad_compat::IfStmt_Create( + m_Context, noLoc, false, nullptr, nullptr, ifCond, noLoc, noLoc, ifThen, + noLoc, nullptr); + SSData->cases.push_back(newSC); + addToCurrentBlock(ifBreakExpr, direction::reverse); + addToCurrentBlock(newSC, direction::forward); + auto diff = DifferentiateSingleStmt(CS->getSubStmt()); + utils::SetSwitchCaseSubStmt(newSC, diff.getStmt()); + addToCurrentBlock(diff.getStmt_dx(), direction::reverse); + return {endBlock(direction::forward), endBlock(direction::reverse)}; + } + + StmtDiff ReverseModeVisitor::VisitDefaultStmt(const DefaultStmt* DS) { + beginBlock(direction::reverse); + beginBlock(direction::forward); + auto* SSData = GetActiveSwitchStmtInfo(); + auto* newDefaultStmt = + new (m_Sema.getASTContext()) DefaultStmt(noLoc, noLoc, nullptr); + Stmt* ifThen = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get(); + Stmt* ifBreakExpr = clad_compat::IfStmt_Create( + m_Context, noLoc, false, nullptr, nullptr, nullptr, noLoc, noLoc, + ifThen, noLoc, nullptr); + SSData->cases.push_back(newDefaultStmt); + SSData->defaultIfBreakExpr = cast(ifBreakExpr); + addToCurrentBlock(ifBreakExpr, direction::reverse); + addToCurrentBlock(newDefaultStmt, direction::forward); + auto diff = DifferentiateSingleStmt(DS->getSubStmt()); + utils::SetSwitchCaseSubStmt(newDefaultStmt, diff.getStmt()); + addToCurrentBlock(diff.getStmt_dx(), direction::reverse); + return {endBlock(direction::forward), endBlock(direction::reverse)}; + } + StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body, LoopCounter& loopCounter, Stmt* condVarDiff, @@ -3401,10 +3609,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } CaseStmt* ReverseModeVisitor::BreakContStmtHandler::GetNextCFCaseStmt() { - // End scope for currenly active case statement, if any. - if (!m_SwitchCases.empty()) - m_RMV.endScope(); - ++m_CaseCounter; auto* counterLiteral = CreateSizeTLiteralExpr(m_CaseCounter); CaseStmt* CS = clad_compat::CaseStmt_Create(m_RMV.m_Context, counterLiteral, @@ -3418,8 +3622,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // corresponding next statements. CS->setSubStmt(m_RMV.m_Sema.ActOnNullStmt(noLoc).get()); - // begin scope for the new active switch case statement. - m_RMV.beginScope(Scope::DeclScope); m_SwitchCases.push_back(CS); return CS; } @@ -3433,12 +3635,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, void ReverseModeVisitor::BreakContStmtHandler::UpdateForwAndRevBlocks( StmtDiff& bodyDiff) { - if (m_SwitchCases.empty()) + if (m_SwitchCases.empty() && !m_IsInvokedBySwitchStmt) return; - // end scope for last switch case. - m_RMV.endScope(); - // Add case statement in the beginning of the reverse block // and corresponding push expression for this case statement // at the end of the forward block to cover the case when no diff --git a/test/Gradient/Switch.C b/test/Gradient/Switch.C new file mode 100644 index 000000000..686a3f1b1 --- /dev/null +++ b/test/Gradient/Switch.C @@ -0,0 +1,716 @@ +// RUN: %cladclang %s -I%S/../../include -oSwitch.out 2>&1 -lstdc++ -lm | FileCheck %s +// RUN: ./Switch.out | FileCheck -check-prefix=CHECK-EXEC %s +//CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" +#include "../TestUtils.h" + +double fn1(double i, double j) { + double res = 0; + int count = 1; + switch (count) { + case 0: res += i * j; break; + case 1: res += i * i; { + case 2: res += j * j; + } + default: res += i * i * j * j; + } + return res; +} + +// CHECK: void fn1_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: int _d_count = 0; +// CHECK-NEXT: int _cond0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double _t3; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: int count = 1; +// CHECK-NEXT: { +// CHECK-NEXT: _cond0 = count; +// CHECK-NEXT: switch (_cond0) { +// CHECK-NEXT: { +// CHECK-NEXT: case 0: +// CHECK-NEXT: res += i * j; +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case 1: +// CHECK-NEXT: res += i * i; +// CHECK-NEXT: _t2 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: case 2: +// CHECK-NEXT: res += j * j; +// CHECK-NEXT: _t3 = res; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: default: +// CHECK-NEXT: res += i * i * j * j; +// CHECK-NEXT: _t4 = res; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t1, 2UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t1)) { +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t4; +// CHECK-NEXT: double _r_d3 = _d_res; +// CHECK-NEXT: * _d_i += _r_d3 * j * j * i; +// CHECK-NEXT: * _d_i += i * _r_d3 * j * j; +// CHECK-NEXT: * _d_j += i * i * _r_d3 * j; +// CHECK-NEXT: * _d_j += i * i * j * _r_d3; +// CHECK-NEXT: } +// CHECK-NEXT: if (_cond0 != 0 && _cond0 != 1 && _cond0 != 2) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t3; +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: * _d_j += _r_d2 * j; +// CHECK-NEXT: * _d_j += j * _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: if (2 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t2; +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: * _d_i += _r_d1 * i; +// CHECK-NEXT: * _d_i += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: if (1 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j; +// CHECK-NEXT: * _d_j += i * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: if (0 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn2(double i, double j) { + double res = 0; + switch (int count = 2) { + res += i * i * j * j; + res += 50 * i; + case 0: res += i; break; + case 1: res += j; + case 2: res += i * j; break; + default: res += i + j; + } + return res; +} + +// CHECK: void fn2_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: int _d_count = 0; +// CHECK-NEXT: int count = 0; +// CHECK-NEXT: int _cond0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double _t1; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double _t5; +// CHECK-NEXT: double _t6; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: { +// CHECK-NEXT: count = 2; +// CHECK-NEXT: _cond0 = count; +// CHECK-NEXT: switch (_cond0) { +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: res += i * i * j * j; +// CHECK-NEXT: _t1 = res; +// CHECK-NEXT: res += 50 * i; +// CHECK-NEXT: { +// CHECK-NEXT: case 0: +// CHECK-NEXT: res += i; +// CHECK-NEXT: _t2 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t3, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case 1: +// CHECK-NEXT: res += j; +// CHECK-NEXT: _t4 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case 2: +// CHECK-NEXT: res += i * j; +// CHECK-NEXT: _t5 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t3, 2UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: default: +// CHECK-NEXT: res += i + j; +// CHECK-NEXT: _t6 = res; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t3, 3UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t3)) { +// CHECK-NEXT: case 3UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t6; +// CHECK-NEXT: double _r_d5 = _d_res; +// CHECK-NEXT: * _d_i += _r_d5; +// CHECK-NEXT: * _d_j += _r_d5; +// CHECK-NEXT: } +// CHECK-NEXT: if (_cond0 != 0 && _cond0 != 1 && _cond0 != 2) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t5; +// CHECK-NEXT: double _r_d4 = _d_res; +// CHECK-NEXT: * _d_i += _r_d4 * j; +// CHECK-NEXT: * _d_j += i * _r_d4; +// CHECK-NEXT: } +// CHECK-NEXT: if (2 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t4; +// CHECK-NEXT: double _r_d3 = _d_res; +// CHECK-NEXT: * _d_j += _r_d3; +// CHECK-NEXT: } +// CHECK-NEXT: if (1 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t2; +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: * _d_i += _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: if (0 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: res = _t1; +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: * _d_i += 50 * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j * j * i; +// CHECK-NEXT: * _d_i += i * _r_d0 * j * j; +// CHECK-NEXT: * _d_j += i * i * _r_d0 * j; +// CHECK-NEXT: * _d_j += i * i * j * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn3(double i, double j) { + double res = 0; + int counter = 2; + while (counter--) { + switch (counter) { + case 0: res += i * i * j * j; + case 1: { + res += i * i; + } break; + case 2: res += j * j; + default: res += i + j; + } + } + return res; +} + +// CHECK: void fn3_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: int _d_counter = 0; +// CHECK-NEXT: unsigned long _t0; +// CHECK-NEXT: clad::tape _cond0 = {}; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: clad::tape _t2 = {}; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: clad::tape _t4 = {}; +// CHECK-NEXT: clad::tape _t5 = {}; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: int counter = 2; +// CHECK-NEXT: _t0 = 0; +// CHECK-NEXT: while (counter--) +// CHECK-NEXT: { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::push(_cond0, counter)) { +// CHECK-NEXT: { +// CHECK-NEXT: case 0: +// CHECK-NEXT: res += i * i * j * j; +// CHECK-NEXT: clad::push(_t1, res); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case 1: +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t2, res); +// CHECK-NEXT: res += i * i; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t3, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case 2: +// CHECK-NEXT: res += j * j; +// CHECK-NEXT: clad::push(_t4, res); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: default: +// CHECK-NEXT: res += i + j; +// CHECK-NEXT: clad::push(_t5, res); +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t3, 2UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: while (_t0) +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t3)) { +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t5); +// CHECK-NEXT: double _r_d3 = _d_res; +// CHECK-NEXT: * _d_i += _r_d3; +// CHECK-NEXT: * _d_j += _r_d3; +// CHECK-NEXT: } +// CHECK-NEXT: if (clad::back(_cond0) != 0 && clad::back(_cond0) != 1 && clad::back(_cond0) != 2) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t4); +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: * _d_j += _r_d2 * j; +// CHECK-NEXT: * _d_j += j * _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: if (2 == clad::back(_cond0)) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t2); +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: * _d_i += _r_d1 * i; +// CHECK-NEXT: * _d_i += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: if (1 == clad::back(_cond0)) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t1); +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j * j * i; +// CHECK-NEXT: * _d_i += i * _r_d0 * j * j; +// CHECK-NEXT: * _d_j += i * i * _r_d0 * j; +// CHECK-NEXT: * _d_j += i * i * j * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: if (0 == clad::back(_cond0)) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: clad::pop(_cond0); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: _t0--; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn4(double i, double j) { + double res = 0; + switch (1) { + case 0: res += i * i * j * j; break; + case 1: + int counter = 2; + while (counter--) { + res += i * j; + } + break; + } + return res; +} + +// CHECK: void fn4_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: int _d_counter = 0; +// CHECK-NEXT: int counter = 0; +// CHECK-NEXT: unsigned long _t2; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: { +// CHECK-NEXT: switch (1) { +// CHECK-NEXT: { +// CHECK-NEXT: case 0: +// CHECK-NEXT: res += i * i * j * j; +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case 1: +// CHECK-NEXT: counter = 2; +// CHECK-NEXT: } +// CHECK-NEXT: _t2 = 0; +// CHECK-NEXT: while (counter--) +// CHECK-NEXT: { +// CHECK-NEXT: _t2++; +// CHECK-NEXT: clad::push(_t3, res); +// CHECK-NEXT: res += i * j; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 2UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t1, 3UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t1)) { +// CHECK-NEXT: case 3UL: +// CHECK-NEXT: ; +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: while (_t2) +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t3); +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: * _d_i += _r_d1 * j; +// CHECK-NEXT: * _d_j += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: _t2--; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: if (1 == 1) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j * j * i; +// CHECK-NEXT: * _d_i += i * _r_d0 * j * j; +// CHECK-NEXT: * _d_j += i * i * _r_d0 * j; +// CHECK-NEXT: * _d_j += i * i * j * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: if (0 == 1) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn5(double i, double j) { + double res=0; + switch(int count = 1) + case 1: + res += i*j; + return res; +} + +// CHECK: void fn5_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: int _d_count = 0; +// CHECK-NEXT: int count = 0; +// CHECK-NEXT: int _cond0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: { +// CHECK-NEXT: count = 1; +// CHECK-NEXT: _cond0 = count; +// CHECK-NEXT: switch (_cond0) { +// CHECK-NEXT: case 1: +// CHECK-NEXT: res += i * j; +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: clad::push(_t1, 1UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t1)) { +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j; +// CHECK-NEXT: * _d_j += i * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: if (1 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn6(double u, double v) { + int res = 0; + double temp = 0; + switch(res = u * v) { + default: + temp = 1; + } + return res; +} + +// CHECK: void fn6_grad(double u, double v, clad::array_ref _d_u, clad::array_ref _d_v) { +// CHECK-NEXT: int _d_res = 0; +// CHECK-NEXT: double _d_temp = 0; +// CHECK-NEXT: int _t0; +// CHECK-NEXT: int _cond0; +// CHECK-NEXT: double _t1; +// CHECK-NEXT: clad::tape _t2 = {}; +// CHECK-NEXT: int res = 0; +// CHECK-NEXT: double temp = 0; +// CHECK-NEXT: { +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: res = u * v; +// CHECK-NEXT: _cond0 = res = u * v; +// CHECK-NEXT: switch (_cond0) { +// CHECK-NEXT: { +// CHECK-NEXT: default: +// CHECK-NEXT: temp = 1; +// CHECK-NEXT: _t1 = temp; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t2, 1UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t2)) { +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: temp = _t1; +// CHECK-NEXT: double _r_d1 = _d_temp; +// CHECK-NEXT: _d_temp -= _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: if (true) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: int _r_d0 = _d_res; +// CHECK-NEXT: _d_res -= _r_d0; +// CHECK-NEXT: * _d_u += _r_d0 * v; +// CHECK-NEXT: * _d_v += u * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn7(double u, double v) { + double res = 0; + for (int i=0; i < 5; ++i) { + switch(i) { + case 0: + case 1: + case 2: + res += u; + break; + case 3: + default: + res += v; + break; + } + } + return res; +} + +// CHECK: void fn7_grad(double u, double v, clad::array_ref _d_u, clad::array_ref _d_v) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: unsigned long _t0; +// CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: int i = 0; +// CHECK-NEXT: clad::tape _cond0 = {}; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: clad::tape _t2 = {}; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: _t0 = 0; +// CHECK-NEXT: for (i = 0; i < 5; ++i) { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::push(_cond0, i)) { +// CHECK-NEXT: { +// CHECK-NEXT: case 0: +// CHECK-NEXT: { +// CHECK-NEXT: case 1: +// CHECK-NEXT: { +// CHECK-NEXT: case 2: +// CHECK-NEXT: res += u; +// CHECK-NEXT: clad::push(_t1, res); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t2, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case 3: +// CHECK-NEXT: { +// CHECK-NEXT: default: +// CHECK-NEXT: res += v; +// CHECK-NEXT: clad::push(_t3, res); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t2, 2UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t2, 3UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: for (; _t0; _t0--) { +// CHECK-NEXT: --i; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t2)) { +// CHECK-NEXT: case 3UL: +// CHECK-NEXT: ; +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t3); +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: * _d_v += _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: if (clad::back(_cond0) != 0 && clad::back(_cond0) != 1 && clad::back(_cond0) != 2 && clad::back(_cond0) != 3) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: if (3 == clad::back(_cond0)) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t1); +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_u += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: if (2 == clad::back(_cond0)) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: if (1 == clad::back(_cond0)) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: if (0 == clad::back(_cond0)) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: clad::pop(_cond0); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + + +#define TEST_2(F, x, y) \ + { \ + result[0] = result[1] = 0; \ + auto d_##F = clad::gradient(F); \ + d_##F.execute(x, y, result, result + 1); \ + printf("{%.2f, %.2f}\n", result[0], result[1]); \ + } + +int main() { + double result[2] = {}; + clad::array_ref result_ref(result, 2); + + TEST_2(fn1, 3, 5); // CHECK-EXEC: {156.00, 100.00} + TEST_2(fn2, 3, 5); // CHECK-EXEC: {5.00, 3.00} + TEST_2(fn3, 3, 5); // CHECK-EXEC: {162.00, 90.00} + TEST_2(fn4, 3, 5); // CHECK-EXEC: {10.00, 6.00} + TEST_2(fn5, 3, 5); // CHECK-EXEC: {5.00, 3.00} + + INIT_GRADIENT(fn6); + INIT_GRADIENT(fn7); + + TEST_GRADIENT(fn6, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {5.00, 3.00} + TEST_GRADIENT(fn7, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {3.00, 2.00} +} \ No newline at end of file diff --git a/test/Gradient/SwitchInit.C b/test/Gradient/SwitchInit.C new file mode 100644 index 000000000..f2112c8f4 --- /dev/null +++ b/test/Gradient/SwitchInit.C @@ -0,0 +1,133 @@ +// RUN: %cladclang %s -I%S/../../include -std=c++17 -oSwitchInit.out 2>&1 -lstdc++ -lm | FileCheck %s +// RUN: ./SwitchInit.out | FileCheck -check-prefix=CHECK-EXEC %s +//CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" + +double fn1(double i, double j) { + double res = 0; + switch (int count = 1;count) { + case 0: res += i * j; break; + case 1: res += i * i; { + case 2: res += j * j; + } + default: res += i * i * j * j; + } + return res; +} + +// CHECK: void fn1_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: int _d_count = 0; +// CHECK-NEXT: int count = 0; +// CHECK-NEXT: int _cond0; +// CHECK-NEXT: double _t0; +// CHECK-NEXT: clad::tape _t1 = {}; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double _t3; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: { +// CHECK-NEXT: count = 1; +// CHECK-NEXT: _cond0 = count; +// CHECK-NEXT: switch (_cond0) { +// CHECK-NEXT: { +// CHECK-NEXT: case 0: +// CHECK-NEXT: res += i * j; +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t1, 1UL); +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: case 1: +// CHECK-NEXT: res += i * i; +// CHECK-NEXT: _t2 = res; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: case 2: +// CHECK-NEXT: res += j * j; +// CHECK-NEXT: _t3 = res; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: default: +// CHECK-NEXT: res += i * i * j * j; +// CHECK-NEXT: _t4 = res; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t1, 2UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t1)) { +// CHECK-NEXT: case 2UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t4; +// CHECK-NEXT: double _r_d3 = _d_res; +// CHECK-NEXT: * _d_i += _r_d3 * j * j * i; +// CHECK-NEXT: * _d_i += i * _r_d3 * j * j; +// CHECK-NEXT: * _d_j += i * i * _r_d3 * j; +// CHECK-NEXT: * _d_j += i * i * j * _r_d3; +// CHECK-NEXT: } +// CHECK-NEXT: if (_cond0 != 0 && _cond0 != 1 && _cond0 != 2) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t3; +// CHECK-NEXT: double _r_d2 = _d_res; +// CHECK-NEXT: * _d_j += _r_d2 * j; +// CHECK-NEXT: * _d_j += j * _r_d2; +// CHECK-NEXT: } +// CHECK-NEXT: if (2 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t2; +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: * _d_i += _r_d1 * i; +// CHECK-NEXT: * _d_i += i * _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: if (1 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: * _d_i += _r_d0 * j; +// CHECK-NEXT: * _d_j += i * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: if (0 == _cond0) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +#define TEST_2(F, x, y) \ + { \ + result[0] = result[1] = 0; \ + auto d_##F = clad::gradient(F); \ + d_##F.execute(x, y, result, result + 1); \ + printf("{%.2f, %.2f}\n", result[0], result[1]); \ + } + +int main() { + double result[2] = {}; + clad::array_ref result_ref(result, 2); + + TEST_2(fn1, 3, 5); // CHECK-EXEC: {156.00, 100.00} +} \ No newline at end of file