From 896276745d5d7bd78508b802b01b3a4e8dc5bc1c Mon Sep 17 00:00:00 2001 From: Luc Grosheintz Date: Thu, 10 Oct 2024 10:54:41 +0200 Subject: [PATCH] Simplify `try_replace_tagged_statement`. --- .../sympy_replace_solutions_visitor.cpp | 24 ++++--------------- .../sympy_replace_solutions_visitor.hpp | 4 +--- src/visitors/visitor_utils.cpp | 15 ++++++++---- src/visitors/visitor_utils.hpp | 3 +++ 4 files changed, 19 insertions(+), 27 deletions(-) diff --git a/src/visitors/sympy_replace_solutions_visitor.cpp b/src/visitors/sympy_replace_solutions_visitor.cpp index fb0838e8a7..a30e8f00e2 100644 --- a/src/visitors/sympy_replace_solutions_visitor.cpp +++ b/src/visitors/sympy_replace_solutions_visitor.cpp @@ -161,8 +161,7 @@ void SympyReplaceSolutionsVisitor::visit_statement_block(ast::StatementBlock& no void SympyReplaceSolutionsVisitor::try_replace_tagged_statement( const ast::Node& node, - std::shared_ptr get_lhs(const ast::Node& node), - std::shared_ptr get_rhs(const ast::Node& node)) { + std::shared_ptr get_lhs(const ast::Node& node)) { interleaves_counter.new_equation(true); const auto& statement = std::static_pointer_cast( @@ -176,8 +175,7 @@ void SympyReplaceSolutionsVisitor::try_replace_tagged_statement( switch (policy) { case ReplacePolicy::VALUE: { - const auto dependencies = statement_dependencies(get_lhs(node), get_rhs(node)); - const auto& key = dependencies.first; + const auto key = statement_dependencies_key(get_lhs(node)); if (solution_statements.is_var_assigned_here(key)) { logger->debug("SympyReplaceSolutionsVisitor :: marking for replacement {}", @@ -216,11 +214,7 @@ void SympyReplaceSolutionsVisitor::visit_diff_eq_expression(ast::DiffEqExpressio return dynamic_cast(node).get_expression()->get_lhs(); }; - auto get_rhs = [](const ast::Node& node) -> std::shared_ptr { - return dynamic_cast(node).get_expression()->get_rhs(); - }; - - try_replace_tagged_statement(node, get_lhs, get_rhs); + try_replace_tagged_statement(node, get_lhs); } void SympyReplaceSolutionsVisitor::visit_lin_equation(ast::LinEquation& node) { @@ -229,11 +223,7 @@ void SympyReplaceSolutionsVisitor::visit_lin_equation(ast::LinEquation& node) { return dynamic_cast(node).get_lhs(); }; - auto get_rhs = [](const ast::Node& node) -> std::shared_ptr { - return dynamic_cast(node).get_rhs(); - }; - - try_replace_tagged_statement(node, get_lhs, get_rhs); + try_replace_tagged_statement(node, get_lhs); } @@ -243,11 +233,7 @@ void SympyReplaceSolutionsVisitor::visit_non_lin_equation(ast::NonLinEquation& n return dynamic_cast(node).get_lhs(); }; - auto get_rhs = [](const ast::Node& node) -> std::shared_ptr { - return dynamic_cast(node).get_rhs(); - }; - - try_replace_tagged_statement(node, get_lhs, get_rhs); + try_replace_tagged_statement(node, get_lhs); } diff --git a/src/visitors/sympy_replace_solutions_visitor.hpp b/src/visitors/sympy_replace_solutions_visitor.hpp index 42bc4da2d0..c055d57c05 100644 --- a/src/visitors/sympy_replace_solutions_visitor.hpp +++ b/src/visitors/sympy_replace_solutions_visitor.hpp @@ -249,12 +249,10 @@ class SympyReplaceSolutionsVisitor: public AstVisitor { * * \param node it can be Diff_Eq_Expression/LinEquation/NonLinEquation * \param get_lhs method with witch we may get the lhs (in case we need it) - * \param get_rhs method with witch we may get the rhs (in case we need it) */ void try_replace_tagged_statement( const ast::Node& node, - std::shared_ptr get_lhs(const ast::Node& node), - std::shared_ptr get_rhs(const ast::Node& node)); + std::shared_ptr get_lhs(const ast::Node& node)); /** * \struct InterleavesCounter diff --git a/src/visitors/visitor_utils.cpp b/src/visitors/visitor_utils.cpp index ca9e1a0b61..b279e20871 100644 --- a/src/visitors/visitor_utils.cpp +++ b/src/visitors/visitor_utils.cpp @@ -250,19 +250,24 @@ std::string to_json(const ast::Ast& node, bool compact, bool expand, bool add_nm return stream.str(); } +std::string statement_dependencies_key(const std::shared_ptr& lhs) { + if (!lhs->is_var_name()) { + return ""; + } + + const auto& lhs_var_name = std::dynamic_pointer_cast(lhs); + return get_full_var_name(*lhs_var_name); +} + std::pair> statement_dependencies( const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - std::string key; + std::string key = statement_dependencies_key(lhs); std::unordered_set out; - if (!lhs->is_var_name()) { return {key, out}; } - const auto& lhs_var_name = std::dynamic_pointer_cast(lhs); - key = get_full_var_name(*lhs_var_name); - visitor::AstLookupVisitor lookup_visitor; lookup_visitor.lookup(*rhs, ast::AstNodeType::VAR_NAME); auto rhs_nodes = lookup_visitor.get_nodes(); diff --git a/src/visitors/visitor_utils.hpp b/src/visitors/visitor_utils.hpp index 97d8fa0d33..e1e7c99896 100644 --- a/src/visitors/visitor_utils.hpp +++ b/src/visitors/visitor_utils.hpp @@ -124,6 +124,9 @@ std::string to_json(const ast::Ast& node, bool expand = false, bool add_nmodl = false); +/// The `result.first` of `statement_dependencies`. +std::string statement_dependencies_key(const std::shared_ptr& lhs); + /// If \p lhs and \p rhs combined represent an assignment (we assume to have an "=" in between them) /// we extract the variables on which the assigned variable depends on. We provide the input with /// lhs and rhs because there are a few nodes that have this similar structure but slightly