Skip to content

Commit

Permalink
Added support for multi-level attribute for function call (#2794)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanay-man authored Jul 31, 2024
1 parent 42f385f commit 9374feb
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 19 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,7 @@ RUN(NAME c_mangling LABELS cpython llvm llvm_jit c)
RUN(NAME class_01 LABELS cpython llvm llvm_jit)
RUN(NAME class_02 LABELS cpython llvm llvm_jit)
RUN(NAME class_03 LABELS cpython llvm llvm_jit)
RUN(NAME class_04 LABELS cpython llvm llvm_jit)

# callback_04 is to test emulation. So just run with cpython
RUN(NAME callback_04 IMPORT_PATH .. LABELS cpython)
Expand Down
46 changes: 46 additions & 0 deletions integration_tests/class_04.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from lpython import i32
class Person:
def __init__(self:"Person", first:str, last:str, birthyear:i32, sgender:str):
self.first:str = first
self.last:str = last
self.birthyear:i32 = birthyear
self.sgender:str = sgender

def describe(self:"Person"):
print("first: " + self.first)
print("last: " + self.last)
print("birthyear: " + str(self.birthyear))
print("sgender: " + self.sgender)

class Employee:
def __init__(self:"Employee", person:Person, hire_date:i32, department:str):
self.person:Person = person
self.hire_date:i32 = hire_date
self.department:str = department

def describe(self:"Employee"):
self.person.describe()
print("hire_date: " + str(self.hire_date))
print("department: " + self.department)

def main():
jack:Person = Person("Jack", "Smith", 1984, "M")
jill_p:Person = Person("Jill", "Smith", 1984, "F")
jill:Employee = Employee(jill_p, 2003, "sales")

jack.describe()
assert jack.first == "Jack"
assert jack.last == "Smith"
assert jack.birthyear == 1984
assert jack.sgender == "M"

jill.describe()
assert jill.person.first == "Jill"
assert jill.person.last == "Smith"
assert jill.person.birthyear == 1984
assert jill.person.sgender == "F"
assert jill.department == "sales"
assert jill.hire_date == 2003

if __name__ == '__main__':
main()
2 changes: 2 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3169,7 +3169,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
visit_EnumType(*et);
} else if (is_a<ASR::Struct_t>(*item.second)) {
ASR::Struct_t *st = down_cast<ASR::Struct_t>(item.second);
mangle_prefix = mangle_prefix + "__class_" + st->m_name + "_";
instantiate_methods(*st);
mangle_prefix = "__module_" + std::string(x.m_name) + "_";
}
}
finish_module_init_function_prototype(x);
Expand Down
3 changes: 2 additions & 1 deletion src/libasr/codegen/llvm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1960,7 +1960,8 @@ namespace LCompilers {
while( struct_type_t != nullptr ) {
for( auto item: struct_type_t->m_symtab->get_scope() ) {
if( ASR::is_a<ASR::ClassProcedure_t>(*item.second) ||
ASR::is_a<ASR::CustomOperator_t>(*item.second) ) {
ASR::is_a<ASR::CustomOperator_t>(*item.second) ||
ASR::is_a<ASR::Function_t>(*item.second) ) {
continue ;
}
std::string mem_name = item.first;
Expand Down
46 changes: 44 additions & 2 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2956,7 +2956,8 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
}

void get_members_init (const AST::FunctionDef_t &x,
Vec<char*>& member_names, Vec<ASR::call_arg_t> &member_init){
Vec<char*>& member_names, Vec<ASR::call_arg_t> &member_init,
SetChar& struct_dependencies){
if(x.n_decorator_list > 0) {
throw SemanticError("Decorators for __init__ not implemented",
x.base.base.loc);
Expand Down Expand Up @@ -2997,6 +2998,22 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
c_arg.loc = var_sym->base.loc;
c_arg.m_value = nullptr;
member_init.push_back(al, c_arg);
ASR::ttype_t* var_type = ASRUtils::type_get_past_pointer(ASRUtils::symbol_type(var_sym));
char* aggregate_type_name = nullptr;
if( ASR::is_a<ASR::StructType_t>(*var_type) ) {
aggregate_type_name = ASRUtils::symbol_name(
ASR::down_cast<ASR::StructType_t>(var_type)->m_derived_type);
} else if( ASR::is_a<ASR::Enum_t>(*var_type) ) {
aggregate_type_name = ASRUtils::symbol_name(
ASR::down_cast<ASR::Enum_t>(var_type)->m_enum_type);
} else if( ASR::is_a<ASR::Union_t>(*var_type) ) {
aggregate_type_name = ASRUtils::symbol_name(
ASR::down_cast<ASR::Union_t>(var_type)->m_union_type);
}
if( aggregate_type_name &&
!current_scope->get_symbol(std::string(aggregate_type_name)) ) {
struct_dependencies.push_back(al, aggregate_type_name);
}
}
}

Expand Down Expand Up @@ -3027,7 +3044,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
*f = AST::down_cast<AST::FunctionDef_t>(x.m_body[i]);
std::string f_name = f->m_name;
if (f_name == "__init__") {
this->get_members_init(*f, member_names, member_init);
this->get_members_init(*f, member_names, member_init, struct_dependencies);
this->visit_stmt(*x.m_body[i]);
member_fn_names.push_back(al, f->m_name);
} else {
Expand Down Expand Up @@ -8244,6 +8261,31 @@ we will have to use something else.
}
handle_builtin_attribute(subscript_expr, at->m_attr, loc, eles);
return;
} else if ( AST::is_a<AST::Attribute_t>(*at->m_value) ) {
AST::Attribute_t* at_m_value = AST::down_cast<AST::Attribute_t>(at->m_value);
visit_Attribute(*at_m_value);
ASR::expr_t* e = ASRUtils::EXPR(tmp);
if ( !ASR::is_a<ASR::StructInstanceMember_t>(*e) ) {
throw SemanticError("Expected a class variable here", loc);
}
if ( !ASR::is_a<ASR::StructType_t>(*ASRUtils::expr_type(e)) ) {
throw SemanticError("Only Classes supported in nested attribute call", loc);
}
ASR::StructType_t* der = ASR::down_cast<ASR::StructType_t>(ASRUtils::expr_type(e));
ASR::symbol_t* der_sym = ASRUtils::symbol_get_past_external(der->m_derived_type);
std::string call_name = at->m_attr;

Vec<ASR::call_arg_t> new_args; new_args.reserve(al, args.n + 1);
ASR::call_arg_t self_arg;
self_arg.loc = args[0].loc;
self_arg.m_value = e;
new_args.push_back(al, self_arg);
for (size_t i=0; i<args.n; i++) {
new_args.push_back(al, args[i]);
}
ASR::symbol_t* st = get_struct_member(der_sym, call_name, loc);
tmp = make_call_helper(al, st, current_scope, new_args, call_name, loc);
return;
} else {
throw SemanticError("Only Name type and constant integers supported in Call", loc);
}
Expand Down
2 changes: 1 addition & 1 deletion tests/errors/class_04.py → tests/errors/class01.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ def __init__(self:"coord", x:i32, y:i32):
p2: coord = p1
p2.x = 2
print(p1.x)
print(p2.x)
print(p2.x)
13 changes: 13 additions & 0 deletions tests/reference/asr-class01-4134616.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "asr-class01-4134616",
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
"infile": "tests/errors/class01.py",
"infile_hash": "abc039698c8285a3831089abdd0cd9711894af57eb142d2222b3583d",
"outfile": null,
"outfile_hash": null,
"stdout": null,
"stdout_hash": null,
"stderr": "asr-class01-4134616.stderr",
"stderr_hash": "4f104cca0ef2ac39634223611165efee9d107c94292a98071d863ed0",
"returncode": 2
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
semantic error: Only Class constructor is allowed in the object assignment for now
--> tests/errors/class_04.py:9:1
--> tests/errors/class01.py:9:1
|
9 | p2: coord = p1
| ^^^^^^^^^^^^^^
13 changes: 0 additions & 13 deletions tests/reference/asr-class_04-b89178d.json

This file was deleted.

2 changes: 1 addition & 1 deletion tests/tests.toml
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ pass = "class_constructor"
cumulative = true

[[test]]
filename = "errors/class_04.py"
filename = "errors/class01.py"
asr = true

[[test]]
Expand Down

0 comments on commit 9374feb

Please sign in to comment.