Skip to content

Commit

Permalink
chore: add hir::Methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Nov 4, 2023
1 parent 234fa3f commit 21c937e
Show file tree
Hide file tree
Showing 14 changed files with 109 additions and 64 deletions.
46 changes: 28 additions & 18 deletions crates/els/hir_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ impl<'a> HIRVisitor<'a> {
) -> Option<Vec<Str>> {
let ns = class_def.sig.ident().to_string_notype();
cur_ns.push(Str::from(ns));
self.get_exprs_ns(cur_ns, class_def.methods.iter(), pos)
self.get_exprs_ns(cur_ns, class_def.all_methods(), pos)
}

fn get_patch_def_ns(
Expand Down Expand Up @@ -282,7 +282,9 @@ impl<'a> HIRVisitor<'a> {
Expr::Tuple(tuple) => self.get_expr_from_tuple(expr, tuple, pos),
Expr::TypeAsc(type_asc) => self.get_expr(&type_asc.expr, pos),
Expr::Dummy(dummy) => self.get_expr_from_dummy(dummy, pos),
Expr::Compound(block) | Expr::Code(block) => self.get_expr_from_block(block, pos),
Expr::Compound(block) | Expr::Code(block) => {
self.get_expr_from_block(block.iter(), pos)
}
Expr::ReDef(redef) => self.get_expr_from_redef(expr, redef, pos),
Expr::Import(_) => None,
}
Expand Down Expand Up @@ -362,7 +364,7 @@ impl<'a> HIRVisitor<'a> {
pos: Position,
) -> Option<&Expr> {
self.return_expr_if_same(expr, def.sig.ident().raw.name.token(), pos)
.or_else(|| self.get_expr_from_block(&def.body.block, pos))
.or_else(|| self.get_expr_from_block(def.body.block.iter(), pos))
.or_else(|| self.return_expr_if_contains(expr, pos, def))
}

Expand All @@ -376,12 +378,16 @@ impl<'a> HIRVisitor<'a> {
.require_or_sup
.as_ref()
.and_then(|req_sup| self.get_expr(req_sup, pos))
.or_else(|| self.get_expr_from_block(&class_def.methods, pos))
.or_else(|| self.get_expr_from_block(class_def.all_methods(), pos))
.or_else(|| self.return_expr_if_contains(expr, pos, class_def))
}

fn get_expr_from_block<'e>(&'e self, block: &'e Block, pos: Position) -> Option<&Expr> {
for chunk in block.iter() {
fn get_expr_from_block<'e>(
&'e self,
block: impl Iterator<Item = &'e Expr>,
pos: Position,
) -> Option<&Expr> {
for chunk in block {
if let Some(expr) = self.get_expr(chunk, pos) {
return Some(expr);
}
Expand All @@ -396,7 +402,7 @@ impl<'a> HIRVisitor<'a> {
pos: Position,
) -> Option<&Expr> {
self.get_expr_from_acc(expr, &redef.attr, pos)
.or_else(|| self.get_expr_from_block(&redef.block, pos))
.or_else(|| self.get_expr_from_block(redef.block.iter(), pos))
}

fn get_expr_from_dummy<'e>(&'e self, dummy: &'e Dummy, pos: Position) -> Option<&Expr> {
Expand All @@ -416,7 +422,7 @@ impl<'a> HIRVisitor<'a> {
) -> Option<&Expr> {
self.return_expr_if_same(expr, patch_def.sig.name().token(), pos)
.or_else(|| self.get_expr(&patch_def.base, pos))
.or_else(|| self.get_expr_from_block(&patch_def.methods, pos))
.or_else(|| self.get_expr_from_block(patch_def.methods.iter(), pos))
.or_else(|| self.return_expr_if_contains(expr, pos, patch_def))
}

Expand All @@ -431,7 +437,7 @@ impl<'a> HIRVisitor<'a> {
{
return Some(expr);
}
self.get_expr_from_block(&lambda.body, pos)
self.get_expr_from_block(lambda.body.iter(), pos)
}

fn get_expr_from_array<'e>(
Expand Down Expand Up @@ -487,7 +493,7 @@ impl<'a> HIRVisitor<'a> {
return Some(expr);
}
for field in record.attrs.iter() {
if let Some(expr) = self.get_expr_from_block(&field.body.block, pos) {
if let Some(expr) = self.get_expr_from_block(field.body.block.iter(), pos) {
return Some(expr);
}
}
Expand Down Expand Up @@ -577,7 +583,7 @@ impl<'a> HIRVisitor<'a> {
Expr::Tuple(tuple) => self.get_tuple_info(tuple, token),
Expr::TypeAsc(type_asc) => self.get_tasc_info(type_asc, token),
Expr::Dummy(dummy) => self.get_dummy_info(dummy, token),
Expr::Compound(block) | Expr::Code(block) => self.get_block_info(block, token),
Expr::Compound(block) | Expr::Code(block) => self.get_block_info(block.iter(), token),
Expr::ReDef(redef) => self.get_redef_info(redef, token),
Expr::Import(_) => None,
}
Expand Down Expand Up @@ -690,7 +696,7 @@ impl<'a> HIRVisitor<'a> {

fn get_def_info(&self, def: &Def, token: &Token) -> Option<VarInfo> {
self.get_sig_info(&def.sig, token)
.or_else(|| self.get_block_info(&def.body.block, token))
.or_else(|| self.get_block_info(def.body.block.iter(), token))
}

fn get_class_def_info(&self, class_def: &ClassDef, token: &Token) -> Option<VarInfo> {
Expand All @@ -699,17 +705,21 @@ impl<'a> HIRVisitor<'a> {
.as_ref()
.and_then(|req_sup| self.get_expr_info(req_sup, token))
.or_else(|| self.get_sig_info(&class_def.sig, token))
.or_else(|| self.get_block_info(&class_def.methods, token))
.or_else(|| self.get_block_info(class_def.all_methods(), token))
}

fn get_patch_def_info(&self, patch_def: &PatchDef, token: &Token) -> Option<VarInfo> {
self.get_expr_info(&patch_def.base, token)
.or_else(|| self.get_sig_info(&patch_def.sig, token))
.or_else(|| self.get_block_info(&patch_def.methods, token))
.or_else(|| self.get_block_info(patch_def.methods.iter(), token))
}

fn get_block_info(&self, block: &Block, token: &Token) -> Option<VarInfo> {
for chunk in block.iter() {
fn get_block_info<'e>(
&self,
block: impl Iterator<Item = &'e Expr>,
token: &Token,
) -> Option<VarInfo> {
for chunk in block {
if let Some(expr) = self.get_expr_info(chunk, token) {
return Some(expr);
}
Expand All @@ -719,7 +729,7 @@ impl<'a> HIRVisitor<'a> {

fn get_redef_info(&self, redef: &ReDef, token: &Token) -> Option<VarInfo> {
self.get_acc_info(&redef.attr, token)
.or_else(|| self.get_block_info(&redef.block, token))
.or_else(|| self.get_block_info(redef.block.iter(), token))
}

fn get_dummy_info(&self, dummy: &Dummy, token: &Token) -> Option<VarInfo> {
Expand All @@ -733,7 +743,7 @@ impl<'a> HIRVisitor<'a> {

fn get_lambda_info(&self, lambda: &Lambda, token: &Token) -> Option<VarInfo> {
self.get_params_info(&lambda.params, token)
.or_else(|| self.get_block_info(&lambda.body, token))
.or_else(|| self.get_block_info(lambda.body.iter(), token))
}

fn get_array_info(&self, arr: &Array, token: &Token) -> Option<VarInfo> {
Expand Down
3 changes: 1 addition & 2 deletions crates/els/inlay_hint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,7 @@ impl<'s, C: BuildRunnable, P: Parsable> InlayHintGenerator<'s, C, P> {

fn get_class_def_hint(&self, class_def: &ClassDef) -> Vec<InlayHint> {
class_def
.methods
.iter()
.all_methods()
.flat_map(|expr| self.get_expr_hint(expr))
.collect()
}
Expand Down
2 changes: 1 addition & 1 deletion crates/els/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ impl<Checker: BuildRunnable, Parser: Parsable> Server<Checker, Parser> {
res.extend(symbol);
}
}
for method in def.methods.iter() {
for method in def.all_methods() {
let symbol = self.symbol(method);
res.extend(symbol);
}
Expand Down
7 changes: 4 additions & 3 deletions crates/erg_compiler/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3339,7 +3339,7 @@ impl PyCodeGenerator {
log!(info "entered {}", fn_name!());
let name = class.sig.ident().inspect().clone();
self.unit_size += 1;
let firstlineno = match class.methods.get(0).and_then(|def| def.ln_begin()) {
let firstlineno = match class.methods_list.get(0).and_then(|def| def.ln_begin()) {
Some(l) => l,
None => class.sig.ln_begin().unwrap_or(0),
};
Expand All @@ -3362,8 +3362,9 @@ impl PyCodeGenerator {
if class.need_to_gen_new {
self.emit_new_func(&class.sig, class.__new__);
}
if !class.methods.is_empty() {
self.emit_simple_block(class.methods);
let methods = ClassDef::take_all_methods(class.methods_list);
if !methods.is_empty() {
self.emit_simple_block(methods);
}
if self.stack_len() == init_stack_len {
self.emit_load_const(ValueObj::None);
Expand Down
2 changes: 1 addition & 1 deletion crates/erg_compiler/context/generalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,7 @@ impl Context {
}
}
hir::Expr::ClassDef(class_def) => {
for def in class_def.methods.iter_mut() {
for def in class_def.all_methods_mut() {
self.resolve_expr_t(def, qnames)?;
}
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions crates/erg_compiler/context/inquire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3405,7 +3405,7 @@ impl Context {
candidates.collect()
}

pub(crate) fn is_class(&self, typ: &Type) -> bool {
pub fn is_class(&self, typ: &Type) -> bool {
match typ {
Type::And(_l, _r) => false,
Type::Never => true,
Expand All @@ -3429,7 +3429,7 @@ impl Context {
}
}

pub(crate) fn is_trait(&self, typ: &Type) -> bool {
pub fn is_trait(&self, typ: &Type) -> bool {
match typ {
Type::Never => false,
Type::FreeVar(fv) if fv.is_linked() => self.is_class(&fv.crack()),
Expand Down
20 changes: 10 additions & 10 deletions crates/erg_compiler/desugar_hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,22 @@ impl HIRDesugarer {
match chunk {
Expr::ClassDef(class_def) => {
let class = Expr::Accessor(Accessor::Ident(class_def.sig.ident().clone()));
let methods = std::mem::take(class_def.methods.ref_mut_payload());
let (methods, static_members): (Vec<_>, Vec<_>) = methods
.into_iter()
.partition(|attr| matches!(attr, Expr::Def(def) if def.sig.is_subr()));
class_def.methods.extend(methods);
let static_members = static_members
.into_iter()
.map(|expr| match expr {
let mut static_members = vec![];
for methods_ in class_def.methods_list.iter_mut() {
let block = std::mem::take(&mut methods_.defs);
let (methods, statics): (Vec<_>, Vec<_>) = block
.into_iter()
.partition(|attr| matches!(attr, Expr::Def(def) if def.sig.is_subr()));
methods_.defs.extend(methods);
static_members.extend(statics.into_iter().map(|expr| match expr {
Expr::Def(def) => {
let acc = class.clone().attr(def.sig.into_ident());
let redef = ReDef::new(acc, def.body.block);
Expr::ReDef(redef)
}
_ => expr,
})
.collect::<Vec<_>>();
}));
}
if !static_members.is_empty() {
*chunk = Expr::Compound(Block::new(
[vec![std::mem::take(chunk)], static_members].concat(),
Expand Down
4 changes: 2 additions & 2 deletions crates/erg_compiler/effectcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl SideEffectChecker {
self.check_expr(req_sup);
}
// TODO: grow
for def in class_def.methods.iter() {
for def in class_def.all_methods() {
self.check_expr(def);
}
}
Expand Down Expand Up @@ -324,7 +324,7 @@ impl SideEffectChecker {
if let Some(req_sup) = &class_def.require_or_sup {
self.check_expr(req_sup);
}
for def in class_def.methods.iter() {
for def in class_def.all_methods() {
self.check_expr(def);
}
}
Expand Down
61 changes: 46 additions & 15 deletions crates/erg_compiler/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2301,14 +2301,19 @@ impl Def {

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Methods {
pub class: TypeSpec,
pub vis: Token, // `.` or `::`
pub defs: RecordAttrs, // TODO: allow declaration
pub class: Type,
pub impl_trait: Option<Type>,
pub defs: Block,
}

impl NestedDisplay for Methods {
fn fmt_nest(&self, f: &mut fmt::Formatter<'_>, level: usize) -> fmt::Result {
writeln!(f, "{}{}", self.class, self.vis.content)?;
writeln!(
f,
"{} {}",
self.class,
fmt_option!("|<: ", &self.impl_trait, "|"),
)?;
self.defs.fmt_nest(f, level + 1)
}
}
Expand All @@ -2317,16 +2322,16 @@ impl NestedDisplay for Methods {
impl NoTypeDisplay for Methods {
fn to_string_notype(&self) -> String {
format!(
"{}{} {}",
"{} {} {}",
self.class,
self.vis.content,
fmt_option!("|<: ", &self.impl_trait, "|"),
self.defs.to_string_notype()
)
}
}

impl_display_from_nested!(Methods);
impl_locational!(Methods, class, defs);
impl_locational!(Methods, defs);

impl HasType for Methods {
#[inline]
Expand All @@ -2348,8 +2353,12 @@ impl HasType for Methods {
}

impl Methods {
pub const fn new(class: TypeSpec, vis: Token, defs: RecordAttrs) -> Self {
Self { class, vis, defs }
pub const fn new(class: Type, impl_trait: Option<Type>, defs: Block) -> Self {
Self {
class,
impl_trait,
defs,
}
}
}

Expand All @@ -2361,26 +2370,32 @@ pub struct ClassDef {
/// The type of `new` that is automatically defined if not defined
pub need_to_gen_new: bool,
pub __new__: Type,
pub methods: Block,
pub methods_list: Vec<Methods>,
}

impl NestedDisplay for ClassDef {
fn fmt_nest(&self, f: &mut fmt::Formatter<'_>, level: usize) -> fmt::Result {
self.sig.fmt_nest(f, level)?;
writeln!(f, ":")?;
self.methods.fmt_nest(f, level + 1)
fmt_lines(self.methods_list.iter(), f, level)
}
}

// TODO
impl NoTypeDisplay for ClassDef {
fn to_string_notype(&self) -> String {
format!("{}: {}", self.sig, self.methods.to_string_notype())
let methods = self
.methods_list
.iter()
.map(|m| m.to_string_notype())
.collect::<Vec<_>>()
.join("\n");
format!("{}: {methods}", self.sig)
}
}

impl_display_from_nested!(ClassDef);
impl_locational!(ClassDef, sig, lossy methods);
impl_locational!(ClassDef, sig, lossy methods_list);

impl HasType for ClassDef {
#[inline]
Expand Down Expand Up @@ -2408,16 +2423,32 @@ impl ClassDef {
require_or_sup: Option<Expr>,
need_to_gen_new: bool,
__new__: Type,
methods: Block,
methods_list: Vec<Methods>,
) -> Self {
Self {
obj,
sig,
require_or_sup: require_or_sup.map(Box::new),
need_to_gen_new,
__new__,
methods,
methods_list,
}
}

pub fn all_methods(&self) -> impl Iterator<Item = &Expr> {
self.methods_list.iter().flat_map(|m| m.defs.iter())
}

pub fn all_methods_mut(&mut self) -> impl Iterator<Item = &mut Expr> {
self.methods_list.iter_mut().flat_map(|m| m.defs.iter_mut())
}

pub fn take_all_methods(methods_list: Vec<Methods>) -> Block {
let mut joined = Block::empty();
for methods in methods_list {
joined.extend(methods.defs);
}
joined
}
}

Expand Down
Loading

0 comments on commit 21c937e

Please sign in to comment.