Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Asr pass refactor #2795

Merged
merged 12 commits into from
Aug 16, 2024
73 changes: 73 additions & 0 deletions src/libasr/pass/pass_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,80 @@ namespace LCompilers {
}
}
}
ASR::symbol_t* get_struct_member(Allocator& al, ASR::symbol_t* struct_type_sym, std::string &call_name,
const Location &loc, SymbolTable* current_scope) {
ASR::Struct_t* struct_type = ASR::down_cast<ASR::Struct_t>(struct_type_sym);
std::string struct_var_name = struct_type->m_name;
std::string struct_member_name = call_name;
ASR::symbol_t* struct_member = struct_type->m_symtab->resolve_symbol(struct_member_name);
ASR::symbol_t* struct_mem_asr_owner = ASRUtils::get_asr_owner(struct_member);
if( !struct_member || !struct_mem_asr_owner ||
!ASR::is_a<ASR::Struct_t>(*struct_mem_asr_owner) ) {
throw LCompilersException(struct_member_name + " not present in " +
struct_var_name + " dataclass");
}
std::string import_name = struct_var_name + "_" + struct_member_name;
ASR::symbol_t* import_struct_member = current_scope->resolve_symbol(import_name);
bool import_from_struct = true;
if( import_struct_member ) {
if( ASR::is_a<ASR::ExternalSymbol_t>(*import_struct_member) ) {
ASR::ExternalSymbol_t* ext_sym = ASR::down_cast<ASR::ExternalSymbol_t>(import_struct_member);
if( ext_sym->m_external == struct_member &&
std::string(ext_sym->m_module_name) == struct_var_name ) {
import_from_struct = false;
}
}
}
if( import_from_struct ) {
import_name = current_scope->get_unique_name(import_name, false);
import_struct_member = ASR::down_cast<ASR::symbol_t>(ASR::make_ExternalSymbol_t(al,
loc, current_scope, s2c(al, import_name),
struct_member, s2c(al, struct_var_name), nullptr, 0,
s2c(al, struct_member_name), ASR::accessType::Public));
current_scope->add_symbol(import_name, import_struct_member);
}
return import_struct_member;
}

ASR::asr_t* make_call_helper(Allocator &al, ASR::symbol_t* s, Vec<ASR::call_arg_t> args,
std::string call_name, const Location &loc) {

ASR::symbol_t *s_generic = nullptr, *stemp = s;
// Type map for generic functions
std::map<std::string, ASR::ttype_t*> subs;
std::map<std::string, ASR::symbol_t*> rt_subs;
// handling ExternalSymbol
s = ASRUtils::symbol_get_past_external(s);

if (ASR::is_a<ASR::Function_t>(*s)) {
ASR::Function_t *func = ASR::down_cast<ASR::Function_t>(s);
if (args.size() < func->n_args) {
std::string missed_args_names =" ";
size_t missed_args_count =0;
for (size_t def_arg = args.size(); def_arg < func->n_args; def_arg++){
ASR::Variable_t* var = ASRUtils::EXPR2VAR(func->m_args[def_arg]);
if(var->m_symbolic_value == nullptr) {
missed_args_names+= "'" + std::string(var->m_name) + "' and ";
missed_args_count++;
} else {
ASR::call_arg_t call_arg;
call_arg.m_value = var->m_symbolic_value;
call_arg.loc = (var->m_symbolic_value->base).loc;
args.push_back(al,call_arg);
}
}
if(missed_args_count > 0){
missed_args_names = missed_args_names.substr(0,missed_args_names.length() - 5);
LCompilersException("Number of arguments does not match in the function call");

}
}
return ASRUtils::make_SubroutineCall_t_util(al, loc, stemp,
s_generic, args.p, args.size(), nullptr, nullptr, false, false);
} else {
throw LCompilersException("Unsupported call type for " + call_name);
}
}
Thirumalai-Shaktivel marked this conversation as resolved.
Show resolved Hide resolved
} // namespace PassUtils

} // namespace LCompilers
54 changes: 43 additions & 11 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,12 @@ namespace LCompilers {
*/
};

ASR::symbol_t* get_struct_member(Allocator& al, ASR::symbol_t* struct_type_sym, std::string &call_name,
const Location &loc, SymbolTable* current_scope);

ASR::asr_t* make_call_helper(Allocator &al, ASR::symbol_t* s, Vec<ASR::call_arg_t> args,
std::string call_name, const Location &loc);

