diff --git a/src/libasr/pass/pass_utils.cpp b/src/libasr/pass/pass_utils.cpp index 35b99f5c58..e15eb84477 100644 --- a/src/libasr/pass/pass_utils.cpp +++ b/src/libasr/pass/pass_utils.cpp @@ -1247,6 +1247,40 @@ namespace LCompilers { } } } + ASR::symbol_t* get_struct_member(Allocator& al, ASR::symbol_t* struct_type_sym, std::string &call_name, + const Location &loc, SymbolTable* current_scope) { + ASR::Struct_t* struct_type = ASR::down_cast(struct_type_sym); + std::string struct_var_name = struct_type->m_name; + std::string struct_member_name = call_name; + ASR::symbol_t* struct_member = struct_type->m_symtab->resolve_symbol(struct_member_name); + ASR::symbol_t* struct_mem_asr_owner = ASRUtils::get_asr_owner(struct_member); + if( !struct_member || !struct_mem_asr_owner || + !ASR::is_a(*struct_mem_asr_owner) ) { + throw LCompilersException(struct_member_name + " not present in " + + struct_var_name + " dataclass"); + } + std::string import_name = struct_var_name + "_" + struct_member_name; + ASR::symbol_t* import_struct_member = current_scope->resolve_symbol(import_name); + bool import_from_struct = true; + if( import_struct_member ) { + if( ASR::is_a(*import_struct_member) ) { + ASR::ExternalSymbol_t* ext_sym = ASR::down_cast(import_struct_member); + if( ext_sym->m_external == struct_member && + std::string(ext_sym->m_module_name) == struct_var_name ) { + import_from_struct = false; + } + } + } + if( import_from_struct ) { + import_name = current_scope->get_unique_name(import_name, false); + import_struct_member = ASR::down_cast(ASR::make_ExternalSymbol_t(al, + loc, current_scope, s2c(al, import_name), + struct_member, s2c(al, struct_var_name), nullptr, 0, + s2c(al, struct_member_name), ASR::accessType::Public)); + current_scope->add_symbol(import_name, import_struct_member); + } + return import_struct_member; + } } // namespace PassUtils diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index e88563f72e..2025e78113 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -570,6 +570,9 @@ namespace LCompilers { */ }; + ASR::symbol_t* get_struct_member(Allocator& al, ASR::symbol_t* struct_type_sym, std::string &call_name, + const Location &loc, SymbolTable* current_scope); + namespace ReplacerUtils { template void replace_StructConstructor(ASR::StructConstructor_t* x, @@ -578,6 +581,33 @@ namespace LCompilers { bool perform_cast=false, ASR::cast_kindType cast_kind=ASR::cast_kindType::IntegerToInteger, ASR::ttype_t* casted_type=nullptr) { + if ( ASR::is_a(*(x->m_dt_sym)) ) { + ASR::Struct_t* st = ASR::down_cast(x->m_dt_sym); + if ( st->n_member_functions > 0 ) { + remove_original_statement = true; + if ( !ASR::is_a(*(replacer->result_var)) ) { + throw LCompilersException("Expected a var here"); + } + ASR::Var_t* target = ASR::down_cast(replacer->result_var); + ASR::call_arg_t first_arg; + first_arg.loc = x->base.base.loc; first_arg.m_value = replacer->result_var; + Vec new_args; new_args.reserve(replacer->al,x->n_args+1); + new_args.push_back(replacer->al, first_arg); + for( size_t i = 0; i < x->n_args; i++ ) { + new_args.push_back(replacer->al, x->m_args[i]); + } + ASR::StructType_t* type = ASR::down_cast( + (ASR::down_cast(target->m_v))->m_type); + std::string call_name = "__init__"; + ASR::symbol_t* call_sym = get_struct_member(replacer->al,type->m_derived_type, call_name, + x->base.base.loc, replacer->current_scope); + result_vec->push_back(replacer->al, ASRUtils::STMT( + ASRUtils::make_SubroutineCall_t_util(replacer->al, + x->base.base.loc, call_sym, nullptr, new_args.p, new_args.size(), + nullptr, nullptr, false, false))); + return; + } + } if( x->n_args == 0 ) { if( !inside_symtab ) { remove_original_statement = true; @@ -598,22 +628,22 @@ namespace LCompilers { } std::deque constructor_arg_syms; - ASR::StructType_t* dt_der = ASR::down_cast(x->m_type); - ASR::Struct_t* dt_dertype = ASR::down_cast( - ASRUtils::symbol_get_past_external(dt_der->m_derived_type)); - while( dt_dertype ) { - for( int i = (int) dt_dertype->n_members - 1; i >= 0; i-- ) { + ASR::StructType_t* dt_dertype = ASR::down_cast(x->m_type); + ASR::Struct_t* dt_der = ASR::down_cast( + ASRUtils::symbol_get_past_external(dt_dertype->m_derived_type)); + while( dt_der ) { + for( int i = (int) dt_der->n_members - 1; i >= 0; i-- ) { constructor_arg_syms.push_front( - dt_dertype->m_symtab->get_symbol( - dt_dertype->m_members[i])); + dt_der->m_symtab->get_symbol( + dt_der->m_members[i])); } - if( dt_dertype->m_parent != nullptr ) { + if( dt_der->m_parent != nullptr ) { ASR::symbol_t* dt_der_sym = ASRUtils::symbol_get_past_external( - dt_dertype->m_parent); + dt_der->m_parent); LCOMPILERS_ASSERT(ASR::is_a(*dt_der_sym)); - dt_dertype = ASR::down_cast(dt_der_sym); + dt_der = ASR::down_cast(dt_der_sym); } else { - dt_dertype = nullptr; + dt_der = nullptr; } } LCOMPILERS_ASSERT(constructor_arg_syms.size() == x->n_args); diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 24daf36684..fad8029a00 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -1282,16 +1282,13 @@ class CommonVisitor : public AST::BaseVisitor { args, st, loc); } if ( st->n_member_functions > 0 ) { - // Empty struct constructor // Initializers handled in init proc call - Vecempty_args; - empty_args.reserve(al, 1); - for (size_t i = 0; i < st->n_members; i++) { - empty_args.push_back(al, st->m_initializers[i]); + if ( n_kwargs>0 ) { + throw SemanticError("Keyword args are not supported", loc); } ASR::ttype_t* der_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al, loc, stemp)); - return ASR::make_StructConstructor_t(al, loc, stemp, empty_args.p, - empty_args.size(), der_type, nullptr); + return ASR::make_StructConstructor_t(al, loc, stemp, args.p, + args.size(), der_type, nullptr); } if ( args.size() > 0 && args.size() > st->n_members ) { @@ -5316,17 +5313,6 @@ class BodyVisitor : public CommonVisitor { if ( call->n_keywords>0 ) { throw SemanticError("Kwargs not implemented yet", x.base.base.loc); } - Vec args; - args.reserve(al, call->n_args + 1); - ASR::call_arg_t self_arg; - self_arg.loc = x.base.base.loc; - self_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, sym)); - args.push_back(al, self_arg); - visit_expr_list(call->m_args, call->n_args, args); - ASR::symbol_t* der = ASR::down_cast((var->m_type))->m_derived_type; - std::string call_name = "__init__"; - ASR::symbol_t* call_sym = get_struct_member(der, call_name, x.base.base.loc); - tmp = make_call_helper(al, call_sym, current_scope, args, call_name, x.base.base.loc); } } } @@ -5611,23 +5597,10 @@ class BodyVisitor : public CommonVisitor { overloaded)); if ( target->type == ASR::exprType::Var && tmp_value->type == ASR::exprType::StructConstructor ) { - Vec new_args; new_args.reserve(al, 1); - ASR::call_arg_t self_arg; - self_arg.loc = x.base.base.loc; - ASR::symbol_t* st = ASR::down_cast(target)->m_v; - self_arg.m_value = target; - new_args.push_back(al,self_arg); AST::Call_t* call = AST::down_cast(x.m_value); if ( call->n_keywords>0 ) { throw SemanticError("Kwargs not implemented yet", x.base.base.loc); } - visit_expr_list(call->m_args, call->n_args, new_args); - ASR::symbol_t* der = ASR::down_cast( - ASR::down_cast(st)->m_type)->m_derived_type; - std::string call_name = "__init__"; - ASR::symbol_t* call_sym = get_struct_member(der, call_name, x.base.base.loc); - tmp_vec.push_back(make_call_helper(al, call_sym, - current_scope, new_args, call_name, x.base.base.loc)); } } // to make sure that we add only those statements in tmp_vec