Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Asr pass refactor #2795

Merged
merged 12 commits into from
Aug 16, 2024
34 changes: 34 additions & 0 deletions src/libasr/pass/pass_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,40 @@ namespace LCompilers {
}
}
}
ASR::symbol_t* get_struct_member(Allocator& al, ASR::symbol_t* struct_type_sym, std::string &call_name,
const Location &loc, SymbolTable* current_scope) {
ASR::Struct_t* struct_type = ASR::down_cast<ASR::Struct_t>(struct_type_sym);
std::string struct_var_name = struct_type->m_name;
std::string struct_member_name = call_name;
ASR::symbol_t* struct_member = struct_type->m_symtab->resolve_symbol(struct_member_name);
ASR::symbol_t* struct_mem_asr_owner = ASRUtils::get_asr_owner(struct_member);
if( !struct_member || !struct_mem_asr_owner ||
!ASR::is_a<ASR::Struct_t>(*struct_mem_asr_owner) ) {
throw LCompilersException(struct_member_name + " not present in " +
struct_var_name + " dataclass");
}
std::string import_name = struct_var_name + "_" + struct_member_name;
ASR::symbol_t* import_struct_member = current_scope->resolve_symbol(import_name);
bool import_from_struct = true;
if( import_struct_member ) {
if( ASR::is_a<ASR::ExternalSymbol_t>(*import_struct_member) ) {
ASR::ExternalSymbol_t* ext_sym = ASR::down_cast<ASR::ExternalSymbol_t>(import_struct_member);
if( ext_sym->m_external == struct_member &&
std::string(ext_sym->m_module_name) == struct_var_name ) {
import_from_struct = false;
}
}
}
if( import_from_struct ) {
import_name = current_scope->get_unique_name(import_name, false);
import_struct_member = ASR::down_cast<ASR::symbol_t>(ASR::make_ExternalSymbol_t(al,
loc, current_scope, s2c(al, import_name),
struct_member, s2c(al, struct_var_name), nullptr, 0,
s2c(al, struct_member_name), ASR::accessType::Public));
current_scope->add_symbol(import_name, import_struct_member);
}
return import_struct_member;
}

} // namespace PassUtils

Expand Down
52 changes: 41 additions & 11 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,9 @@ namespace LCompilers {
*/
};

ASR::symbol_t* get_struct_member(Allocator& al, ASR::symbol_t* struct_type_sym, std::string &call_name,
const Location &loc, SymbolTable* current_scope);

