diff --git a/integration_tests/symbolics_09.py b/integration_tests/symbolics_09.py index 1b6181c0fa..18567769a4 100644 --- a/integration_tests/symbolics_09.py +++ b/integration_tests/symbolics_09.py @@ -1,4 +1,4 @@ -from sympy import Symbol, pi, S +from sympy import Symbol, pi, sin, cos from lpython import S, i32 def addInteger(x: S, y: S, z: S, i: i32): @@ -9,7 +9,11 @@ def call_addInteger(): a: S = Symbol("x") b: S = Symbol("y") c: S = pi - addInteger(a, b, c, 2) + d: S = sin(a) + e: S = cos(b) + addInteger(c, d, e, 2) + addInteger(c, sin(a), cos(b), 2) + addInteger(pi, sin(Symbol("x")), cos(Symbol("y")), 2) def main0(): call_addInteger() diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 6a49461485..7531c1ffeb 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -695,6 +695,56 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorparent; + Vec call_args; + call_args.reserve(al, 1); + + for (size_t i=0; i(*val) && ASR::is_a(*ASRUtils::expr_type(val))) { + ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(val); + ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc)); + std::string symengine_var = symengine_stack.push(); + ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, current_scope, s2c(al, symengine_var), nullptr, 0, ASR::intentType::Local, + nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr, + ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + current_scope->add_symbol(s2c(al, symengine_var), arg); + for (auto &item : current_scope->get_scope()) { + if (ASR::is_a(*item.second)) { + ASR::Variable_t *s = ASR::down_cast(item.second); + this->visit_Variable(*s); + } + } + + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); + process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target); + + ASR::call_arg_t call_arg; + call_arg.loc = x.base.base.loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + } else if (ASR::is_a(*val)) { + ASR::Cast_t* cast_t = ASR::down_cast(val); + if(cast_t->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return; + this->visit_Cast(*cast_t); + ASR::symbol_t *var_sym = current_scope->get_symbol(symengine_stack.pop()); + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + + ASR::call_arg_t call_arg; + call_arg.loc = x.base.base.loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + } else { + call_args.push_back(al, x.m_args[i]); + } + } + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, x.m_name, + x.m_name, call_args.p, call_args.n, nullptr)); + pass_result.push_back(al, stmt); + } + void visit_Print(const ASR::Print_t &x) { std::vector print_tmp; SymbolTable* module_scope = current_scope->parent; @@ -739,6 +789,25 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = x.base.base.loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + basic_str_sym, basic_str_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); + print_tmp.push_back(function_call); + } else if (ASR::is_a(*val)) { + ASR::Cast_t* cast_t = ASR::down_cast(val); + if(cast_t->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return; + this->visit_Cast(*cast_t); + ASR::symbol_t *var_sym = current_scope->get_symbol(symengine_stack.pop()); + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym)); + // Now create the FunctionCall node for basic_str ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope); Vec call_args;