Skip to content

Commit

Permalink
Simplify try_replace_tagged_statement.
Browse files Browse the repository at this point in the history
  • Loading branch information
1uc committed Oct 10, 2024
1 parent b0d96ba commit 8962767
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 27 deletions.
24 changes: 5 additions & 19 deletions src/visitors/sympy_replace_solutions_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ void SympyReplaceSolutionsVisitor::visit_statement_block(ast::StatementBlock& no

void SympyReplaceSolutionsVisitor::try_replace_tagged_statement(
const ast::Node& node,
std::shared_ptr<ast::Expression> get_lhs(const ast::Node& node),
std::shared_ptr<ast::Expression> get_rhs(const ast::Node& node)) {
std::shared_ptr<ast::Expression> get_lhs(const ast::Node& node)) {
interleaves_counter.new_equation(true);

const auto& statement = std::static_pointer_cast<ast::Statement>(
Expand All @@ -176,8 +175,7 @@ void SympyReplaceSolutionsVisitor::try_replace_tagged_statement(

switch (policy) {
case ReplacePolicy::VALUE: {
const auto dependencies = statement_dependencies(get_lhs(node), get_rhs(node));
const auto& key = dependencies.first;
const auto key = statement_dependencies_key(get_lhs(node));

if (solution_statements.is_var_assigned_here(key)) {
logger->debug("SympyReplaceSolutionsVisitor :: marking for replacement {}",
Expand Down Expand Up @@ -216,11 +214,7 @@ void SympyReplaceSolutionsVisitor::visit_diff_eq_expression(ast::DiffEqExpressio
return dynamic_cast<const ast::DiffEqExpression&>(node).get_expression()->get_lhs();
};

auto get_rhs = [](const ast::Node& node) -> std::shared_ptr<ast::Expression> {
return dynamic_cast<const ast::DiffEqExpression&>(node).get_expression()->get_rhs();
};

try_replace_tagged_statement(node, get_lhs, get_rhs);
try_replace_tagged_statement(node, get_lhs);
}

void SympyReplaceSolutionsVisitor::visit_lin_equation(ast::LinEquation& node) {
Expand All @@ -229,11 +223,7 @@ void SympyReplaceSolutionsVisitor::visit_lin_equation(ast::LinEquation& node) {
return dynamic_cast<const ast::LinEquation&>(node).get_lhs();
};

auto get_rhs = [](const ast::Node& node) -> std::shared_ptr<ast::Expression> {
return dynamic_cast<const ast::LinEquation&>(node).get_rhs();
};

try_replace_tagged_statement(node, get_lhs, get_rhs);
try_replace_tagged_statement(node, get_lhs);
}


Expand All @@ -243,11 +233,7 @@ void SympyReplaceSolutionsVisitor::visit_non_lin_equation(ast::NonLinEquation& n
return dynamic_cast<const ast::NonLinEquation&>(node).get_lhs();
};

auto get_rhs = [](const ast::Node& node) -> std::shared_ptr<ast::Expression> {
return dynamic_cast<const ast::NonLinEquation&>(node).get_rhs();
};

try_replace_tagged_statement(node, get_lhs, get_rhs);
try_replace_tagged_statement(node, get_lhs);
}


Expand Down
4 changes: 1 addition & 3 deletions src/visitors/sympy_replace_solutions_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,10 @@ class SympyReplaceSolutionsVisitor: public AstVisitor {
*
* \param node it can be Diff_Eq_Expression/LinEquation/NonLinEquation
* \param get_lhs method with witch we may get the lhs (in case we need it)
* \param get_rhs method with witch we may get the rhs (in case we need it)
*/
void try_replace_tagged_statement(
const ast::Node& node,
std::shared_ptr<ast::Expression> get_lhs(const ast::Node& node),
std::shared_ptr<ast::Expression> get_rhs(const ast::Node& node));
std::shared_ptr<ast::Expression> get_lhs(const ast::Node& node));

/**
* \struct InterleavesCounter
Expand Down
15 changes: 10 additions & 5 deletions src/visitors/visitor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,19 +250,24 @@ std::string to_json(const ast::Ast& node, bool compact, bool expand, bool add_nm
return stream.str();
}

std::string statement_dependencies_key(const std::shared_ptr<ast::Expression>& lhs) {
if (!lhs->is_var_name()) {
return "";
}

const auto& lhs_var_name = std::dynamic_pointer_cast<ast::VarName>(lhs);
return get_full_var_name(*lhs_var_name);
}

std::pair<std::string, std::unordered_set<std::string>> statement_dependencies(
const std::shared_ptr<ast::Expression>& lhs,
const std::shared_ptr<ast::Expression>& rhs) {
std::string key;
std::string key = statement_dependencies_key(lhs);
std::unordered_set<std::string> out;

if (!lhs->is_var_name()) {
return {key, out};
}

const auto& lhs_var_name = std::dynamic_pointer_cast<ast::VarName>(lhs);
key = get_full_var_name(*lhs_var_name);

visitor::AstLookupVisitor lookup_visitor;
lookup_visitor.lookup(*rhs, ast::AstNodeType::VAR_NAME);
auto rhs_nodes = lookup_visitor.get_nodes();
Expand Down
3 changes: 3 additions & 0 deletions src/visitors/visitor_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ std::string to_json(const ast::Ast& node,
bool expand = false,
bool add_nmodl = false);

/// The `result.first` of `statement_dependencies`.
std::string statement_dependencies_key(const std::shared_ptr<ast::Expression>& lhs);

/// If \p lhs and \p rhs combined represent an assignment (we assume to have an "=" in between them)
/// we extract the variables on which the assigned variable depends on. We provide the input with
/// lhs and rhs because there are a few nodes that have this similar structure but slightly
Expand Down

0 comments on commit 8962767

Please sign in to comment.