namespace ReplacerUtils {
template <typename T>
void replace_StructConstructor(ASR::StructConstructor_t* x,
Expand All @@ -578,6 +581,33 @@ namespace LCompilers {
bool perform_cast=false,
ASR::cast_kindType cast_kind=ASR::cast_kindType::IntegerToInteger,
ASR::ttype_t* casted_type=nullptr) {
if ( ASR::is_a<ASR::Struct_t>(*(x->m_dt_sym)) ) {
ASR::Struct_t* st = ASR::down_cast<ASR::Struct_t>(x->m_dt_sym);
if ( st->n_member_functions > 0 ) {
remove_original_statement = true;
if ( !ASR::is_a<ASR::Var_t>(*(replacer->result_var)) ) {
throw LCompilersException("Expected a var here");
}
ASR::Var_t* target = ASR::down_cast<ASR::Var_t>(replacer->result_var);
ASR::call_arg_t first_arg;
first_arg.loc = x->base.base.loc; first_arg.m_value = replacer->result_var;
Vec<ASR::call_arg_t> new_args; new_args.reserve(replacer->al,x->n_args+1);
new_args.push_back(replacer->al, first_arg);
for( size_t i = 0; i < x->n_args; i++ ) {
new_args.push_back(replacer->al, x->m_args[i]);
}
ASR::StructType_t* type = ASR::down_cast<ASR::StructType_t>(
(ASR::down_cast<ASR::Variable_t>(target->m_v))->m_type);
std::string call_name = "__init__";
ASR::symbol_t* call_sym = get_struct_member(replacer->al,type->m_derived_type, call_name,
x->base.base.loc, replacer->current_scope);
result_vec->push_back(replacer->al, ASRUtils::STMT(
ASRUtils::make_SubroutineCall_t_util(replacer->al,
x->base.base.loc, call_sym, nullptr, new_args.p, new_args.size(),
nullptr, nullptr, false, false)));
return;
}
}
if( x->n_args == 0 ) {
if( !inside_symtab ) {
remove_original_statement = true;
Expand All @@ -598,22 +628,22 @@ namespace LCompilers {
}

std::deque<ASR::symbol_t*> constructor_arg_syms;
ASR::StructType_t* dt_der = ASR::down_cast<ASR::StructType_t>(x->m_type);
ASR::Struct_t* dt_dertype = ASR::down_cast<ASR::Struct_t>(
ASRUtils::symbol_get_past_external(dt_der->m_derived_type));
while( dt_dertype ) {
for( int i = (int) dt_dertype->n_members - 1; i >= 0; i-- ) {
ASR::StructType_t* dt_dertype = ASR::down_cast<ASR::StructType_t>(x->m_type);
ASR::Struct_t* dt_der = ASR::down_cast<ASR::Struct_t>(
ASRUtils::symbol_get_past_external(dt_dertype->m_derived_type));
while( dt_der ) {
for( int i = (int) dt_der->n_members - 1; i >= 0; i-- ) {
constructor_arg_syms.push_front(
dt_dertype->m_symtab->get_symbol(
dt_dertype->m_members[i]));
dt_der->m_symtab->get_symbol(
dt_der->m_members[i]));
}
if( dt_dertype->m_parent != nullptr ) {
if( dt_der->m_parent != nullptr ) {
ASR::symbol_t* dt_der_sym = ASRUtils::symbol_get_past_external(
dt_dertype->m_parent);
dt_der->m_parent);
LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*dt_der_sym));
dt_dertype = ASR::down_cast<ASR::Struct_t>(dt_der_sym);
dt_der = ASR::down_cast<ASR::Struct_t>(dt_der_sym);
} else {
dt_dertype = nullptr;
dt_der = nullptr;
}
}
LCOMPILERS_ASSERT(constructor_arg_syms.size() == x->n_args);
Expand Down
35 changes: 4 additions & 31 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1282,16 +1282,13 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
args, st, loc);
}
if ( st->n_member_functions > 0 ) {
// Empty struct constructor
// Initializers handled in init proc call
Vec<ASR::call_arg_t>empty_args;
empty_args.reserve(al, 1);
for (size_t i = 0; i < st->n_members; i++) {
empty_args.push_back(al, st->m_initializers[i]);
if ( n_kwargs>0 ) {
throw SemanticError("Keyword args are not supported", loc);
}
ASR::ttype_t* der_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al, loc, stemp));
return ASR::make_StructConstructor_t(al, loc, stemp, empty_args.p,
empty_args.size(), der_type, nullptr);
return ASR::make_StructConstructor_t(al, loc, stemp, args.p,
args.size(), der_type, nullptr);
}

if ( args.size() > 0 && args.size() > st->n_members ) {
Expand Down Expand Up @@ -5316,17 +5313,6 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
if ( call->n_keywords>0 ) {
throw SemanticError("Kwargs not implemented yet", x.base.base.loc);
}
Vec<ASR::call_arg_t> args;
args.reserve(al, call->n_args + 1);
ASR::call_arg_t self_arg;
self_arg.loc = x.base.base.loc;
self_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, sym));
args.push_back(al, self_arg);
visit_expr_list(call->m_args, call->n_args, args);
ASR::symbol_t* der = ASR::down_cast<ASR::StructType_t>((var->m_type))->m_derived_type;
std::string call_name = "__init__";
ASR::symbol_t* call_sym = get_struct_member(der, call_name, x.base.base.loc);
tmp = make_call_helper(al, call_sym, current_scope, args, call_name, x.base.base.loc);
}
}
}
Expand Down Expand Up @@ -5611,23 +5597,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
overloaded));
if ( target->type == ASR::exprType::Var &&
tmp_value->type == ASR::exprType::StructConstructor ) {
Vec<ASR::call_arg_t> new_args; new_args.reserve(al, 1);
ASR::call_arg_t self_arg;
self_arg.loc = x.base.base.loc;
ASR::symbol_t* st = ASR::down_cast<ASR::Var_t>(target)->m_v;
self_arg.m_value = target;
new_args.push_back(al,self_arg);
AST::Call_t* call = AST::down_cast<AST::Call_t>(x.m_value);
if ( call->n_keywords>0 ) {
throw SemanticError("Kwargs not implemented yet", x.base.base.loc);
}
visit_expr_list(call->m_args, call->n_args, new_args);
ASR::symbol_t* der = ASR::down_cast<ASR::StructType_t>(
ASR::down_cast<ASR::Variable_t>(st)->m_type)->m_derived_type;
std::string call_name = "__init__";
ASR::symbol_t* call_sym = get_struct_member(der, call_name, x.base.base.loc);
tmp_vec.push_back(make_call_helper(al, call_sym,
current_scope, new_args, call_name, x.base.base.loc));
}
}
// to make sure that we add only those statements in tmp_vec
Expand Down
Loading