From c6f3adf8d80ddb32714487e5f39d73177acf9c53 Mon Sep 17 00:00:00 2001 From: parth Date: Sat, 3 Feb 2024 04:47:46 +0530 Subject: [PATCH] Support differentiation of switch condition --- lib/Differentiator/ReverseModeVisitor.cpp | 27 +++++----- test/Gradient/Switch.C | 63 +++++++++++++++++++++++ 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 0904dfe02..26baa95b6 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -9,6 +9,7 @@ #include "ConstantFolder.h" #include "TBRAnalyzer.h" +#include "clad/Differentiator/DerivativeBuilder.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" #include "clad/Differentiator/ExternalRMVSource.h" @@ -17,7 +18,9 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" +#include "clang/AST/Stmt.h" #include "clang/AST/TemplateBase.h" +#include "clang/Basic/TokenKinds.h" #include "clang/Sema/Lookup.h" #include "clang/Sema/Overload.h" #include "clang/Sema/Scope.h" @@ -3252,7 +3255,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff ReverseModeVisitor::VisitSwitchStmt(const SwitchStmt* SS) { // Scope and blocks for the compound statement that encloses the switch // statement in both the forward and the reverse pass. Block is required - // handling condition variable and switch-init statement. + // for handling condition variable and switch-init statement. beginScope(Scope::DeclScope); beginBlock(direction::forward); beginBlock(direction::reverse); @@ -3271,14 +3274,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(condVarDiff.getStmt(), direction::forward); addToCurrentBlock(condVarDiff.getStmt_dx(), direction::reverse); } - // Condition is only cloned, and not differentiated. - // Its because conditions generally contain non-differentiable constructs, - // but this behaviour will lead to incorrect results if the condition - // expression modifies any variable. - Expr* condClone = (SS->getCond() ? Clone(SS->getCond()) : nullptr); + StmtDiff condDiff = DifferentiateSingleStmt(SS->getCond()); + addToCurrentBlock(condDiff.getStmt(), direction::forward); + addToCurrentBlock(condDiff.getStmt_dx(), direction::reverse); Expr* condExpr = nullptr; - llvm::Optional condTape; + clad_compat::llvm_Optional condTape; if (isInsideLoop) { // If we are inside a loop, condition will be stored and used as follows: @@ -3289,10 +3290,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // reverse block: // switch (...) { ... } // clad::pop(...); - condTape.emplace(MakeCladTapeFor(condClone, "_cond")); + condTape.emplace(MakeCladTapeFor(condDiff.getExpr(), "_cond")); condExpr = condTape->Push; } else { - condExpr = GlobalStoreAndRef(condClone, "_cond").getExpr(); + condExpr = GlobalStoreAndRef(condDiff.getExpr(), "_cond").getExpr(); } auto activeBreakContHandler = PushBreakContStmtHandler( @@ -3361,6 +3362,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } } + if (!breakCond) + breakCond = m_Sema.ActOnCXXBoolLiteral(noLoc, tok::kw_true).get(); SSData->defaultIfBreakExpr->setCond(breakCond); } @@ -3378,7 +3381,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); // Registers all the cases to the switch statement. - for (auto SC : SSData->cases) + for (auto *SC : SSData->cases) forwardSS->addSwitchCase(SC); forwardSS = @@ -3425,8 +3428,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff ReverseModeVisitor::VisitDefaultStmt(const DefaultStmt* DS) { beginBlock(direction::reverse); beginBlock(direction::forward); - auto SSData = GetActiveSwitchStmtInfo(); - auto newDefaultStmt = + auto *SSData = GetActiveSwitchStmtInfo(); + DefaultStmt *newDefaultStmt = new (m_Sema.getASTContext()) DefaultStmt(noLoc, noLoc, nullptr); Stmt* ifThen = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get(); Stmt* ifBreakExpr = clad_compat::IfStmt_Create( diff --git a/test/Gradient/Switch.C b/test/Gradient/Switch.C index c4b0286f2..9adc52072 100644 --- a/test/Gradient/Switch.C +++ b/test/Gradient/Switch.C @@ -3,6 +3,7 @@ //CHECK-NOT: {{.*error|warning|note:.*}} #include "clad/Differentiator/Differentiator.h" +#include "../TestUtils.h" double fn1(double i, double j) { double res = 0; @@ -513,6 +514,65 @@ double fn5(double i, double j) { // CHECK-NEXT: } // CHECK-NEXT: } +double fn6(double u, double v) { + int res = 0; + double temp = 0; + switch(res = u * v) { + default: + temp = 1; + } + return res; +} + +// CHECK: void fn6_grad(double u, double v, clad::array_ref _d_u, clad::array_ref _d_v) { +// CHECK-NEXT: int _d_res = 0; +// CHECK-NEXT: double _d_temp = 0; +// CHECK-NEXT: int _t0; +// CHECK-NEXT: int _cond0; +// CHECK-NEXT: double _t1; +// CHECK-NEXT: clad::tape _t2 = {}; +// CHECK-NEXT: int res = 0; +// CHECK-NEXT: double temp = 0; +// CHECK-NEXT: { +// CHECK-NEXT: _t0 = res; +// CHECK-NEXT: res = u * v; +// CHECK-NEXT: _cond0 = res = u * v; +// CHECK-NEXT: switch (_cond0) { +// CHECK-NEXT: { +// CHECK-NEXT: default: +// CHECK-NEXT: temp = 1; +// CHECK-NEXT: _t1 = temp; +// CHECK-NEXT: } +// CHECK-NEXT: clad::push(_t2, 1UL); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: { +// CHECK-NEXT: switch (clad::pop(_t2)) { +// CHECK-NEXT: case 1UL: +// CHECK-NEXT: ; +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: temp = _t1; +// CHECK-NEXT: double _r_d1 = _d_temp; +// CHECK-NEXT: _d_temp -= _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: if (true) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: res = _t0; +// CHECK-NEXT: int _r_d0 = _d_res; +// CHECK-NEXT: _d_res -= _r_d0; +// CHECK-NEXT: * _d_u += _r_d0 * v; +// CHECK-NEXT: * _d_v += u * _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + #define TEST_2(F, x, y) \ { \ result[0] = result[1] = 0; \ @@ -530,4 +590,7 @@ int main() { TEST_2(fn3, 3, 5); // CHECK-EXEC: {162.00, 90.00} TEST_2(fn4, 3, 5); // CHECK-EXEC: {10.00, 6.00} TEST_2(fn5, 3, 5); // CHECK-EXEC: {5.00, 3.00} + + INIT_GRADIENT(fn6); + TEST_GRADIENT(fn6, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {5.00, 3.00} } \ No newline at end of file