Thirumalai-Shaktivel marked this conversation as resolved.
Show resolved Hide resolved
namespace ReplacerUtils {
template <typename T>
void replace_StructConstructor(ASR::StructConstructor_t* x,
Expand All @@ -578,6 +584,32 @@ namespace LCompilers {
bool perform_cast=false,
ASR::cast_kindType cast_kind=ASR::cast_kindType::IntegerToInteger,
ASR::ttype_t* casted_type=nullptr) {
if ( ASR::is_a<ASR::Struct_t>(*(x->m_dt_sym)) ) {
ASR::Struct_t* st = ASR::down_cast<ASR::Struct_t>(x->m_dt_sym);
if ( st->n_member_functions > 0 ) {
remove_original_statement = true;
if ( !ASR::is_a<ASR::Var_t>(*(replacer->result_var)) ) {
throw LCompilersException("Expected a var here");
}
ASR::Var_t* target = ASR::down_cast<ASR::Var_t>(replacer->result_var);
ASR::call_arg_t first_arg;
first_arg.loc = x->base.base.loc; first_arg.m_value = replacer->result_var;
Vec<ASR::call_arg_t> new_args; new_args.reserve(replacer->al,x->n_args+1);
new_args.push_back(replacer->al, first_arg);
for( size_t i = 0; i < x->n_args; i++ ) {
new_args.push_back(replacer->al, x->m_args[i]);
}
ASR::StructType_t* type = ASR::down_cast<ASR::StructType_t>(
(ASR::down_cast<ASR::Variable_t>(target->m_v))->m_type);
std::string call_name = "__init__";
ASR::symbol_t* call_sym = get_struct_member(replacer->al,type->m_derived_type, call_name,
x->base.base.loc, replacer->current_scope);
result_vec->push_back(replacer->al,
ASRUtils::STMT(make_call_helper(replacer->al,call_sym,
new_args,call_name,x->base.base.loc)));
Thirumalai-Shaktivel marked this conversation as resolved.
Show resolved Hide resolved
return;
}
}
if( x->n_args == 0 ) {
if( !inside_symtab ) {
remove_original_statement = true;
Expand All @@ -598,22 +630,22 @@ namespace LCompilers {
}

std::deque<ASR::symbol_t*> constructor_arg_syms;
ASR::StructType_t* dt_der = ASR::down_cast<ASR::StructType_t>(x->m_type);
ASR::Struct_t* dt_dertype = ASR::down_cast<ASR::Struct_t>(
ASRUtils::symbol_get_past_external(dt_der->m_derived_type));
while( dt_dertype ) {
for( int i = (int) dt_dertype->n_members - 1; i >= 0; i-- ) {
ASR::StructType_t* dt_dertype = ASR::down_cast<ASR::StructType_t>(x->m_type);
ASR::Struct_t* dt_der = ASR::down_cast<ASR::Struct_t>(
ASRUtils::symbol_get_past_external(dt_dertype->m_derived_type));
while( dt_der ) {
for( int i = (int) dt_der->n_members - 1; i >= 0; i-- ) {
constructor_arg_syms.push_front(
dt_dertype->m_symtab->get_symbol(
dt_dertype->m_members[i]));
dt_der->m_symtab->get_symbol(
dt_der->m_members[i]));
}
if( dt_dertype->m_parent != nullptr ) {
if( dt_der->m_parent != nullptr ) {
ASR::symbol_t* dt_der_sym = ASRUtils::symbol_get_past_external(
dt_dertype->m_parent);
dt_der->m_parent);
LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*dt_der_sym));
dt_dertype = ASR::down_cast<ASR::Struct_t>(dt_der_sym);
dt_der = ASR::down_cast<ASR::Struct_t>(dt_der_sym);
} else {
dt_dertype = nullptr;
dt_der = nullptr;
}
}
LCOMPILERS_ASSERT(constructor_arg_syms.size() == x->n_args);
Expand Down
35 changes: 4 additions & 31 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1282,16 +1282,13 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
args, st, loc);
}
if ( st->n_member_functions > 0 ) {
// Empty struct constructor
// Initializers handled in init proc call
Vec<ASR::call_arg_t>empty_args;
empty_args.reserve(al, 1);
for (size_t i = 0; i < st->n_members; i++) {
empty_args.push_back(al, st->m_initializers[i]);
if ( n_kwargs>0 ) {
throw SemanticError("Keyword args are not supported", loc);
}
ASR::ttype_t* der_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al, loc, stemp));
return ASR::make_StructConstructor_t(al, loc, stemp, empty_args.p,
empty_args.size(), der_type, nullptr);
return ASR::make_StructConstructor_t(al, loc, stemp, args.p,
args.size(), der_type, nullptr);
}

if ( args.size() > 0 && args.size() > st->n_members ) {
Expand Down Expand Up @@ -5316,17 +5313,6 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
if ( call->n_keywords>0 ) {
throw SemanticError("Kwargs not implemented yet", x.base.base.loc);
}
Vec<ASR::call_arg_t> args;
args.reserve(al, call->n_args + 1);
ASR::call_arg_t self_arg;
self_arg.loc = x.base.base.loc;
self_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, sym));
args.push_back(al, self_arg);
visit_expr_list(call->m_args, call->n_args, args);
ASR::symbol_t* der = ASR::down_cast<ASR::StructType_t>((var->m_type))->m_derived_type;
std::string call_name = "__init__";
ASR::symbol_t* call_sym = get_struct_member(der, call_name, x.base.base.loc);
tmp = make_call_helper(al, call_sym, current_scope, args, call_name, x.base.base.loc);
}
}
}
Expand Down Expand Up @@ -5611,23 +5597,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
overloaded));
if ( target->type == ASR::exprType::Var &&
tmp_value->type == ASR::exprType::StructConstructor ) {
Vec<ASR::call_arg_t> new_args; new_args.reserve(al, 1);
ASR::call_arg_t self_arg;
self_arg.loc = x.base.base.loc;
ASR::symbol_t* st = ASR::down_cast<ASR::Var_t>(target)->m_v;
self_arg.m_value = target;
new_args.push_back(al,self_arg);
AST::Call_t* call = AST::down_cast<AST::Call_t>(x.m_value);
if ( call->n_keywords>0 ) {
throw SemanticError("Kwargs not implemented yet", x.base.base.loc);
}
visit_expr_list(call->m_args, call->n_args, new_args);
ASR::symbol_t* der = ASR::down_cast<ASR::StructType_t>(
ASR::down_cast<ASR::Variable_t>(st)->m_type)->m_derived_type;
std::string call_name = "__init__";
ASR::symbol_t* call_sym = get_struct_member(der, call_name, x.base.base.loc);
tmp_vec.push_back(make_call_helper(al, call_sym,
current_scope, new_args, call_name, x.base.base.loc));
}
}
// to make sure that we add only those statements in tmp_vec
Expand Down
Loading