diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index 7630ab196..577a458cb 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -199,6 +199,21 @@ impl Context { .map(|(opt_name, vi)| (opt_name.as_ref().unwrap(), vi)) } + pub fn get_method_kv(&self, name: &str) -> Option<(&VarName, &VarInfo)> { + #[cfg(feature = "py_compat")] + let name = self.erg_to_py_names.get(name).map_or(name, |s| &s[..]); + self.get_var_kv(name) + .or_else(|| { + for methods in self.methods_list.iter() { + if let Some(vi) = methods.get_method_kv(name) { + return Some(vi); + } + } + None + }) + .or_else(|| self.get_outer().and_then(|ctx| ctx.get_method_kv(name))) + } + pub fn get_singular_ctxs_by_hir_expr( &self, obj: &hir::Expr, diff --git a/crates/erg_compiler/context/register.rs b/crates/erg_compiler/context/register.rs index 5cc228b68..559e8f7eb 100644 --- a/crates/erg_compiler/context/register.rs +++ b/crates/erg_compiler/context/register.rs @@ -741,6 +741,91 @@ impl Context { } } + fn unify_params_t( + &self, + sig: &ast::SubrSignature, + registered_t: &SubrType, + params: &hir::Params, + body_t: &Type, + body_loc: &impl Locational, + ) -> TyCheckResult<()> { + let name = &sig.ident.name; + let mut errs = TyCheckErrors::empty(); + for (param, pt) in params + .non_defaults + .iter() + .zip(registered_t.non_default_params.iter()) + { + pt.typ().lower(); + if let Err(es) = self.force_sub_unify(¶m.vi.t, pt.typ(), param, None) { + errs.extend(es); + } + pt.typ().lift(); + } + // TODO: var_params: [Int; _], pt: Int + /*if let Some((var_params, pt)) = params.var_params.as_deref().zip(registered_t.var_params.as_ref()) { + pt.typ().lower(); + if let Err(es) = self.force_sub_unify(&var_params.vi.t, pt.typ(), var_params, None) { + errs.extend(es); + } + pt.typ().lift(); + }*/ + for (param, pt) in params + .defaults + .iter() + .zip(registered_t.default_params.iter()) + { + pt.typ().lower(); + if let Err(es) = self.force_sub_unify(¶m.sig.vi.t, pt.typ(), param, None) { + errs.extend(es); + } + pt.typ().lift(); + } + let spec_ret_t = registered_t.return_t.as_ref(); + // spec_ret_t.lower(); + let unify_return_result = if let Some(t_spec) = sig.return_t_spec.as_ref() { + self.force_sub_unify(body_t, spec_ret_t, t_spec, None) + } else { + self.force_sub_unify(body_t, spec_ret_t, body_loc, None) + }; + // spec_ret_t.lift(); + if let Err(unify_errs) = unify_return_result { + let es = TyCheckErrors::new( + unify_errs + .into_iter() + .map(|e| { + let expect = if cfg!(feature = "debug") { + spec_ret_t.clone() + } else { + self.readable_type(spec_ret_t.clone()) + }; + let found = if cfg!(feature = "debug") { + body_t.clone() + } else { + self.readable_type(body_t.clone()) + }; + TyCheckError::return_type_error( + self.cfg.input.clone(), + line!() as usize, + e.core.get_loc_with_fallback(), + e.caused_by, + readable_name(name.inspect()), + &expect, + &found, + // e.core.get_hint().map(|s| s.to_string()), + ) + }) + .collect(), + ); + errs.extend(es); + } + if errs.is_empty() { + Ok(()) + } else { + Err(errs) + } + } + /// ## Errors /// * TypeError: if `return_t` != typeof `body` /// * AssignError: if `name` has already been registered @@ -748,6 +833,7 @@ impl Context { &mut self, sig: &ast::SubrSignature, id: DefId, + params: &hir::Params, body_t: &Type, body_loc: &impl Locational, ) -> Result { @@ -772,63 +858,27 @@ impl Context { }; let name = &sig.ident.name; // FIXME: constでない関数 - let t = self.get_current_scope_var(name).map(|vi| &vi.t).unwrap(); - debug_assert!(t.is_subr(), "{t} is not subr"); - let empty = vec![]; - let non_default_params = t.non_default_params().unwrap_or(&empty); - let var_args = t.var_params(); - let default_params = t.default_params().unwrap_or(&empty); - if let Some(spec_ret_t) = t.return_t() { - let unify_result = if let Some(t_spec) = sig.return_t_spec.as_ref() { - self.sub_unify(body_t, spec_ret_t, t_spec, None) - } else { - self.sub_unify(body_t, spec_ret_t, body_loc, None) - }; - if let Err(unify_errs) = unify_result { - let es = TyCheckErrors::new( - unify_errs - .into_iter() - .map(|e| { - let expect = if cfg!(feature = "debug") { - spec_ret_t.clone() - } else { - self.readable_type(spec_ret_t.clone()) - }; - let found = if cfg!(feature = "debug") { - body_t.clone() - } else { - self.readable_type(body_t.clone()) - }; - TyCheckError::return_type_error( - self.cfg.input.clone(), - line!() as usize, - e.core.get_loc_with_fallback(), - e.caused_by, - readable_name(name.inspect()), - &expect, - &found, - // e.core.get_hint().map(|s| s.to_string()), - ) - }) - .collect(), - ); - errs.extend(es); - } + let subr_t = self.get_current_scope_var(name).map(|vi| &vi.t).unwrap(); + let Ok(subr_t) = <&SubrType>::try_from(subr_t) else { + panic!("{subr_t} is not subr"); + }; + if let Err(es) = self.unify_params_t(sig, subr_t, params, body_t, body_loc) { + errs.extend(es); } // NOTE: not `body_t.clone()` because the body may contain `return` - let return_t = t.return_t().unwrap().clone(); + let return_t = subr_t.return_t.as_ref().clone(); let sub_t = if sig.ident.is_procedural() { proc( - non_default_params.clone(), - var_args.cloned(), - default_params.clone(), + subr_t.non_default_params.clone(), + subr_t.var_params.as_deref().cloned(), + subr_t.default_params.clone(), return_t, ) } else { func( - non_default_params.clone(), - var_args.cloned(), - default_params.clone(), + subr_t.non_default_params.clone(), + subr_t.var_params.as_deref().cloned(), + subr_t.default_params.clone(), return_t, ) }; diff --git a/crates/erg_compiler/context/test.rs b/crates/erg_compiler/context/test.rs index 784ac6d73..db6621628 100644 --- a/crates/erg_compiler/context/test.rs +++ b/crates/erg_compiler/context/test.rs @@ -24,6 +24,22 @@ impl Context { } } + pub fn assert_attr_type(&self, receiver_t: &Type, attr: &str, ty: &Type) -> Result<(), ()> { + let Some(ctx) = self.get_nominal_type_ctx(receiver_t) else { + panic!("type not found: {receiver_t}"); + }; + let Some((_, vi)) = ctx.get_method_kv(attr) else { + panic!("attribute not found: {attr}"); + }; + println!("{attr}: {}", vi.t); + if vi.t.structural_eq(ty) { + Ok(()) + } else { + println!("{attr} is not the type of {ty}"); + Err(()) + } + } + pub fn test_refinement_subtyping(&self) -> Result<(), ()> { // Nat :> {I: Int | I >= 1} ? let lhs = Nat; diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index fbb087456..49e41086a 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -31,6 +31,7 @@ pub struct Unifier<'c, 'l, 'u, L: Locational> { ctx: &'c Context, loc: &'l L, undoable: Option<&'u UndoableLinkedList>, + change_generalized: bool, param_name: Option, } @@ -39,12 +40,14 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { ctx: &'c Context, loc: &'l L, undoable: Option<&'u UndoableLinkedList>, + change_generalized: bool, param_name: Option, ) -> Self { Self { ctx, loc, undoable, + change_generalized, param_name, } } @@ -326,7 +329,11 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } Ok(()) } - (TyParam::FreeVar(sub_fv), _) if sub_fv.is_generalized() => Ok(()), + (TyParam::FreeVar(sub_fv), _) + if !self.change_generalized && sub_fv.is_generalized() => + { + Ok(()) + } (TyParam::FreeVar(sub_fv), sup_tp) => { match &*sub_fv.borrow() { FreeKind::Linked(l) | FreeKind::UndoableLinked { t: l, .. } => { @@ -366,7 +373,11 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { ))) } } - (_, TyParam::FreeVar(sup_fv)) if sup_fv.is_generalized() => Ok(()), + (_, TyParam::FreeVar(sup_fv)) + if !self.change_generalized && sup_fv.is_generalized() => + { + Ok(()) + } (sub_tp, TyParam::FreeVar(sup_fv)) => { match &*sup_fv.borrow() { FreeKind::Linked(l) | FreeKind::UndoableLinked { t: l, .. } => { @@ -760,7 +771,8 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { (FreeVar(sub_fv), FreeVar(sup_fv)) if sub_fv.constraint_is_sandwiched() && sup_fv.constraint_is_sandwiched() => { - if sub_fv.is_generalized() || sup_fv.is_generalized() { + if !self.change_generalized && (sub_fv.is_generalized() || sup_fv.is_generalized()) + { log!(info "generalized:\nmaybe_sub: {maybe_sub}\nmaybe_sup: {maybe_sup}"); return Ok(()); } @@ -860,7 +872,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { }, FreeVar(sup_fv), ) if sup_fv.constraint_is_sandwiched() => { - if sup_fv.is_generalized() { + if !self.change_generalized && sup_fv.is_generalized() { log!(info "generalized:\nmaybe_sub: {maybe_sub}\nmaybe_sup: {maybe_sup}"); return Ok(()); } @@ -958,7 +970,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { // e.g. Structural({ .method = (self: T) -> Int })/T (Structural(sub), FreeVar(sup_fv)) if sup_fv.is_unbound() && sub.contains_tvar(sup_fv) => {} - (_, FreeVar(sup_fv)) if sup_fv.is_generalized() => {} + (_, FreeVar(sup_fv)) if !self.change_generalized && sup_fv.is_generalized() => {} (_, FreeVar(sup_fv)) if sup_fv.is_unbound() => { // * sub_unify(Nat, ?E(<: Eq(?E))) // sub !<: l => OK (sub will widen) @@ -1037,7 +1049,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { (FreeVar(sub_fv), Ref(sup)) if sub_fv.is_unbound() => { self.sub_unify(maybe_sub, sup)?; } - (FreeVar(sub_fv), _) if sub_fv.is_generalized() => {} + (FreeVar(sub_fv), _) if !self.change_generalized && sub_fv.is_generalized() => {} (FreeVar(sub_fv), _) if sub_fv.is_unbound() => { // sub !<: r => Error // * sub_unify(?T(:> Int, <: _), Nat): (/* Error */) @@ -1165,7 +1177,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { .iter() .zip(sup_subr.non_default_params.iter()) .try_for_each(|(sub, sup)| { - if sub.typ().is_generalized() { + if !self.change_generalized && sub.typ().is_generalized() { Ok(()) } // contravariant @@ -1179,7 +1191,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { .iter() .find(|sub_pt| sub_pt.name() == sup_pt.name()) { - if sup_pt.typ().is_generalized() { + if !self.change_generalized && sup_pt.typ().is_generalized() { continue; } // contravariant @@ -1203,7 +1215,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { .zip(sup_subr.non_default_params.iter()) .try_for_each(|(sub, sup)| { // contravariant - if sup.typ().is_generalized() { + if !self.change_generalized && sup.typ().is_generalized() { Ok(()) } else { self.sub_unify(sup.typ(), sub.typ()) @@ -1216,7 +1228,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { .find(|sub_pt| sub_pt.name() == sup_pt.name()) { // contravariant - if sup_pt.typ().is_generalized() { + if !self.change_generalized && sup_pt.typ().is_generalized() { continue; } self.sub_unify(sup_pt.typ(), sub_pt.typ())?; @@ -1555,7 +1567,7 @@ impl Context { maybe_sup: &Type, loc: &impl Locational, ) -> TyCheckResult<()> { - let unifier = Unifier::new(self, loc, None, None); + let unifier = Unifier::new(self, loc, None, false, None); unifier.occur(maybe_sub, maybe_sup) } @@ -1567,7 +1579,7 @@ impl Context { loc: &impl Locational, is_structural: bool, ) -> TyCheckResult<()> { - let unifier = Unifier::new(self, loc, None, None); + let unifier = Unifier::new(self, loc, None, false, None); unifier.sub_unify_tp(maybe_sub, maybe_sup, variance, is_structural) } @@ -1579,7 +1591,19 @@ impl Context { loc: &impl Locational, param_name: Option<&Str>, ) -> TyCheckResult<()> { - let unifier = Unifier::new(self, loc, None, param_name.cloned()); + let unifier = Unifier::new(self, loc, None, false, param_name.cloned()); + unifier.sub_unify(maybe_sub, maybe_sup) + } + + /// This will rewrite generalized type variables. + pub(crate) fn force_sub_unify( + &self, + maybe_sub: &Type, + maybe_sup: &Type, + loc: &impl Locational, + param_name: Option<&Str>, + ) -> TyCheckResult<()> { + let unifier = Unifier::new(self, loc, None, true, param_name.cloned()); unifier.sub_unify(maybe_sub, maybe_sup) } @@ -1591,12 +1615,12 @@ impl Context { list: &UndoableLinkedList, param_name: Option<&Str>, ) -> TyCheckResult<()> { - let unifier = Unifier::new(self, loc, Some(list), param_name.cloned()); + let unifier = Unifier::new(self, loc, Some(list), false, param_name.cloned()); unifier.sub_unify(maybe_sub, maybe_sup) } pub(crate) fn unify(&self, lhs: &Type, rhs: &Type) -> Option { - let unifier = Unifier::new(self, &(), None, None); + let unifier = Unifier::new(self, &(), None, false, None); unifier.unify(lhs, rhs) } } diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index 30918c951..0dcd39ae8 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -1956,25 +1956,26 @@ impl ASTLowerer { fn lower_subr_block( &mut self, - subr_t: SubrType, + registered_subr_t: SubrType, sig: ast::SubrSignature, decorators: Set, body: ast::DefBody, ) -> LowerResult { - let params = self.lower_params(sig.params.clone(), Some(&subr_t))?; + let params = self.lower_params(sig.params.clone(), Some(®istered_subr_t))?; if let Err(errs) = self.module.context.register_const(&body.block) { self.errs.extend(errs); } - let return_t = subr_t + let return_t = registered_subr_t .return_t .has_no_unbound_var() - .then_some(subr_t.return_t.as_ref()); + .then_some(registered_subr_t.return_t.as_ref()); match self.lower_block(body.block, return_t) { Ok(block) => { let found_body_t = self.module.context.squash_tyvar(block.t()); let vi = match self.module.context.outer.as_mut().unwrap().assign_subr( &sig, body.id, + ¶ms, &found_body_t, block.last().unwrap(), ) { @@ -2009,6 +2010,7 @@ impl ASTLowerer { let vi = match self.module.context.outer.as_mut().unwrap().assign_subr( &sig, ast::DefId(0), + ¶ms, &Type::Failure, &sig, ) { diff --git a/crates/erg_compiler/tests/infer.er b/crates/erg_compiler/tests/infer.er index f095f861c..844c0210d 100644 --- a/crates/erg_compiler/tests/infer.er +++ b/crates/erg_compiler/tests/infer.er @@ -25,3 +25,8 @@ f! t = for! arr, t => result.extend! f! t result + +c_new x, y = C.new x, y +C = Class Int +C. + new x, y = Self::__new__ x + y diff --git a/crates/erg_compiler/tests/test.rs b/crates/erg_compiler/tests/test.rs index 569cd5eda..80bbec857 100644 --- a/crates/erg_compiler/tests/test.rs +++ b/crates/erg_compiler/tests/test.rs @@ -1,3 +1,5 @@ +use std::vec; + use erg_common::config::ErgConfig; use erg_common::error::MultiErrorDisplay; use erg_common::io::Output; @@ -79,6 +81,12 @@ fn _test_infer_types() -> Result<(), ()> { let t = type_q("T"); let f_t = proc1(t.clone(), unknown_len_array_mut(t)).quantify(); module.context.assert_var_type("f!", &f_t)?; + let r = type_q("R"); + let add_r = poly("Add", vec![ty_tp(r.clone())]); + let c = mono("::C"); + let c_new_t = func2(add_r, r, c.clone()).quantify(); + module.context.assert_var_type("c_new", &c_new_t)?; + module.context.assert_attr_type(&c, "new", &c_new_t)?; Ok(()) } diff --git a/crates/erg_compiler/ty/free.rs b/crates/erg_compiler/ty/free.rs index d08fe909e..5257a4c0a 100644 --- a/crates/erg_compiler/ty/free.rs +++ b/crates/erg_compiler/ty/free.rs @@ -23,7 +23,7 @@ static UNBOUND_ID: AtomicUsize = AtomicUsize::new(0); pub trait HasLevel { fn level(&self) -> Option; fn set_level(&self, lev: Level); - fn lower(&self, level: Level) { + fn set_lower(&self, level: Level) { if self.level() < Some(level) { self.set_level(level); } @@ -33,6 +33,11 @@ pub trait HasLevel { self.set_level(lev.saturating_add(1)); } } + fn lower(&self) { + if let Some(lev) = self.level() { + self.set_level(lev.saturating_sub(1)); + } + } fn generalize(&self) { self.set_level(GENERIC_LEVEL); }