Skip to content

Commit

Permalink
Enabled accessing base class members from derived class
Browse files Browse the repository at this point in the history
  • Loading branch information
tanay-man committed Aug 12, 2024
1 parent 81a2905 commit a288818
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 17 deletions.
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ RUN(NAME class_01 LABELS cpython llvm llvm_jit)
RUN(NAME class_02 LABELS cpython llvm llvm_jit)
RUN(NAME class_03 LABELS cpython llvm llvm_jit)
RUN(NAME class_04 LABELS cpython llvm llvm_jit)
RUN(NAME class_05 LABELS cpython llvm llvm_jit)

# callback_04 is to test emulation. So just run with cpython
RUN(NAME callback_04 IMPORT_PATH .. LABELS cpython)
Expand Down
40 changes: 27 additions & 13 deletions integration_tests/class_05.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,32 @@
from lpython import i32

class Base:
def __init__(self:"Base"):
self.x_A : i32 = 10
class Animal:
def __init__(self:"Animal"):
self.species: str = "Generic Animal"
self.age: i32 = 0

class Dog(Animal):
def __init__(self:"Dog", name:str, age:i32):
# super().__init__()
self.species: str = "Dog"
self.name: str = name
self.age: i32 = age

class Cat(Animal):
def __init__(self:"Cat", name: str, age: i32):
# super().__init__()
self.species: str = "Cat"
self.name:str = name
self.age: i32 = age

class Derived(Base):
def __init__(self:"Derived") :
super().__init__()
self.y_B : i32 = 6
def get_x_A(self:"Derived"):
print(self.x_A)
def main():
d : Derived = Derived()
print(d.x_A)
print(d.y_B)
d.get_x_A()
dog: Dog = Dog("Buddy", 5)
cat: Cat = Cat("Whiskers", 3)
op1: str = str(dog.name+" is a "+str(dog.age)+"-year-old "+dog.species+".")
print(op1)
assert op1 == "Buddy is a 5-year-old Dog."
op2: str = str(cat.name+ " is a "+ str(cat.age)+ "-year-old "+ cat.species+ ".")
print(op2)
assert op2 == "Whiskers is a 3-year-old Cat."

main()
32 changes: 28 additions & 4 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3301,10 +3301,21 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
current_scope->add_symbol(x_m_name, class_type);
}
} else {
if( x.n_bases > 0 ) {
throw SemanticError("Inheritance in classes isn't supported yet.",
ASR::symbol_t* parent = nullptr;
if( x.n_bases > 1 ) {
throw SemanticError("Multiple inheritance in classes isn't supported yet.",
x.base.base.loc);
}
else if (x.n_bases == 1) {
AST::Name_t* n = nullptr;
if ( AST::is_a<AST::Name_t>(*x.m_bases[0]) ) {
n = AST::down_cast<AST::Name_t>(x.m_bases[0]);
} else {
throw SemanticError("Expected a Name here",x.base.base.loc);
}
parent = current_scope->resolve_symbol(n->m_id);
LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*parent));
}
SymbolTable *parent_scope = current_scope;
if( ASR::symbol_t* sym = current_scope->resolve_symbol(x_m_name) ) {
LCOMPILERS_ASSERT(ASR::is_a<ASR::Struct_t>(*sym));
Expand All @@ -3324,6 +3335,15 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
}
} else {
current_scope = al.make_new<SymbolTable>(parent_scope);
// if ( parent ) {
// ASR::Struct_t* base_st = ASR::down_cast<ASR::Struct_t>(parent);
// SymbolTable* base_scope = base_st->m_symtab;
// for (auto i : base_scope->scope) {
// std::string name = i.first;
// ASR::symbol_t* sym = i.second;
// current_scope->add_symbol(name,sym);
// }
// }
Vec<char*> member_names;
Vec<char*> member_fn_names;
Vec<ASR::call_arg_t> member_init;
Expand All @@ -3344,7 +3364,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
member_names.p, member_names.size(), member_fn_names.p,
member_fn_names.size(), class_abi, ASR::accessType::Public,
false, false, member_init.p, member_init.size(),
nullptr, nullptr));
nullptr, parent));
parent_scope->add_symbol(x.m_name, class_sym);
visit_ClassMembers(x, member_names, member_fn_names,
struct_dependencies, member_init, false, class_abi, true);
Expand Down Expand Up @@ -6239,10 +6259,14 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
for( size_t i = 0; i < der_type->n_members && !member_found; i++ ) {
member_found = std::string(der_type->m_members[i]) == member_name;
}
if( !member_found ) {
if( !member_found && !der_type->m_parent ) {
throw SemanticError("No member " + member_name +
" found in " + std::string(der_type->m_name),
loc);
} else if ( !member_found && der_type->m_parent ) {
ASR::ttype_t* parent_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al, loc,der_type->m_parent));
visit_AttributeUtil(parent_type,attr_char,t,loc);
return;
}
ASR::expr_t *val = ASR::down_cast<ASR::expr_t>(ASR::make_Var_t(al, loc, t));
ASR::symbol_t* member_sym = der_type->m_symtab->resolve_symbol(member_name);
Expand Down

0 comments on commit a288818

Please sign in to comment.