diff --git a/crates/erg_compiler/context/eval.rs b/crates/erg_compiler/context/eval.rs index 9bd280aab..56c7145d0 100644 --- a/crates/erg_compiler/context/eval.rs +++ b/crates/erg_compiler/context/eval.rs @@ -163,7 +163,9 @@ impl<'c> Substituter<'c> { let qtps = qt.typarams(); let stps = st.typarams(); if qt.qual_name() != st.qual_name() || qtps.len() != stps.len() { - if let Some(sub) = st.get_sub() { + if let Some(inner) = st.ref_inner().or_else(|| st.ref_mut_inner()) { + return Self::substitute_typarams(ctx, qt, &inner); + } else if let Some(sub) = st.get_sub() { return Self::substitute_typarams(ctx, qt, &sub); } log!(err "{qt} / {st}"); diff --git a/crates/erg_compiler/context/mod.rs b/crates/erg_compiler/context/mod.rs index 9e063f5ed..40a44fd87 100644 --- a/crates/erg_compiler/context/mod.rs +++ b/crates/erg_compiler/context/mod.rs @@ -1121,6 +1121,7 @@ impl Context { }; log!(info "{}: current namespace: {name}", fn_name!()); self.outer = Some(Box::new(mem::take(self))); + // self.level += 1; if let Some(tv_cache) = tv_cache.as_ref() { self.assign_bounds(tv_cache) }; diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index 6336d94b2..b4d7db7e4 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -872,7 +872,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { maybe_sup.update_constraint(new_constraint, self.undoable, false); maybe_sub.link(maybe_sup, self.undoable); } else { - maybe_sup.update_constraint(new_constraint, self.undoable, false); + maybe_sub.update_constraint(new_constraint, self.undoable, false); maybe_sup.link(maybe_sub, self.undoable); } } diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index d6396b237..f34d51fe5 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -1611,7 +1611,14 @@ impl ASTLowerer { body.t(), ) }; - let t = if ty.has_qvar() { ty.quantify() } else { ty }; + let t = if ty.has_unbound_var() { + // TODO: + // ty.lift(); + // self.module.context.generalize_t(ty) + ty.quantify() + } else { + ty + }; Ok(hir::Lambda::new(id, params, lambda.op, body, t)) } diff --git a/crates/erg_compiler/tests/infer.er b/crates/erg_compiler/tests/infer.er index d80ad7090..f095f861c 100644 --- a/crates/erg_compiler/tests/infer.er +++ b/crates/erg_compiler/tests/infer.er @@ -1,13 +1,15 @@ id x = x +id2 = x -> x if__ cond, then, else = if cond, then, else for__! i, proc! = for! i, proc! add x, y = x + y +add2 = (x, y) -> x + y abs_ x = x.abs() - +abs2 = x -> x.abs() Norm = Trait { .norm = (self: Self) -> Nat } norm x = x.norm() diff --git a/crates/erg_compiler/tests/test.rs b/crates/erg_compiler/tests/test.rs index e4f1cb7ef..569cd5eda 100644 --- a/crates/erg_compiler/tests/test.rs +++ b/crates/erg_compiler/tests/test.rs @@ -36,6 +36,7 @@ fn _test_infer_types() -> Result<(), ()> { let u = type_q("U"); let id_t = func1(t.clone(), t.clone()).quantify(); module.context.assert_var_type("id", &id_t)?; + module.context.assert_var_type("id2", &id_t)?; let tu = or(t.clone(), u.clone()); let if_t = nd_func( vec![ @@ -62,8 +63,10 @@ fn _test_infer_types() -> Result<(), ()> { let o = a.clone().proj("Output"); let add_t = func2(a, t, o).quantify(); module.context.assert_var_type("add", &add_t)?; + module.context.assert_var_type("add2", &add_t)?; let abs_t = func1(Int, Nat); module.context.assert_var_type("abs_", &abs_t)?; + module.context.assert_var_type("abs2", &abs_t)?; let norm_t = func1(mono("::Norm"), Nat); module.context.assert_var_type("norm", &norm_t)?; let a_t = array_t( diff --git a/crates/erg_compiler/ty/free.rs b/crates/erg_compiler/ty/free.rs index 692a980f1..78e913bfb 100644 --- a/crates/erg_compiler/ty/free.rs +++ b/crates/erg_compiler/ty/free.rs @@ -222,6 +222,22 @@ impl Constraint { _ => None, } } + + /// e.g. + /// ```erg + /// old_sub: ?T, constraint: (:> ?T or NoneType, <: Obj) + /// -> constraint: (:> NoneType, <: Obj) + /// ``` + pub fn eliminate_recursion(self, target: &Type) -> Self { + match self { + Self::Sandwiched { sub, sup } => { + let sub = sub.eliminate(target); + let sup = sup.eliminate(target); + Self::new_sandwiched(sub, sup) + } + other => other, + } + } } pub trait CanbeFree { diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index 31c5c7646..11380912d 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -2041,6 +2041,33 @@ impl Type { } } + pub fn ref_inner(&self) -> Option { + match self { + Self::FreeVar(fv) if fv.is_linked() => fv.crack().ref_inner(), + Self::Ref(t) => Some(t.as_ref().clone()), + Self::Refinement(refine) => refine.t.ref_inner(), + _ => None, + } + } + + pub fn is_refmut(&self) -> bool { + match self { + Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_refmut(), + Self::RefMut { .. } => true, + Self::Refinement(refine) => refine.t.is_refmut(), + _ => false, + } + } + + pub fn ref_mut_inner(&self) -> Option { + match self { + Self::FreeVar(fv) if fv.is_linked() => fv.crack().ref_mut_inner(), + Self::RefMut { before, .. } => Some(before.as_ref().clone()), + Self::Refinement(refine) => refine.t.ref_mut_inner(), + _ => None, + } + } + pub fn is_structural(&self) -> bool { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_structural(), @@ -3198,6 +3225,32 @@ impl Type { } } + pub fn eliminate(self, target: &Type) -> Self { + match self { + Self::FreeVar(fv) if fv.is_linked() => { + let t = fv.crack().clone(); + t.eliminate(target) + } + Self::And(l, r) => { + if l.addr_eq(target) { + return r.eliminate(target); + } else if r.addr_eq(target) { + return l.eliminate(target); + } + l.eliminate(target) & r.eliminate(target) + } + Self::Or(l, r) => { + if l.addr_eq(target) { + return r.eliminate(target); + } else if r.addr_eq(target) { + return l.eliminate(target); + } + l.eliminate(target) | r.eliminate(target) + } + other => other, + } + } + pub fn replace(self, target: &Type, to: &Type) -> Type { let table = ReplaceTable::make(target, to); table.replace(self) @@ -3506,6 +3559,7 @@ impl Type { list: Option<&UndoableLinkedList>, in_instantiation: bool, ) { + let new_constraint = new_constraint.eliminate_recursion(self); if let Some(list) = list { self.undoable_update_constraint(new_constraint, list); } else {