diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 52ccd6d074..929bf1b63a 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -839,6 +839,7 @@ RUN(NAME class_01 LABELS cpython llvm llvm_jit) RUN(NAME class_02 LABELS cpython llvm llvm_jit) RUN(NAME class_03 LABELS cpython llvm llvm_jit) RUN(NAME class_04 LABELS cpython llvm llvm_jit) +RUN(NAME class_05 LABELS cpython llvm llvm_jit) # callback_04 is to test emulation. So just run with cpython RUN(NAME callback_04 IMPORT_PATH .. LABELS cpython) diff --git a/integration_tests/class_05.py b/integration_tests/class_05.py index b0d92ff7ed..4bccf5fc63 100644 --- a/integration_tests/class_05.py +++ b/integration_tests/class_05.py @@ -1,18 +1,32 @@ from lpython import i32 -class Base: - def __init__(self:"Base"): - self.x_A : i32 = 10 +class Animal: + def __init__(self:"Animal"): + self.species: str = "Generic Animal" + self.age: i32 = 0 + +class Dog(Animal): + def __init__(self:"Dog", name:str, age:i32): + # super().__init__() + self.species: str = "Dog" + self.name: str = name + self.age: i32 = age + +class Cat(Animal): + def __init__(self:"Cat", name: str, age: i32): + # super().__init__() + self.species: str = "Cat" + self.name:str = name + self.age: i32 = age -class Derived(Base): - def __init__(self:"Derived") : - super().__init__() - self.y_B : i32 = 6 - def get_x_A(self:"Derived"): - print(self.x_A) def main(): - d : Derived = Derived() - print(d.x_A) - print(d.y_B) - d.get_x_A() + dog: Dog = Dog("Buddy", 5) + cat: Cat = Cat("Whiskers", 3) + op1: str = str(dog.name+" is a "+str(dog.age)+"-year-old "+dog.species+".") + print(op1) + assert op1 == "Buddy is a 5-year-old Dog." + op2: str = str(cat.name+ " is a "+ str(cat.age)+ "-year-old "+ cat.species+ ".") + print(op2) + assert op2 == "Whiskers is a 3-year-old Cat." + main() diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index fad8029a00..dc41debfa4 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -3301,10 +3301,21 @@ class CommonVisitor : public AST::BaseVisitor { current_scope->add_symbol(x_m_name, class_type); } } else { - if( x.n_bases > 0 ) { - throw SemanticError("Inheritance in classes isn't supported yet.", + ASR::symbol_t* parent = nullptr; + if( x.n_bases > 1 ) { + throw SemanticError("Multiple inheritance in classes isn't supported yet.", x.base.base.loc); } + else if (x.n_bases == 1) { + AST::Name_t* n = nullptr; + if ( AST::is_a(*x.m_bases[0]) ) { + n = AST::down_cast(x.m_bases[0]); + } else { + throw SemanticError("Expected a Name here",x.base.base.loc); + } + parent = current_scope->resolve_symbol(n->m_id); + LCOMPILERS_ASSERT(ASR::is_a(*parent)); + } SymbolTable *parent_scope = current_scope; if( ASR::symbol_t* sym = current_scope->resolve_symbol(x_m_name) ) { LCOMPILERS_ASSERT(ASR::is_a(*sym)); @@ -3324,6 +3335,15 @@ class CommonVisitor : public AST::BaseVisitor { } } else { current_scope = al.make_new(parent_scope); + // if ( parent ) { + // ASR::Struct_t* base_st = ASR::down_cast(parent); + // SymbolTable* base_scope = base_st->m_symtab; + // for (auto i : base_scope->scope) { + // std::string name = i.first; + // ASR::symbol_t* sym = i.second; + // current_scope->add_symbol(name,sym); + // } + // } Vec member_names; Vec member_fn_names; Vec member_init; @@ -3344,7 +3364,7 @@ class CommonVisitor : public AST::BaseVisitor { member_names.p, member_names.size(), member_fn_names.p, member_fn_names.size(), class_abi, ASR::accessType::Public, false, false, member_init.p, member_init.size(), - nullptr, nullptr)); + nullptr, parent)); parent_scope->add_symbol(x.m_name, class_sym); visit_ClassMembers(x, member_names, member_fn_names, struct_dependencies, member_init, false, class_abi, true); @@ -6239,10 +6259,14 @@ class BodyVisitor : public CommonVisitor { for( size_t i = 0; i < der_type->n_members && !member_found; i++ ) { member_found = std::string(der_type->m_members[i]) == member_name; } - if( !member_found ) { + if( !member_found && !der_type->m_parent ) { throw SemanticError("No member " + member_name + " found in " + std::string(der_type->m_name), loc); + } else if ( !member_found && der_type->m_parent ) { + ASR::ttype_t* parent_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al, loc,der_type->m_parent)); + visit_AttributeUtil(parent_type,attr_char,t,loc); + return; } ASR::expr_t *val = ASR::down_cast(ASR::make_Var_t(al, loc, t)); ASR::symbol_t* member_sym = der_type->m_symtab->resolve_symbol(member_name);