From 61ffba8a49f0051487625b3446f4342a87bcf004 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sat, 23 Sep 2023 15:40:06 +0530 Subject: [PATCH 1/3] Added visit_SubroutineCall for the ASR symbolic pass --- integration_tests/symbolics_09.py | 8 +++- src/libasr/pass/replace_symbolic.cpp | 58 ++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/integration_tests/symbolics_09.py b/integration_tests/symbolics_09.py index 1b6181c0fa..2133f0b3d2 100644 --- a/integration_tests/symbolics_09.py +++ b/integration_tests/symbolics_09.py @@ -1,4 +1,4 @@ -from sympy import Symbol, pi, S +from sympy import Symbol, pi, sin, cos from lpython import S, i32 def addInteger(x: S, y: S, z: S, i: i32): @@ -9,7 +9,11 @@ def call_addInteger(): a: S = Symbol("x") b: S = Symbol("y") c: S = pi - addInteger(a, b, c, 2) + d: S = sin(a) + e: S = cos(b) + addInteger(c, d, e, 2) + addInteger(c, sin(a), cos(b), 2) + addInteger(c, sin(Symbol("x")), cos(Symbol("y")), 2) def main0(): call_addInteger() diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 6a49461485..1a0df58703 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -695,6 +695,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorparent; + Vec call_args; + call_args.reserve(al, 1); + + for (size_t i=0; i(*val) && ASR::is_a(*ASRUtils::expr_type(val))) { + ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(val); + ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc)); + std::string symengine_var = symengine_stack.push(); + ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, current_scope, s2c(al, symengine_var), nullptr, 0, ASR::intentType::Local, + nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr, + ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + current_scope->add_symbol(s2c(al, symengine_var), arg); + for (auto &item : current_scope->get_scope()) { + if (ASR::is_a(*item.second)) { + ASR::Variable_t *s = ASR::down_cast(item.second); + this->visit_Variable(*s); + } + } + + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); + process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target); + + ASR::call_arg_t call_arg; + call_arg.loc = x.base.base.loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + } else { + call_args.push_back(al, x.m_args[i]); + } + } + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, x.m_name, + x.m_name, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + } + void visit_Print(const ASR::Print_t &x) { std::vector print_tmp; SymbolTable* module_scope = current_scope->parent; @@ -739,6 +778,25 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = x.base.base.loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + basic_str_sym, basic_str_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); + print_tmp.push_back(function_call); + } else if (ASR::is_a(*val)) { + ASR::Cast_t* cast_t = ASR::down_cast(val); + ASR::symbol_t *var_sym = nullptr; + this->visit_Cast(*cast_t); + var_sym = current_scope->get_symbol(symengine_stack.pop()); + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + // Now create the FunctionCall node for basic_str ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope); Vec call_args; From 46df3e52a4621877ca8c842219c5a0fff6c84d27 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sat, 23 Sep 2023 16:01:29 +0530 Subject: [PATCH 2/3] fixed failing tests --- src/libasr/pass/replace_symbolic.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 1a0df58703..4719a8ff68 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -792,6 +792,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*val)) { ASR::Cast_t* cast_t = ASR::down_cast(val); + if(cast_t->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return; ASR::symbol_t *var_sym = nullptr; this->visit_Cast(*cast_t); var_sym = current_scope->get_symbol(symengine_stack.pop()); From 37cc6ebbdbb16e7107e7218578d815dffe1ddc41 Mon Sep 17 00:00:00 2001 From: anutosh491 Date: Sat, 23 Sep 2023 18:00:21 +0530 Subject: [PATCH 3/3] Added support for casting within visit_SubroutineCall --- integration_tests/symbolics_09.py | 2 +- src/libasr/pass/replace_symbolic.cpp | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/integration_tests/symbolics_09.py b/integration_tests/symbolics_09.py index 2133f0b3d2..18567769a4 100644 --- a/integration_tests/symbolics_09.py +++ b/integration_tests/symbolics_09.py @@ -13,7 +13,7 @@ def call_addInteger(): e: S = cos(b) addInteger(c, d, e, 2) addInteger(c, sin(a), cos(b), 2) - addInteger(c, sin(Symbol("x")), cos(Symbol("y")), 2) + addInteger(pi, sin(Symbol("x")), cos(Symbol("y")), 2) def main0(): call_addInteger() diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 4719a8ff68..7531c1ffeb 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -721,6 +721,17 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*val)) { + ASR::Cast_t* cast_t = ASR::down_cast(val); + if(cast_t->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return; + this->visit_Cast(*cast_t); + ASR::symbol_t *var_sym = current_scope->get_symbol(symengine_stack.pop()); + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + ASR::call_arg_t call_arg; call_arg.loc = x.base.base.loc; call_arg.m_value = target; @@ -793,9 +804,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*val)) { ASR::Cast_t* cast_t = ASR::down_cast(val); if(cast_t->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return; - ASR::symbol_t *var_sym = nullptr; this->visit_Cast(*cast_t); - var_sym = current_scope->get_symbol(symengine_stack.pop()); + ASR::symbol_t *var_sym = current_scope->get_symbol(symengine_stack.pop()); ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); // Now create the FunctionCall node for basic_str