diff --git a/crates/els/hir_visitor.rs b/crates/els/hir_visitor.rs index 841f4c179..c75361479 100644 --- a/crates/els/hir_visitor.rs +++ b/crates/els/hir_visitor.rs @@ -186,7 +186,7 @@ impl<'a> HIRVisitor<'a> { ) -> Option> { let ns = class_def.sig.ident().to_string_notype(); cur_ns.push(Str::from(ns)); - self.get_exprs_ns(cur_ns, class_def.methods.iter(), pos) + self.get_exprs_ns(cur_ns, class_def.all_methods(), pos) } fn get_patch_def_ns( @@ -282,7 +282,9 @@ impl<'a> HIRVisitor<'a> { Expr::Tuple(tuple) => self.get_expr_from_tuple(expr, tuple, pos), Expr::TypeAsc(type_asc) => self.get_expr(&type_asc.expr, pos), Expr::Dummy(dummy) => self.get_expr_from_dummy(dummy, pos), - Expr::Compound(block) | Expr::Code(block) => self.get_expr_from_block(block, pos), + Expr::Compound(block) | Expr::Code(block) => { + self.get_expr_from_block(block.iter(), pos) + } Expr::ReDef(redef) => self.get_expr_from_redef(expr, redef, pos), Expr::Import(_) => None, } @@ -362,7 +364,7 @@ impl<'a> HIRVisitor<'a> { pos: Position, ) -> Option<&Expr> { self.return_expr_if_same(expr, def.sig.ident().raw.name.token(), pos) - .or_else(|| self.get_expr_from_block(&def.body.block, pos)) + .or_else(|| self.get_expr_from_block(def.body.block.iter(), pos)) .or_else(|| self.return_expr_if_contains(expr, pos, def)) } @@ -376,12 +378,16 @@ impl<'a> HIRVisitor<'a> { .require_or_sup .as_ref() .and_then(|req_sup| self.get_expr(req_sup, pos)) - .or_else(|| self.get_expr_from_block(&class_def.methods, pos)) + .or_else(|| self.get_expr_from_block(class_def.all_methods(), pos)) .or_else(|| self.return_expr_if_contains(expr, pos, class_def)) } - fn get_expr_from_block<'e>(&'e self, block: &'e Block, pos: Position) -> Option<&Expr> { - for chunk in block.iter() { + fn get_expr_from_block<'e>( + &'e self, + block: impl Iterator, + pos: Position, + ) -> Option<&Expr> { + for chunk in block { if let Some(expr) = self.get_expr(chunk, pos) { return Some(expr); } @@ -396,7 +402,7 @@ impl<'a> HIRVisitor<'a> { pos: Position, ) -> Option<&Expr> { self.get_expr_from_acc(expr, &redef.attr, pos) - .or_else(|| self.get_expr_from_block(&redef.block, pos)) + .or_else(|| self.get_expr_from_block(redef.block.iter(), pos)) } fn get_expr_from_dummy<'e>(&'e self, dummy: &'e Dummy, pos: Position) -> Option<&Expr> { @@ -416,7 +422,7 @@ impl<'a> HIRVisitor<'a> { ) -> Option<&Expr> { self.return_expr_if_same(expr, patch_def.sig.name().token(), pos) .or_else(|| self.get_expr(&patch_def.base, pos)) - .or_else(|| self.get_expr_from_block(&patch_def.methods, pos)) + .or_else(|| self.get_expr_from_block(patch_def.methods.iter(), pos)) .or_else(|| self.return_expr_if_contains(expr, pos, patch_def)) } @@ -431,7 +437,7 @@ impl<'a> HIRVisitor<'a> { { return Some(expr); } - self.get_expr_from_block(&lambda.body, pos) + self.get_expr_from_block(lambda.body.iter(), pos) } fn get_expr_from_array<'e>( @@ -487,7 +493,7 @@ impl<'a> HIRVisitor<'a> { return Some(expr); } for field in record.attrs.iter() { - if let Some(expr) = self.get_expr_from_block(&field.body.block, pos) { + if let Some(expr) = self.get_expr_from_block(field.body.block.iter(), pos) { return Some(expr); } } @@ -577,7 +583,7 @@ impl<'a> HIRVisitor<'a> { Expr::Tuple(tuple) => self.get_tuple_info(tuple, token), Expr::TypeAsc(type_asc) => self.get_tasc_info(type_asc, token), Expr::Dummy(dummy) => self.get_dummy_info(dummy, token), - Expr::Compound(block) | Expr::Code(block) => self.get_block_info(block, token), + Expr::Compound(block) | Expr::Code(block) => self.get_block_info(block.iter(), token), Expr::ReDef(redef) => self.get_redef_info(redef, token), Expr::Import(_) => None, } @@ -690,7 +696,7 @@ impl<'a> HIRVisitor<'a> { fn get_def_info(&self, def: &Def, token: &Token) -> Option { self.get_sig_info(&def.sig, token) - .or_else(|| self.get_block_info(&def.body.block, token)) + .or_else(|| self.get_block_info(def.body.block.iter(), token)) } fn get_class_def_info(&self, class_def: &ClassDef, token: &Token) -> Option { @@ -699,17 +705,21 @@ impl<'a> HIRVisitor<'a> { .as_ref() .and_then(|req_sup| self.get_expr_info(req_sup, token)) .or_else(|| self.get_sig_info(&class_def.sig, token)) - .or_else(|| self.get_block_info(&class_def.methods, token)) + .or_else(|| self.get_block_info(class_def.all_methods(), token)) } fn get_patch_def_info(&self, patch_def: &PatchDef, token: &Token) -> Option { self.get_expr_info(&patch_def.base, token) .or_else(|| self.get_sig_info(&patch_def.sig, token)) - .or_else(|| self.get_block_info(&patch_def.methods, token)) + .or_else(|| self.get_block_info(patch_def.methods.iter(), token)) } - fn get_block_info(&self, block: &Block, token: &Token) -> Option { - for chunk in block.iter() { + fn get_block_info<'e>( + &self, + block: impl Iterator, + token: &Token, + ) -> Option { + for chunk in block { if let Some(expr) = self.get_expr_info(chunk, token) { return Some(expr); } @@ -719,7 +729,7 @@ impl<'a> HIRVisitor<'a> { fn get_redef_info(&self, redef: &ReDef, token: &Token) -> Option { self.get_acc_info(&redef.attr, token) - .or_else(|| self.get_block_info(&redef.block, token)) + .or_else(|| self.get_block_info(redef.block.iter(), token)) } fn get_dummy_info(&self, dummy: &Dummy, token: &Token) -> Option { @@ -733,7 +743,7 @@ impl<'a> HIRVisitor<'a> { fn get_lambda_info(&self, lambda: &Lambda, token: &Token) -> Option { self.get_params_info(&lambda.params, token) - .or_else(|| self.get_block_info(&lambda.body, token)) + .or_else(|| self.get_block_info(lambda.body.iter(), token)) } fn get_array_info(&self, arr: &Array, token: &Token) -> Option { diff --git a/crates/els/inlay_hint.rs b/crates/els/inlay_hint.rs index 54907c9d0..75d98ad96 100644 --- a/crates/els/inlay_hint.rs +++ b/crates/els/inlay_hint.rs @@ -227,8 +227,7 @@ impl<'s, C: BuildRunnable, P: Parsable> InlayHintGenerator<'s, C, P> { fn get_class_def_hint(&self, class_def: &ClassDef) -> Vec { class_def - .methods - .iter() + .all_methods() .flat_map(|expr| self.get_expr_hint(expr)) .collect() } diff --git a/crates/els/symbol.rs b/crates/els/symbol.rs index 8800882e2..bd5658660 100644 --- a/crates/els/symbol.rs +++ b/crates/els/symbol.rs @@ -163,7 +163,7 @@ impl Server { res.extend(symbol); } } - for method in def.methods.iter() { + for method in def.all_methods() { let symbol = self.symbol(method); res.extend(symbol); } diff --git a/crates/erg_compiler/codegen.rs b/crates/erg_compiler/codegen.rs index 4289e2dcb..32bb8f890 100644 --- a/crates/erg_compiler/codegen.rs +++ b/crates/erg_compiler/codegen.rs @@ -3339,7 +3339,7 @@ impl PyCodeGenerator { log!(info "entered {}", fn_name!()); let name = class.sig.ident().inspect().clone(); self.unit_size += 1; - let firstlineno = match class.methods.get(0).and_then(|def| def.ln_begin()) { + let firstlineno = match class.methods_list.get(0).and_then(|def| def.ln_begin()) { Some(l) => l, None => class.sig.ln_begin().unwrap_or(0), }; @@ -3362,8 +3362,9 @@ impl PyCodeGenerator { if class.need_to_gen_new { self.emit_new_func(&class.sig, class.__new__); } - if !class.methods.is_empty() { - self.emit_simple_block(class.methods); + let methods = ClassDef::take_all_methods(class.methods_list); + if !methods.is_empty() { + self.emit_simple_block(methods); } if self.stack_len() == init_stack_len { self.emit_load_const(ValueObj::None); diff --git a/crates/erg_compiler/context/generalize.rs b/crates/erg_compiler/context/generalize.rs index 156cb668d..94414de1f 100644 --- a/crates/erg_compiler/context/generalize.rs +++ b/crates/erg_compiler/context/generalize.rs @@ -1295,7 +1295,7 @@ impl Context { } } hir::Expr::ClassDef(class_def) => { - for def in class_def.methods.iter_mut() { + for def in class_def.all_methods_mut() { self.resolve_expr_t(def, qnames)?; } Ok(()) diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index 577a458cb..85fca9bff 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -3405,7 +3405,7 @@ impl Context { candidates.collect() } - pub(crate) fn is_class(&self, typ: &Type) -> bool { + pub fn is_class(&self, typ: &Type) -> bool { match typ { Type::And(_l, _r) => false, Type::Never => true, @@ -3429,7 +3429,7 @@ impl Context { } } - pub(crate) fn is_trait(&self, typ: &Type) -> bool { + pub fn is_trait(&self, typ: &Type) -> bool { match typ { Type::Never => false, Type::FreeVar(fv) if fv.is_linked() => self.is_class(&fv.crack()), diff --git a/crates/erg_compiler/desugar_hir.rs b/crates/erg_compiler/desugar_hir.rs index 443eb8f23..60cb36106 100644 --- a/crates/erg_compiler/desugar_hir.rs +++ b/crates/erg_compiler/desugar_hir.rs @@ -41,22 +41,22 @@ impl HIRDesugarer { match chunk { Expr::ClassDef(class_def) => { let class = Expr::Accessor(Accessor::Ident(class_def.sig.ident().clone())); - let methods = std::mem::take(class_def.methods.ref_mut_payload()); - let (methods, static_members): (Vec<_>, Vec<_>) = methods - .into_iter() - .partition(|attr| matches!(attr, Expr::Def(def) if def.sig.is_subr())); - class_def.methods.extend(methods); - let static_members = static_members - .into_iter() - .map(|expr| match expr { + let mut static_members = vec![]; + for methods_ in class_def.methods_list.iter_mut() { + let block = std::mem::take(&mut methods_.defs); + let (methods, statics): (Vec<_>, Vec<_>) = block + .into_iter() + .partition(|attr| matches!(attr, Expr::Def(def) if def.sig.is_subr())); + methods_.defs.extend(methods); + static_members.extend(statics.into_iter().map(|expr| match expr { Expr::Def(def) => { let acc = class.clone().attr(def.sig.into_ident()); let redef = ReDef::new(acc, def.body.block); Expr::ReDef(redef) } _ => expr, - }) - .collect::>(); + })); + } if !static_members.is_empty() { *chunk = Expr::Compound(Block::new( [vec![std::mem::take(chunk)], static_members].concat(), diff --git a/crates/erg_compiler/effectcheck.rs b/crates/erg_compiler/effectcheck.rs index bdfaf27dd..20808305f 100644 --- a/crates/erg_compiler/effectcheck.rs +++ b/crates/erg_compiler/effectcheck.rs @@ -95,7 +95,7 @@ impl SideEffectChecker { self.check_expr(req_sup); } // TODO: grow - for def in class_def.methods.iter() { + for def in class_def.all_methods() { self.check_expr(def); } } @@ -324,7 +324,7 @@ impl SideEffectChecker { if let Some(req_sup) = &class_def.require_or_sup { self.check_expr(req_sup); } - for def in class_def.methods.iter() { + for def in class_def.all_methods() { self.check_expr(def); } } diff --git a/crates/erg_compiler/hir.rs b/crates/erg_compiler/hir.rs index aa1afc55f..481fb330f 100644 --- a/crates/erg_compiler/hir.rs +++ b/crates/erg_compiler/hir.rs @@ -2301,14 +2301,19 @@ impl Def { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Methods { - pub class: TypeSpec, - pub vis: Token, // `.` or `::` - pub defs: RecordAttrs, // TODO: allow declaration + pub class: Type, + pub impl_trait: Option, + pub defs: Block, } impl NestedDisplay for Methods { fn fmt_nest(&self, f: &mut fmt::Formatter<'_>, level: usize) -> fmt::Result { - writeln!(f, "{}{}", self.class, self.vis.content)?; + writeln!( + f, + "{} {}", + self.class, + fmt_option!("|<: ", &self.impl_trait, "|"), + )?; self.defs.fmt_nest(f, level + 1) } } @@ -2317,16 +2322,16 @@ impl NestedDisplay for Methods { impl NoTypeDisplay for Methods { fn to_string_notype(&self) -> String { format!( - "{}{} {}", + "{} {} {}", self.class, - self.vis.content, + fmt_option!("|<: ", &self.impl_trait, "|"), self.defs.to_string_notype() ) } } impl_display_from_nested!(Methods); -impl_locational!(Methods, class, defs); +impl_locational!(Methods, defs); impl HasType for Methods { #[inline] @@ -2348,8 +2353,12 @@ impl HasType for Methods { } impl Methods { - pub const fn new(class: TypeSpec, vis: Token, defs: RecordAttrs) -> Self { - Self { class, vis, defs } + pub const fn new(class: Type, impl_trait: Option, defs: Block) -> Self { + Self { + class, + impl_trait, + defs, + } } } @@ -2361,26 +2370,32 @@ pub struct ClassDef { /// The type of `new` that is automatically defined if not defined pub need_to_gen_new: bool, pub __new__: Type, - pub methods: Block, + pub methods_list: Vec, } impl NestedDisplay for ClassDef { fn fmt_nest(&self, f: &mut fmt::Formatter<'_>, level: usize) -> fmt::Result { self.sig.fmt_nest(f, level)?; writeln!(f, ":")?; - self.methods.fmt_nest(f, level + 1) + fmt_lines(self.methods_list.iter(), f, level) } } // TODO impl NoTypeDisplay for ClassDef { fn to_string_notype(&self) -> String { - format!("{}: {}", self.sig, self.methods.to_string_notype()) + let methods = self + .methods_list + .iter() + .map(|m| m.to_string_notype()) + .collect::>() + .join("\n"); + format!("{}: {methods}", self.sig) } } impl_display_from_nested!(ClassDef); -impl_locational!(ClassDef, sig, lossy methods); +impl_locational!(ClassDef, sig, lossy methods_list); impl HasType for ClassDef { #[inline] @@ -2408,7 +2423,7 @@ impl ClassDef { require_or_sup: Option, need_to_gen_new: bool, __new__: Type, - methods: Block, + methods_list: Vec, ) -> Self { Self { obj, @@ -2416,8 +2431,24 @@ impl ClassDef { require_or_sup: require_or_sup.map(Box::new), need_to_gen_new, __new__, - methods, + methods_list, + } + } + + pub fn all_methods(&self) -> impl Iterator { + self.methods_list.iter().flat_map(|m| m.defs.iter()) + } + + pub fn all_methods_mut(&mut self) -> impl Iterator { + self.methods_list.iter_mut().flat_map(|m| m.defs.iter_mut()) + } + + pub fn take_all_methods(methods_list: Vec) -> Block { + let mut joined = Block::empty(); + for methods in methods_list { + joined.extend(methods.defs); } + joined } } diff --git a/crates/erg_compiler/link_hir.rs b/crates/erg_compiler/link_hir.rs index 6b90aeabf..441bedbac 100644 --- a/crates/erg_compiler/link_hir.rs +++ b/crates/erg_compiler/link_hir.rs @@ -197,7 +197,7 @@ impl<'a> HIRLinker<'a> { } } Expr::ClassDef(class_def) => { - for def in class_def.methods.iter_mut() { + for def in class_def.all_methods_mut() { Self::resolve_pymod_path(def); } } @@ -332,7 +332,7 @@ impl<'a> HIRLinker<'a> { } } Expr::ClassDef(class_def) => { - for def in class_def.methods.iter_mut() { + for def in class_def.all_methods_mut() { self.replace_import(def); } } diff --git a/crates/erg_compiler/lint.rs b/crates/erg_compiler/lint.rs index e3e83e343..68ce7a595 100644 --- a/crates/erg_compiler/lint.rs +++ b/crates/erg_compiler/lint.rs @@ -122,7 +122,7 @@ impl GenericASTLowerer { } } hir::Expr::ClassDef(class_def) => { - for chunk in class_def.methods.iter() { + for chunk in class_def.all_methods() { if let Err(ws) = self.expr_use_check(chunk) { warns.extend(ws); } @@ -233,7 +233,7 @@ impl GenericASTLowerer { } } Expr::ClassDef(class_def) => { - for chunk in class_def.methods.iter() { + for chunk in class_def.all_methods() { self.check_doc_comment(chunk); } } @@ -298,7 +298,7 @@ impl GenericASTLowerer { fn warn_implicit_union_chunk(&mut self, chunk: &Expr) { match chunk { Expr::ClassDef(class_def) => { - for chunk in class_def.methods.iter() { + for chunk in class_def.all_methods() { self.warn_implicit_union_chunk(chunk); } } diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index a57ceb9f2..6e77ce2d8 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -2097,8 +2097,9 @@ impl GenericASTLowerer { fn lower_class_def(&mut self, class_def: ast::ClassDef) -> LowerResult { log!(info "entered {}({class_def})", fn_name!()); let mut hir_def = self.lower_def(class_def.def)?; - let mut hir_methods = hir::Block::empty(); + let mut hir_methods_list = vec![]; for methods in class_def.methods_list.into_iter() { + let mut hir_methods = hir::Block::empty(); let (class, impl_trait) = self .module .context @@ -2198,7 +2199,9 @@ impl GenericASTLowerer { if let Err(err) = self.check_trait_impl(impl_trait.clone(), &class) { self.errs.push(err); } - self.check_collision_and_push(methods.id, class, impl_trait.map(|(t, _)| t)); + let impl_trait = impl_trait.map(|(t, _)| t); + self.check_collision_and_push(methods.id, class.clone(), impl_trait.clone()); + hir_methods_list.push(hir::Methods::new(class, impl_trait, hir_methods)); } let class = self.module.context.gen_type(&hir_def.sig.ident().raw); let Some(class_ctx) = self.module.context.get_nominal_type_ctx(&class) else { @@ -2242,7 +2245,7 @@ impl GenericASTLowerer { require_or_sup, need_to_gen_new, __new__.t.clone(), - hir_methods, + hir_methods_list, )) } diff --git a/crates/erg_compiler/ownercheck.rs b/crates/erg_compiler/ownercheck.rs index 393ae4e1c..6e7019775 100644 --- a/crates/erg_compiler/ownercheck.rs +++ b/crates/erg_compiler/ownercheck.rs @@ -131,7 +131,7 @@ impl OwnershipChecker { if let Some(req_sup) = &class_def.require_or_sup { self.check_expr(req_sup, Ownership::Owned, false); } - for def in class_def.methods.iter() { + for def in class_def.all_methods() { self.check_expr(def, Ownership::Owned, true); } } diff --git a/crates/erg_compiler/transpile.rs b/crates/erg_compiler/transpile.rs index ee390c3bd..24b496a40 100644 --- a/crates/erg_compiler/transpile.rs +++ b/crates/erg_compiler/transpile.rs @@ -1112,7 +1112,8 @@ impl PyScriptGenerator { code += &" ".repeat(self.level + 1); code += &format!("def new(x): return {class_name}.__call__(x)\n"); } - code += &self.transpile_block(classdef.methods, Discard); + let methods = ClassDef::take_all_methods(classdef.methods_list); + code += &self.transpile_block(methods, Discard); code }