From b0c31370c519de7d16149dd4db0fa78d91f6dad5 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 14 Sep 2024 21:20:05 +0900 Subject: [PATCH 01/12] fix: `Type::{And, Or}(Set)` --- crates/els/completion.rs | 28 +- crates/erg_common/set.rs | 14 + crates/erg_common/traits.rs | 2 + crates/erg_common/triple.rs | 10 + crates/erg_compiler/context/compare.rs | 112 +++--- crates/erg_compiler/context/eval.rs | 58 ++- crates/erg_compiler/context/generalize.rs | 77 ++-- crates/erg_compiler/context/hint.rs | 9 +- crates/erg_compiler/context/inquire.rs | 140 ++++---- crates/erg_compiler/context/instantiate.rs | 68 ++-- crates/erg_compiler/context/unify.rs | 174 +++++---- crates/erg_compiler/lower.rs | 7 +- crates/erg_compiler/ty/constructors.rs | 28 +- crates/erg_compiler/ty/mod.rs | 400 ++++++++++++++------- 14 files changed, 661 insertions(+), 466 deletions(-) diff --git a/crates/els/completion.rs b/crates/els/completion.rs index 47c465110..13cb818ff 100644 --- a/crates/els/completion.rs +++ b/crates/els/completion.rs @@ -42,24 +42,30 @@ fn comp_item_kind(t: &Type, muty: Mutability) -> CompletionItemKind { Type::Subr(_) | Type::Quantified(_) => CompletionItemKind::FUNCTION, Type::ClassType => CompletionItemKind::CLASS, Type::TraitType => CompletionItemKind::INTERFACE, - Type::Or(l, r) => { - let l = comp_item_kind(l, muty); - let r = comp_item_kind(r, muty); - if l == r { - l + Type::Or(tys) => { + let fst = comp_item_kind(tys.iter().next().unwrap(), muty); + if tys + .iter() + .map(|t| comp_item_kind(t, muty)) + .all(|k| k == fst) + { + fst } else if muty.is_const() { CompletionItemKind::CONSTANT } else { CompletionItemKind::VARIABLE } } - Type::And(l, r) => { - let l = comp_item_kind(l, muty); - let r = comp_item_kind(r, muty); - if l == CompletionItemKind::VARIABLE { - r + Type::And(tys) => { + for k in tys.iter().map(|t| comp_item_kind(t, muty)) { + if k != CompletionItemKind::VARIABLE { + return k; + } + } + if muty.is_const() { + CompletionItemKind::CONSTANT } else { - l + CompletionItemKind::VARIABLE } } Type::Refinement(r) => comp_item_kind(&r.t, muty), diff --git a/crates/erg_common/set.rs b/crates/erg_common/set.rs index c925cf1e9..8a2fe180a 100644 --- a/crates/erg_common/set.rs +++ b/crates/erg_common/set.rs @@ -382,6 +382,20 @@ impl Set { self.insert(other); self } + + /// ``` + /// # use erg_common::set; + /// assert_eq!(set!{1, 2}.product(&set!{3, 4}), set!{(&1, &3), (&1, &4), (&2, &3), (&2, &4)}); + /// ``` + pub fn product<'l, 'r, U: Hash + Eq>(&'l self, other: &'r Set) -> Set<(&'l T, &'r U)> { + let mut res = set! {}; + for x in self.iter() { + for y in other.iter() { + res.insert((x, y)); + } + } + res + } } impl Set { diff --git a/crates/erg_common/traits.rs b/crates/erg_common/traits.rs index 7c0c84bf1..168a268f7 100644 --- a/crates/erg_common/traits.rs +++ b/crates/erg_common/traits.rs @@ -1407,6 +1407,8 @@ impl Immutable for &T {} impl Immutable for Option {} impl Immutable for Vec {} impl Immutable for [T] {} +impl Immutable for (T, U) {} +impl Immutable for (T, U, V) {} impl Immutable for Box {} impl Immutable for std::rc::Rc {} impl Immutable for std::sync::Arc {} diff --git a/crates/erg_common/triple.rs b/crates/erg_common/triple.rs index a743d43ab..a3626d3d3 100644 --- a/crates/erg_common/triple.rs +++ b/crates/erg_common/triple.rs @@ -18,6 +18,16 @@ impl fmt::Display for Triple { } impl Triple { + pub const fn is_ok(&self) -> bool { + matches!(self, Triple::Ok(_)) + } + pub const fn is_err(&self) -> bool { + matches!(self, Triple::Err(_)) + } + pub const fn is_none(&self) -> bool { + matches!(self, Triple::None) + } + pub fn none_then(self, err: E) -> Result { match self { Triple::None => Err(err), diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index 8e3dc2908..911e53757 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -754,9 +754,7 @@ impl Context { self.structural_supertype_of(&l, rhs) } // {1, 2, 3} :> {1} or {2, 3} == true - (Refinement(_refine), Or(l, r)) => { - self.supertype_of(lhs, l) && self.supertype_of(lhs, r) - } + (Refinement(_refine), Or(tys)) => tys.iter().all(|ty| self.supertype_of(lhs, ty)), // ({I: Int | True} :> Int) == true // {N: Nat | ...} :> Int) == false // ({I: Int | I >= 0} :> Int) == false @@ -808,41 +806,20 @@ impl Context { self.sub_unify(&inst, l, &(), None).is_ok() } // Int or Str :> Str or Int == (Int :> Str && Str :> Int) || (Int :> Int && Str :> Str) == true - (Or(l_1, l_2), Or(r_1, r_2)) => { - if l_1.is_union_type() && self.supertype_of(l_1, rhs) { - return true; - } - if l_2.is_union_type() && self.supertype_of(l_2, rhs) { - return true; - } - (self.supertype_of(l_1, r_1) && self.supertype_of(l_2, r_2)) - || (self.supertype_of(l_1, r_2) && self.supertype_of(l_2, r_1)) - } + // Int or Str or NoneType :> Str or Int + (Or(l), Or(r)) => r.iter().all(|r| l.iter().any(|l| self.supertype_of(l, r))), // not Nat :> not Int == true (Not(l), Not(r)) => self.subtype_of(l, r), // (Int or Str) :> Nat == Int :> Nat || Str :> Nat == true // (Num or Show) :> Show == Num :> Show || Show :> Num == true - (Or(l_or, r_or), rhs) => self.supertype_of(l_or, rhs) || self.supertype_of(r_or, rhs), + (Or(ors), rhs) => ors.iter().any(|or| self.supertype_of(or, rhs)), // Int :> (Nat or Str) == Int :> Nat && Int :> Str == false - (lhs, Or(l_or, r_or)) => self.supertype_of(lhs, l_or) && self.supertype_of(lhs, r_or), - (And(l_1, l_2), And(r_1, r_2)) => { - if l_1.is_intersection_type() && self.supertype_of(l_1, rhs) { - return true; - } - if l_2.is_intersection_type() && self.supertype_of(l_2, rhs) { - return true; - } - (self.supertype_of(l_1, r_1) && self.supertype_of(l_2, r_2)) - || (self.supertype_of(l_1, r_2) && self.supertype_of(l_2, r_1)) - } + (lhs, Or(ors)) => ors.iter().all(|or| self.supertype_of(lhs, or)), + (And(l), And(r)) => r.iter().any(|r| l.iter().all(|l| self.supertype_of(l, r))), // (Num and Show) :> Show == false - (And(l_and, r_and), rhs) => { - self.supertype_of(l_and, rhs) && self.supertype_of(r_and, rhs) - } + (And(ands), rhs) => ands.iter().all(|and| self.supertype_of(and, rhs)), // Show :> (Num and Show) == true - (lhs, And(l_and, r_and)) => { - self.supertype_of(lhs, l_and) || self.supertype_of(lhs, r_and) - } + (lhs, And(ands)) => ands.iter().any(|and| self.supertype_of(lhs, and)), // Not(Eq) :> Float == !(Eq :> Float) == true (Not(_), Obj) => false, (Not(l), rhs) => !self.supertype_of(l, rhs), @@ -914,18 +891,18 @@ impl Context { Type::NamedTuple(fields) => fields.iter().cloned().collect(), Type::Refinement(refine) => self.fields(&refine.t), Type::Structural(t) => self.fields(t), - Type::Or(l, r) => { - let l_fields = self.fields(l); - let r_fields = self.fields(r); - let l_field_names = l_fields.keys().collect::>(); - let r_field_names = r_fields.keys().collect::>(); - let field_names = l_field_names.intersection(&r_field_names); + Type::Or(tys) => { + let or_fields = tys.iter().map(|t| self.fields(t)).collect::>(); + let field_names = or_fields + .iter() + .flat_map(|fs| fs.keys()) + .collect::>(); let mut fields = Dict::new(); - for (name, l_t, r_t) in field_names + for (name, tys) in field_names .iter() - .map(|&name| (name, &l_fields[name], &r_fields[name])) + .map(|&name| (name, or_fields.iter().filter_map(|fields| fields.get(name)))) { - let union = self.union(l_t, r_t); + let union = tys.fold(Never, |acc, ty| self.union(&acc, ty)); fields.insert(name.clone(), union); } fields @@ -1408,6 +1385,8 @@ impl Context { /// union({ .a = Int }, { .a = Str }) == { .a = Int or Str } /// union({ .a = Int }, { .a = Int; .b = Int }) == { .a = Int } or { .a = Int; .b = Int } # not to lost `b` information /// union((A and B) or C) == (A or C) and (B or C) + /// union(Never, Int) == Int + /// union(Obj, Int) == Obj /// ``` pub(crate) fn union(&self, lhs: &Type, rhs: &Type) -> Type { if lhs == rhs { @@ -1470,10 +1449,9 @@ impl Context { (Some(sub), Some(sup)) => bounded(sub.clone(), sup.clone()), _ => self.simple_union(lhs, rhs), }, - (other, or @ Or(_, _)) | (or @ Or(_, _), other) => self.union_add(or, other), + (other, or @ Or(_)) | (or @ Or(_), other) => self.union_add(or, other), // (A and B) or C ==> (A or C) and (B or C) - (and_t @ And(_, _), other) | (other, and_t @ And(_, _)) => { - let ands = and_t.ands(); + (And(ands), other) | (other, And(ands)) => { let mut t = Type::Obj; for branch in ands.iter() { let union = self.union(branch, other); @@ -1657,6 +1635,12 @@ impl Context { /// Returns intersection of two types (`A and B`). /// If `A` and `B` have a subtype relationship, it is equal to `min(A, B)`. + /// ```erg + /// intersection(Nat, Int) == Nat + /// intersection(Int, Str) == Never + /// intersection(Obj, Int) == Int + /// intersection(Never, Int) == Never + /// ``` pub(crate) fn intersection(&self, lhs: &Type, rhs: &Type) -> Type { if lhs == rhs { return lhs.clone(); @@ -1687,12 +1671,9 @@ impl Context { (_, Not(r)) => self.diff(lhs, r), (Not(l), _) => self.diff(rhs, l), // A and B and A == A and B - (other, and @ And(_, _)) | (and @ And(_, _), other) => { - self.intersection_add(and, other) - } + (other, and @ And(_)) | (and @ And(_), other) => self.intersection_add(and, other), // (A or B) and C == (A and C) or (B and C) - (or_t @ Or(_, _), other) | (other, or_t @ Or(_, _)) => { - let ors = or_t.ors(); + (Or(ors), other) | (other, Or(ors)) => { if ors.iter().any(|t| t.has_unbound_var()) { return self.simple_intersection(lhs, rhs); } @@ -1827,21 +1808,21 @@ impl Context { fn narrow_type_by_pred(&self, t: Type, pred: &Predicate) -> Type { match (t, pred) { ( - Type::Or(l, r), + Type::Or(tys), Predicate::NotEqual { rhs: TyParam::Value(v), .. }, ) if v.is_none() => { - let l = self.narrow_type_by_pred(*l, pred); - let r = self.narrow_type_by_pred(*r, pred); - if l.is_nonetype() { - r - } else if r.is_nonetype() { - l - } else { - or(l, r) + let mut new_tys = Set::new(); + for ty in tys { + let ty = self.narrow_type_by_pred(ty, pred); + if ty.is_nonelike() { + continue; + } + new_tys.insert(ty); } + Type::checked_or(new_tys) } (Type::Refinement(refine), _) => { let t = self.narrow_type_by_pred(*refine.t, pred); @@ -1983,8 +1964,12 @@ impl Context { guard.target.clone(), self.complement(&guard.to), )), - Or(l, r) => self.intersection(&self.complement(l), &self.complement(r)), - And(l, r) => self.union(&self.complement(l), &self.complement(r)), + Or(ors) => ors + .iter() + .fold(Obj, |l, r| self.intersection(&l, &self.complement(r))), + And(ands) => ands + .iter() + .fold(Never, |l, r| self.union(&l, &self.complement(r))), other => not(other.clone()), } } @@ -2002,7 +1987,14 @@ impl Context { match lhs { Type::FreeVar(fv) if fv.is_linked() => self.diff(&fv.crack(), rhs), // Type::And(l, r) => self.intersection(&self.diff(l, rhs), &self.diff(r, rhs)), - Type::Or(l, r) => self.union(&self.diff(l, rhs), &self.diff(r, rhs)), + Type::Or(tys) => { + let mut new_tys = vec![]; + for ty in tys { + let diff = self.diff(ty, rhs); + new_tys.push(diff); + } + new_tys.into_iter().fold(Never, |l, r| self.union(&l, &r)) + } _ => lhs.clone(), } } diff --git a/crates/erg_compiler/context/eval.rs b/crates/erg_compiler/context/eval.rs index 60e1e64f3..3dce1c24f 100644 --- a/crates/erg_compiler/context/eval.rs +++ b/crates/erg_compiler/context/eval.rs @@ -2285,44 +2285,42 @@ impl Context { Err((t, errs)) } } - Type::And(l, r) => { - let l = match self.eval_t_params(*l, level, t_loc) { - Ok(l) => l, - Err((l, es)) => { - errs.extend(es); - l - } - }; - let r = match self.eval_t_params(*r, level, t_loc) { - Ok(r) => r, - Err((r, es)) => { - errs.extend(es); - r + Type::And(ands) => { + let mut new_ands = set! {}; + for and in ands.into_iter() { + match self.eval_t_params(and, level, t_loc) { + Ok(and) => { + new_ands.insert(and); + } + Err((and, es)) => { + new_ands.insert(and); + errs.extend(es); + } } - }; - let intersec = self.intersection(&l, &r); + } + let intersec = new_ands + .into_iter() + .fold(Type::Obj, |l, r| self.intersection(&l, &r)); if errs.is_empty() { Ok(intersec) } else { Err((intersec, errs)) } } - Type::Or(l, r) => { - let l = match self.eval_t_params(*l, level, t_loc) { - Ok(l) => l, - Err((l, es)) => { - errs.extend(es); - l - } - }; - let r = match self.eval_t_params(*r, level, t_loc) { - Ok(r) => r, - Err((r, es)) => { - errs.extend(es); - r + Type::Or(ors) => { + let mut new_ors = set! {}; + for or in ors.into_iter() { + match self.eval_t_params(or, level, t_loc) { + Ok(or) => { + new_ors.insert(or); + } + Err((or, es)) => { + new_ors.insert(or); + errs.extend(es); + } } - }; - let union = self.union(&l, &r); + } + let union = new_ors.into_iter().fold(Never, |l, r| self.union(&l, &r)); if errs.is_empty() { Ok(union) } else { diff --git a/crates/erg_compiler/context/generalize.rs b/crates/erg_compiler/context/generalize.rs index 13a0049ad..d8899e701 100644 --- a/crates/erg_compiler/context/generalize.rs +++ b/crates/erg_compiler/context/generalize.rs @@ -289,17 +289,21 @@ impl Generalizer { } proj_call(lhs, attr_name, args) } - And(l, r) => { - let l = self.generalize_t(*l, uninit); - let r = self.generalize_t(*r, uninit); + And(ands) => { // not `self.intersection` because types are generalized - and(l, r) + let ands = ands + .into_iter() + .map(|t| self.generalize_t(t, uninit)) + .collect(); + Type::checked_and(ands) } - Or(l, r) => { - let l = self.generalize_t(*l, uninit); - let r = self.generalize_t(*r, uninit); + Or(ors) => { // not `self.union` because types are generalized - or(l, r) + let ors = ors + .into_iter() + .map(|t| self.generalize_t(t, uninit)) + .collect(); + Type::checked_or(ors) } Not(l) => not(self.generalize_t(*l, uninit)), Structural(ty) => { @@ -1045,15 +1049,23 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> { let pred = self.deref_pred(*refine.pred)?; Ok(refinement(refine.var, t, pred)) } - And(l, r) => { - let l = self.deref_tyvar(*l)?; - let r = self.deref_tyvar(*r)?; - Ok(self.ctx.intersection(&l, &r)) + And(ands) => { + let mut new_ands = vec![]; + for t in ands.into_iter() { + new_ands.push(self.deref_tyvar(t)?); + } + Ok(new_ands + .into_iter() + .fold(Type::Obj, |acc, t| self.ctx.intersection(&acc, &t))) } - Or(l, r) => { - let l = self.deref_tyvar(*l)?; - let r = self.deref_tyvar(*r)?; - Ok(self.ctx.union(&l, &r)) + Or(ors) => { + let mut new_ors = vec![]; + for t in ors.into_iter() { + new_ors.push(self.deref_tyvar(t)?); + } + Ok(new_ors + .into_iter() + .fold(Type::Never, |acc, t| self.ctx.union(&acc, &t))) } Not(ty) => { let ty = self.deref_tyvar(*ty)?; @@ -1733,22 +1745,33 @@ impl Context { /// ``` pub(crate) fn squash_tyvar(&self, typ: Type) -> Type { match typ { - Or(l, r) => { - let l = self.squash_tyvar(*l); - let r = self.squash_tyvar(*r); + Or(tys) => { + let new_tys = tys + .into_iter() + .map(|t| self.squash_tyvar(t)) + .collect::>(); + let mut union = Never; // REVIEW: - if l.is_unnamed_unbound_var() && r.is_unnamed_unbound_var() { - match (self.subtype_of(&l, &r), self.subtype_of(&r, &l)) { - (true, true) | (true, false) => { - let _ = self.sub_unify(&l, &r, &(), None); + if new_tys.iter().all(|t| t.is_unnamed_unbound_var()) { + for ty in new_tys.iter() { + if union == Never { + union = ty.clone(); + continue; } - (false, true) => { - let _ = self.sub_unify(&r, &l, &(), None); + match (self.subtype_of(&union, ty), self.subtype_of(&union, ty)) { + (true, true) | (true, false) => { + let _ = self.sub_unify(&union, ty, &(), None); + } + (false, true) => { + let _ = self.sub_unify(ty, &union, &(), None); + } + _ => {} } - _ => {} } } - self.union(&l, &r) + new_tys + .into_iter() + .fold(Never, |acc, t| self.union(&acc, &t)) } FreeVar(ref fv) if fv.constraint_is_sandwiched() => { let (sub_t, super_t) = fv.get_subsup().unwrap(); diff --git a/crates/erg_compiler/context/hint.rs b/crates/erg_compiler/context/hint.rs index e7826d022..b7b795224 100644 --- a/crates/erg_compiler/context/hint.rs +++ b/crates/erg_compiler/context/hint.rs @@ -116,9 +116,12 @@ impl Context { return Some(hint); } } - (Type::And(l, r), found) => { - let left = self.readable_type(l.as_ref().clone()); - let right = self.readable_type(r.as_ref().clone()); + (Type::And(tys), found) if tys.len() == 2 => { + let mut iter = tys.iter(); + let l = iter.next().unwrap(); + let r = iter.next().unwrap(); + let left = self.readable_type(l.clone()); + let right = self.readable_type(r.clone()); if self.supertype_of(l, found) { let msg = switch_lang!( "japanese" => format!("型{found}は{left}のサブタイプですが、{right}のサブタイプではありません"), diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index 70ea0c5be..44aa46b3e 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -1046,35 +1046,35 @@ impl Context { } Type::Structural(t) => self.get_attr_info_from_attributive(t, ident, namespace), // TODO: And - Type::Or(l, r) => { - let l_info = self.get_attr_info_from_attributive(l, ident, namespace); - let r_info = self.get_attr_info_from_attributive(r, ident, namespace); - match (l_info, r_info) { - (Triple::Ok(l), Triple::Ok(r)) => { - let res = self.union(&l.t, &r.t); - let vis = if l.vis.is_public() && r.vis.is_public() { - Visibility::DUMMY_PUBLIC - } else { - Visibility::DUMMY_PRIVATE - }; - let vi = VarInfo::new( - res, - l.muty, - vis, - l.kind, - l.comptime_decos, - l.ctx, - l.py_name, - l.def_loc, - ); - Triple::Ok(vi) - } - (Triple::Ok(_), Triple::Err(e)) | (Triple::Err(e), Triple::Ok(_)) => { - Triple::Err(e) + Type::Or(tys) => { + let mut info = Triple::::None; + for ty in tys { + match ( + self.get_attr_info_from_attributive(ty, ident, namespace), + &info, + ) { + (Triple::Ok(vi), Triple::Ok(vi_)) => { + let res = self.union(&vi.t, &vi_.t); + let vis = if vi.vis.is_public() && vi_.vis.is_public() { + Visibility::DUMMY_PUBLIC + } else { + Visibility::DUMMY_PRIVATE + }; + let vi = VarInfo { t: res, vis, ..vi }; + info = Triple::Ok(vi); + } + (Triple::Ok(vi), Triple::None) => { + info = Triple::Ok(vi); + } + (Triple::Err(err), _) => { + info = Triple::Err(err); + break; + } + (Triple::None, _) => {} + (_, Triple::Err(_)) => unreachable!(), } - (Triple::Err(e1), Triple::Err(_e2)) => Triple::Err(e1), - _ => Triple::None, } + info } _other => Triple::None, } @@ -1952,7 +1952,7 @@ impl Context { res } } - Type::And(_, _) => { + Type::And(_) => { let instance = self.resolve_overload( obj, instance.clone(), @@ -3012,32 +3012,30 @@ impl Context { self.get_nominal_super_type_ctxs(&Type) } } - Type::And(l, r) => { - match ( - self.get_nominal_super_type_ctxs(l), - self.get_nominal_super_type_ctxs(r), - ) { - // TODO: sort - (Some(l), Some(r)) => Some([l, r].concat()), - (Some(l), None) => Some(l), - (None, Some(r)) => Some(r), - (None, None) => None, - } - } - // TODO - Type::Or(l, r) => match (l.as_ref(), r.as_ref()) { - (Type::FreeVar(l), Type::FreeVar(r)) - if l.is_unbound_and_sandwiched() && r.is_unbound_and_sandwiched() => + Type::And(tys) => { + let mut acc = vec![]; + for ctxs in tys + .iter() + .filter_map(|t| self.get_nominal_super_type_ctxs(t)) { - let (_lsub, lsup) = l.get_subsup().unwrap(); - let (_rsub, rsup) = r.get_subsup().unwrap(); - self.get_nominal_super_type_ctxs(&self.union(&lsup, &rsup)) + acc.extend(ctxs); } - (Type::Refinement(l), Type::Refinement(r)) if l.t == r.t => { - self.get_nominal_super_type_ctxs(&l.t) + if acc.is_empty() { + None + } else { + Some(acc) } - _ => self.get_nominal_type_ctx(&Obj).map(|ctx| vec![ctx]), - }, + } + Type::Or(tys) => { + let union = tys + .iter() + .fold(Never, |l, r| self.union(&l, &r.upper_bounded())); + if union.is_union_type() { + self.get_nominal_super_type_ctxs(&Obj) + } else { + self.get_nominal_super_type_ctxs(&union) + } + } _ => self .get_simple_nominal_super_type_ctxs(t) .map(|ctxs| ctxs.collect()), @@ -3231,7 +3229,7 @@ impl Context { .unwrap_or(self) .rec_local_get_mono_type("GenericNamedTuple"); } - Type::Or(_l, _r) => { + Type::Or(_) => { if let Some(ctx) = self.get_nominal_type_ctx(&poly("Or", vec![])) { return Some(ctx); } @@ -3366,26 +3364,27 @@ impl Context { match trait_ { // And(Add, Sub) == intersection({Int <: Add(Int), Bool <: Add(Bool) ...}, {Int <: Sub(Int), ...}) // == {Int <: Add(Int) and Sub(Int), ...} - Type::And(l, r) => { - let l_impls = self.get_trait_impls(l); - let l_base = Set::from_iter(l_impls.iter().map(|ti| &ti.sub_type)); - let r_impls = self.get_trait_impls(r); - let r_base = Set::from_iter(r_impls.iter().map(|ti| &ti.sub_type)); - let bases = l_base.intersection(&r_base); + Type::And(tys) => { + let impls = tys + .iter() + .flat_map(|ty| self.get_trait_impls(ty)) + .collect::>(); + let bases = impls.iter().map(|ti| &ti.sub_type); let mut isec = set! {}; - for base in bases.into_iter() { - let lti = l_impls.iter().find(|ti| &ti.sub_type == base).unwrap(); - let rti = r_impls.iter().find(|ti| &ti.sub_type == base).unwrap(); - let sup_trait = self.intersection(<i.sup_trait, &rti.sup_trait); - isec.insert(TraitImpl::new(lti.sub_type.clone(), sup_trait, None)); + for base in bases { + let base_impls = impls.iter().filter(|ti| ti.sub_type == *base); + let sup_trait = + base_impls.fold(Obj, |l, r| self.intersection(&l, &r.sup_trait)); + if sup_trait != Obj { + isec.insert(TraitImpl::new(base.clone(), sup_trait, None)); + } } isec } - Type::Or(l, r) => { - let l_impls = self.get_trait_impls(l); - let r_impls = self.get_trait_impls(r); + Type::Or(tys) => { // FIXME: - l_impls.union(&r_impls) + tys.iter() + .fold(set! {}, |acc, ty| acc.union(&self.get_trait_impls(ty))) } _ => self.get_simple_trait_impls(trait_), } @@ -3955,11 +3954,11 @@ impl Context { pub fn is_class(&self, typ: &Type) -> bool { match typ { - Type::And(_l, _r) => false, + Type::And(_) => false, Type::Never => true, Type::FreeVar(fv) if fv.is_linked() => self.is_class(&fv.crack()), Type::FreeVar(_) => false, - Type::Or(l, r) => self.is_class(l) && self.is_class(r), + Type::Or(tys) => tys.iter().all(|t| self.is_class(t)), Type::Proj { lhs, rhs } => self .get_proj_candidates(lhs, rhs) .iter() @@ -3982,7 +3981,8 @@ impl Context { Type::Never => false, Type::FreeVar(fv) if fv.is_linked() => self.is_class(&fv.crack()), Type::FreeVar(_) => false, - Type::And(l, r) | Type::Or(l, r) => self.is_trait(l) && self.is_trait(r), + Type::And(tys) => tys.iter().any(|t| self.is_trait(t)), + Type::Or(tys) => tys.iter().all(|t| self.is_trait(t)), Type::Proj { lhs, rhs } => self .get_proj_candidates(lhs, rhs) .iter() diff --git a/crates/erg_compiler/context/instantiate.rs b/crates/erg_compiler/context/instantiate.rs index 42d0c36f6..be4028548 100644 --- a/crates/erg_compiler/context/instantiate.rs +++ b/crates/erg_compiler/context/instantiate.rs @@ -953,15 +953,21 @@ impl Context { let t = self.instantiate_t_inner(*t, tmp_tv_cache, loc)?; Ok(t.structuralize()) } - And(l, r) => { - let l = self.instantiate_t_inner(*l, tmp_tv_cache, loc)?; - let r = self.instantiate_t_inner(*r, tmp_tv_cache, loc)?; - Ok(self.intersection(&l, &r)) + And(tys) => { + let mut new_tys = vec![]; + for ty in tys.iter().cloned() { + new_tys.push(self.instantiate_t_inner(ty, tmp_tv_cache, loc)?); + } + Ok(new_tys + .into_iter() + .fold(Obj, |l, r| self.intersection(&l, &r))) } - Or(l, r) => { - let l = self.instantiate_t_inner(*l, tmp_tv_cache, loc)?; - let r = self.instantiate_t_inner(*r, tmp_tv_cache, loc)?; - Ok(self.union(&l, &r)) + Or(tys) => { + let mut new_tys = vec![]; + for ty in tys.iter().cloned() { + new_tys.push(self.instantiate_t_inner(ty, tmp_tv_cache, loc)?); + } + Ok(new_tys.into_iter().fold(Never, |l, r| self.union(&l, &r))) } Not(ty) => { let ty = self.instantiate_t_inner(*ty, tmp_tv_cache, loc)?; @@ -998,10 +1004,12 @@ impl Context { let t = fv.crack().clone(); self.instantiate(t, callee) } - And(lhs, rhs) => { - let lhs = self.instantiate(*lhs, callee)?; - let rhs = self.instantiate(*rhs, callee)?; - Ok(lhs & rhs) + And(tys) => { + let tys = tys + .into_iter() + .map(|t| self.instantiate(t, callee)) + .collect::>>()?; + Ok(tys.into_iter().fold(Obj, |l, r| l & r)) } Quantified(quant) => { let mut tmp_tv_cache = TyVarCache::new(self.level, self); @@ -1028,22 +1036,16 @@ impl Context { )?; } } - Type::And(l, r) => { - if let Some(self_t) = l.self_t() { - self.sub_unify( - callee.ref_t(), - self_t, - callee, - Some(&Str::ever("self")), - )?; - } - if let Some(self_t) = r.self_t() { - self.sub_unify( - callee.ref_t(), - self_t, - callee, - Some(&Str::ever("self")), - )?; + Type::And(tys) => { + for ty in tys { + if let Some(self_t) = ty.self_t() { + self.sub_unify( + callee.ref_t(), + self_t, + callee, + Some(&Str::ever("self")), + )?; + } } } other => unreachable!("{other}"), @@ -1066,10 +1068,12 @@ impl Context { let t = fv.crack().clone(); self.instantiate_dummy(t) } - And(lhs, rhs) => { - let lhs = self.instantiate_dummy(*lhs)?; - let rhs = self.instantiate_dummy(*rhs)?; - Ok(lhs & rhs) + And(tys) => { + let tys = tys + .into_iter() + .map(|t| self.instantiate_dummy(t)) + .collect::>>()?; + Ok(tys.into_iter().fold(Obj, |l, r| l & r)) } Quantified(quant) => { let mut tmp_tv_cache = TyVarCache::new(self.level, self); diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index 3dc1d5519..10afe974c 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -155,17 +155,38 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } Ok(()) } - (Or(l, r), Or(l2, r2)) | (And(l, r), And(l2, r2)) => self - .occur(l, l2) - .and(self.occur(r, r2)) - .or(self.occur(l, r2).and(self.occur(r, l2))), - (lhs, Or(l, r)) | (lhs, And(l, r)) => { - self.occur_inner(lhs, l)?; - self.occur_inner(lhs, r) - } - (Or(l, r), rhs) | (And(l, r), rhs) => { - self.occur_inner(l, rhs)?; - self.occur_inner(r, rhs) + (Or(l), Or(r)) | (And(l), And(r)) if l.len() == r.len() => { + let l = l.to_vec(); + let mut r = r.to_vec(); + for _ in 0..r.len() { + if l.iter() + .zip(r.iter()) + .all(|(l, r)| self.occur(l, r).is_ok()) + { + return Ok(()); + } + r.rotate_left(1); + } + Err(TyCheckErrors::from(TyCheckError::subtyping_error( + self.ctx.cfg.input.clone(), + line!() as usize, + maybe_sub, + maybe_sup, + self.loc.loc(), + self.ctx.caused_by(), + ))) + } + (lhs, Or(tys)) | (lhs, And(tys)) => { + for ty in tys.iter() { + self.occur(lhs, ty)?; + } + Ok(()) + } + (Or(tys), rhs) | (And(tys), rhs) => { + for ty in tys.iter() { + self.occur(ty, rhs)?; + } + Ok(()) } _ => Ok(()), } @@ -266,13 +287,17 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } Ok(()) } - (lhs, Or(l, r)) | (lhs, And(l, r)) => { - self.occur_inner(lhs, l)?; - self.occur_inner(lhs, r) + (lhs, Or(tys)) | (lhs, And(tys)) => { + for ty in tys.iter() { + self.occur_inner(lhs, ty)?; + } + Ok(()) } - (Or(l, r), rhs) | (And(l, r), rhs) => { - self.occur_inner(l, rhs)?; - self.occur_inner(r, rhs) + (Or(tys), rhs) | (And(tys), rhs) => { + for ty in tys.iter() { + self.occur_inner(ty, rhs)?; + } + Ok(()) } _ => Ok(()), } @@ -1186,35 +1211,42 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { // (Int or ?T) <: (?U or Int) // OK: (Int <: Int); (?T <: ?U) // NG: (Int <: ?U); (?T <: Int) - (Or(l1, r1), Or(l2, r2)) | (And(l1, r1), And(l2, r2)) => { - if self.ctx.subtype_of(l1, l2) && self.ctx.subtype_of(r1, r2) { - let (l_sup, r_sup) = if !l1.is_unbound_var() - && !r2.is_unbound_var() - && self.ctx.subtype_of(l1, r2) + (Or(ltys), Or(rtys)) | (And(ltys), And(rtys)) => { + let lvars = ltys.to_vec(); + let mut rvars = rtys.to_vec(); + for _ in 0..rvars.len() { + if lvars + .iter() + .zip(rvars.iter()) + .all(|(l, r)| self.ctx.subtype_of(l, r)) { - (r2, l2) - } else { - (l2, r2) - }; - self.sub_unify(l1, l_sup)?; - self.sub_unify(r1, r_sup)?; - } else { - self.sub_unify(l1, r2)?; - self.sub_unify(r1, l2)?; + for (l, r) in ltys.iter().zip(rtys.iter()) { + self.sub_unify(l, r)?; + } + break; + } + rvars.rotate_left(1); } + return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error( + self.ctx.cfg.input.clone(), + line!() as usize, + self.loc.loc(), + self.ctx.caused_by(), + self.param_name.as_ref().unwrap_or(&Str::ever("_")), + None, + maybe_sup, + maybe_sub, + self.ctx.get_candidates(maybe_sub), + self.ctx.get_simple_type_mismatch_hint(maybe_sup, maybe_sub), + ))); } // NG: Nat <: ?T or Int ==> Nat or Int (?T = Nat) // OK: Nat <: ?T or Int ==> ?T or Int - (sub, Or(l, r)) - if l.is_unbound_var() - && !sub.is_unbound_var() - && !r.is_unbound_var() - && self.ctx.subtype_of(sub, r) => {} - (sub, Or(l, r)) - if r.is_unbound_var() - && !sub.is_unbound_var() - && !l.is_unbound_var() - && self.ctx.subtype_of(sub, l) => {} + (sub, Or(tys)) + if !sub.is_unbound_var() + && tys + .iter() + .any(|ty| !ty.is_unbound_var() && self.ctx.subtype_of(sub, ty)) => {} // e.g. Structural({ .method = (self: T) -> Int })/T (Structural(sub), FreeVar(sup_fv)) if sup_fv.is_unbound() && sub.contains_tvar(sup_fv) => {} @@ -1578,30 +1610,36 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } } // (X or Y) <: Z is valid when X <: Z and Y <: Z - (Or(l, r), _) => { - self.sub_unify(l, maybe_sup)?; - self.sub_unify(r, maybe_sup)?; + (Or(tys), _) => { + for ty in tys { + self.sub_unify(ty, maybe_sup)?; + } } // X <: (Y and Z) is valid when X <: Y and X <: Z - (_, And(l, r)) => { - self.sub_unify(maybe_sub, l)?; - self.sub_unify(maybe_sub, r)?; + (_, And(tys)) => { + for ty in tys { + self.sub_unify(maybe_sub, ty)?; + } } // (X and Y) <: Z is valid when X <: Z or Y <: Z - (And(l, r), _) => { - if self.ctx.subtype_of(l, maybe_sup) { - self.sub_unify(l, maybe_sup)?; - } else { - self.sub_unify(r, maybe_sup)?; + (And(tys), _) => { + for ty in tys { + if self.ctx.subtype_of(ty, maybe_sup) { + self.sub_unify(ty, maybe_sup)?; + break; + } } + self.sub_unify(tys.iter().next().unwrap(), maybe_sup)?; } // X <: (Y or Z) is valid when X <: Y or X <: Z - (_, Or(l, r)) => { - if self.ctx.subtype_of(maybe_sub, l) { - self.sub_unify(maybe_sub, l)?; - } else { - self.sub_unify(maybe_sub, r)?; + (_, Or(tys)) => { + for ty in tys { + if self.ctx.subtype_of(maybe_sub, ty) { + self.sub_unify(maybe_sub, ty)?; + break; + } } + self.sub_unify(maybe_sub, tys.iter().next().unwrap())?; } (Ref(sub), Ref(sup)) => { self.sub_unify(sub, sup)?; @@ -1843,27 +1881,27 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { /// ``` fn unify(&self, lhs: &Type, rhs: &Type) -> Option { match (lhs, rhs) { - (Type::Or(l, r), other) | (other, Type::Or(l, r)) => { - if let Some(t) = self.unify(l, other) { - return self.unify(&t, l); - } else if let Some(t) = self.unify(r, other) { - return self.unify(&t, l); + (Or(tys), other) | (other, Or(tys)) => { + for ty in tys { + if let Some(t) = self.unify(ty, other) { + return self.unify(&t, ty); + } } return None; } - (Type::FreeVar(fv), _) if fv.is_linked() => return self.unify(&fv.crack(), rhs), - (_, Type::FreeVar(fv)) if fv.is_linked() => return self.unify(lhs, &fv.crack()), + (FreeVar(fv), _) if fv.is_linked() => return self.unify(&fv.crack(), rhs), + (_, FreeVar(fv)) if fv.is_linked() => return self.unify(lhs, &fv.crack()), // TODO: unify(?T, ?U) ? - (Type::FreeVar(_), Type::FreeVar(_)) => {} - (Type::FreeVar(fv), _) if fv.constraint_is_sandwiched() => { + (FreeVar(_), FreeVar(_)) => {} + (FreeVar(fv), _) if fv.constraint_is_sandwiched() => { let sub = fv.get_sub()?; return self.unify(&sub, rhs); } - (_, Type::FreeVar(fv)) if fv.constraint_is_sandwiched() => { + (_, FreeVar(fv)) if fv.constraint_is_sandwiched() => { let sub = fv.get_sub()?; return self.unify(lhs, &sub); } - (Type::Refinement(lhs), Type::Refinement(rhs)) => { + (Refinement(lhs), Refinement(rhs)) => { if let Some(_union) = self.unify(&lhs.t, &rhs.t) { return Some(self.ctx.union_refinement(lhs, rhs).into()); } diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index 04f45d4c7..948a49de4 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -1502,9 +1502,10 @@ impl GenericASTLowerer { } _ => {} }, - Type::And(lhs, rhs) => { - self.push_guard(nth, kind, lhs); - self.push_guard(nth, kind, rhs); + Type::And(tys) => { + for ty in tys { + self.push_guard(nth, kind, ty); + } } _ => {} } diff --git a/crates/erg_compiler/ty/constructors.rs b/crates/erg_compiler/ty/constructors.rs index 3e78dc33c..eb9ffd61d 100644 --- a/crates/erg_compiler/ty/constructors.rs +++ b/crates/erg_compiler/ty/constructors.rs @@ -593,35 +593,11 @@ pub fn refinement(var: Str, t: Type, pred: Predicate) -> Type { } pub fn and(lhs: Type, rhs: Type) -> Type { - match (lhs, rhs) { - (Type::And(l, r), other) | (other, Type::And(l, r)) => { - if l.as_ref() == &other { - and(*r, other) - } else if r.as_ref() == &other { - and(*l, other) - } else { - Type::And(Box::new(Type::And(l, r)), Box::new(other)) - } - } - (Type::Obj, other) | (other, Type::Obj) => other, - (lhs, rhs) => Type::And(Box::new(lhs), Box::new(rhs)), - } + lhs & rhs } pub fn or(lhs: Type, rhs: Type) -> Type { - match (lhs, rhs) { - (Type::Or(l, r), other) | (other, Type::Or(l, r)) => { - if l.as_ref() == &other { - or(*r, other) - } else if r.as_ref() == &other { - or(*l, other) - } else { - Type::Or(Box::new(Type::Or(l, r)), Box::new(other)) - } - } - (Type::Never, other) | (other, Type::Never) => other, - (lhs, rhs) => Type::Or(Box::new(lhs), Box::new(rhs)), - } + lhs | rhs } pub fn ors(tys: impl IntoIterator) -> Type { diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index 7f545ed14..c5c2d9344 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -1410,8 +1410,8 @@ pub enum Type { Refinement(RefinementType), // e.g. |T: Type| T -> T Quantified(Box), - And(Box, Box), - Or(Box, Box), + And(Set), + Or(Set), Not(Box), // NOTE: It was found that adding a new variant above `Poly` may cause a subtyping bug, // possibly related to enum internal numbering, but the cause is unknown. @@ -1504,8 +1504,8 @@ impl PartialEq for Type { (Self::NamedTuple(lhs), Self::NamedTuple(rhs)) => lhs == rhs, (Self::Refinement(l), Self::Refinement(r)) => l == r, (Self::Quantified(l), Self::Quantified(r)) => l == r, - (Self::And(_, _), Self::And(_, _)) => self.ands().linear_eq(&other.ands()), - (Self::Or(_, _), Self::Or(_, _)) => self.ors().linear_eq(&other.ors()), + (Self::And(l), Self::And(r)) => l.linear_eq(r), + (Self::Or(l), Self::Or(r)) => l.linear_eq(r), (Self::Not(l), Self::Not(r)) => l == r, ( Self::Poly { @@ -1659,20 +1659,28 @@ impl LimitedDisplay for Type { write!(f, "|")?; quantified.limited_fmt(f, limit - 1) } - Self::And(lhs, rhs) => { - lhs.limited_fmt(f, limit - 1)?; - write!(f, " and ")?; - rhs.limited_fmt(f, limit - 1) + Self::And(ands) => { + for (i, ty) in ands.iter().enumerate() { + if i > 0 { + write!(f, " and ")?; + } + ty.limited_fmt(f, limit - 1)?; + } + Ok(()) + } + Self::Or(ors) => { + for (i, ty) in ors.iter().enumerate() { + if i > 0 { + write!(f, " or ")?; + } + ty.limited_fmt(f, limit - 1)?; + } + Ok(()) } Self::Not(ty) => { write!(f, "not ")?; ty.limited_fmt(f, limit - 1) } - Self::Or(lhs, rhs) => { - lhs.limited_fmt(f, limit - 1)?; - write!(f, " or ")?; - rhs.limited_fmt(f, limit - 1) - } Self::Poly { name, params } => { write!(f, "{name}(")?; if !DEBUG_MODE && self.is_module() { @@ -1852,14 +1860,40 @@ impl From> for Type { impl BitAnd for Type { type Output = Self; fn bitand(self, rhs: Self) -> Self::Output { - Self::And(Box::new(self), Box::new(rhs)) + match (self, rhs) { + (Self::And(l), Self::And(r)) => Self::And(l.union(&r)), + (Self::Obj, other) | (other, Self::Obj) => other, + (Self::Never, _) | (_, Self::Never) => Self::Never, + (Self::And(mut l), r) => { + l.insert(r); + Self::And(l) + } + (l, Self::And(mut r)) => { + r.insert(l); + Self::And(r) + } + (l, r) => Self::checked_and(set! {l, r}), + } } } impl BitOr for Type { type Output = Self; fn bitor(self, rhs: Self) -> Self::Output { - Self::Or(Box::new(self), Box::new(rhs)) + match (self, rhs) { + (Self::Or(l), Self::Or(r)) => Self::Or(l.union(&r)), + (Self::Obj, _) | (_, Self::Obj) => Self::Obj, + (Self::Never, other) | (other, Self::Never) => other, + (Self::Or(mut l), r) => { + l.insert(r); + Self::Or(l) + } + (l, Self::Or(mut r)) => { + r.insert(l); + Self::Or(r) + } + (l, r) => Self::checked_or(set! {l, r}), + } } } @@ -1974,17 +2008,7 @@ impl HasLevel for Type { .filter_map(|o| *o) .min() } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - let l = lhs - .level() - .unwrap_or(GENERIC_LEVEL) - .min(rhs.level().unwrap_or(GENERIC_LEVEL)); - if l == GENERIC_LEVEL { - None - } else { - Some(l) - } - } + Self::And(tys) | Self::Or(tys) => tys.iter().filter_map(|t| t.level()).min(), Self::Not(ty) => ty.level(), Self::Record(attrs) => attrs.values().filter_map(|t| t.level()).min(), Self::NamedTuple(attrs) => attrs.iter().filter_map(|(_, t)| t.level()).min(), @@ -2065,9 +2089,10 @@ impl HasLevel for Type { Self::Quantified(quant) => { quant.set_level(level); } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.set_level(level); - rhs.set_level(level); + Self::And(tys) | Self::Or(tys) => { + for t in tys.iter() { + t.set_level(level); + } } Self::Not(ty) => ty.set_level(level), Self::Record(attrs) => { @@ -2218,9 +2243,7 @@ impl StructuralEq for Type { (Self::Guard(l), Self::Guard(r)) => l.structural_eq(r), // NG: (l.structural_eq(l2) && r.structural_eq(r2)) // || (l.structural_eq(r2) && r.structural_eq(l2)) - (Self::And(_, _), Self::And(_, _)) => { - let self_ands = self.ands(); - let other_ands = other.ands(); + (Self::And(self_ands), Self::And(other_ands)) => { if self_ands.len() != other_ands.len() { return false; } @@ -2234,9 +2257,7 @@ impl StructuralEq for Type { } true } - (Self::Or(_, _), Self::Or(_, _)) => { - let self_ors = self.ors(); - let other_ors = other.ors(); + (Self::Or(self_ors), Self::Or(other_ors)) => { if self_ors.len() != other_ors.len() { return false; } @@ -2330,10 +2351,30 @@ impl Type { } } + pub fn checked_or(tys: Set) -> Self { + if tys.is_empty() { + panic!("tys is empty"); + } else if tys.len() == 1 { + tys.into_iter().next().unwrap() + } else { + Self::Or(tys) + } + } + + pub fn checked_and(tys: Set) -> Self { + if tys.is_empty() { + panic!("tys is empty"); + } else if tys.len() == 1 { + tys.into_iter().next().unwrap() + } else { + Self::And(tys) + } + } + pub fn quantify(self) -> Self { debug_assert!(self.is_subr(), "{self} is not subr"); match self { - Self::And(lhs, rhs) => lhs.quantify() & rhs.quantify(), + Self::And(tys) => Self::And(tys.into_iter().map(|t| t.quantify()).collect()), other => Self::Quantified(Box::new(other)), } } @@ -2431,7 +2472,7 @@ impl Type { Self::Quantified(t) => t.is_procedure(), Self::Subr(subr) if subr.kind == SubrKind::Proc => true, Self::Refinement(refine) => refine.t.is_procedure(), - Self::And(lhs, rhs) => lhs.is_procedure() && rhs.is_procedure(), + Self::And(tys) => tys.iter().any(|t| t.is_procedure()), _ => false, } } @@ -2449,6 +2490,7 @@ impl Type { name.ends_with('!') } Self::Refinement(refine) => refine.t.is_mut_type(), + Self::And(tys) => tys.iter().any(|t| t.is_mut_type()), _ => false, } } @@ -2467,6 +2509,7 @@ impl Type { Self::Poly { name, params, .. } if &name[..] == "Tuple" => params.is_empty(), Self::Refinement(refine) => refine.t.is_nonelike(), Self::Bounded { sup, .. } => sup.is_nonelike(), + Self::And(tys) => tys.iter().any(|t| t.is_nonelike()), _ => false, } } @@ -2516,7 +2559,7 @@ impl Type { pub fn is_union_type(&self) -> bool { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_union_type(), - Self::Or(_, _) => true, + Self::Or(_) => true, Self::Refinement(refine) => refine.t.is_union_type(), _ => false, } @@ -2549,7 +2592,7 @@ impl Type { pub fn is_intersection_type(&self) -> bool { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_intersection_type(), - Self::And(_, _) => true, + Self::And(_) => true, Self::Refinement(refine) => refine.t.is_intersection_type(), _ => false, } @@ -2563,11 +2606,11 @@ impl Type { fv.do_avoiding_recursion(|| sub.union_size().max(sup.union_size())) } // Or(Or(Int, Str), Nat) == 3 - Self::Or(l, r) => l.union_size() + r.union_size(), + Self::Or(tys) => tys.len(), Self::Refinement(refine) => refine.t.union_size(), Self::Ref(t) => t.union_size(), Self::RefMut { before, after: _ } => before.union_size(), - Self::And(lhs, rhs) => lhs.union_size().max(rhs.union_size()), + Self::And(tys) => tys.iter().map(|ty| ty.union_size()).max().unwrap_or(1), Self::Not(ty) => ty.union_size(), Self::Callable { param_ts, return_t } => param_ts .iter() @@ -2608,7 +2651,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_refinement(), Self::Refinement(_) => true, - Self::And(l, r) => l.is_refinement() && r.is_refinement(), + Self::And(tys) => tys.iter().any(|t| t.is_refinement()), _ => false, } } @@ -2617,6 +2660,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_singleton_refinement(), Self::Refinement(refine) => matches!(refine.pred.as_ref(), Predicate::Equal { .. }), + Self::And(tys) => tys.iter().any(|t| t.is_singleton_refinement()), _ => false, } } @@ -2626,6 +2670,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_record(), Self::Record(_) => true, Self::Refinement(refine) => refine.t.is_record(), + Self::And(tys) => tys.iter().any(|t| t.is_record()), _ => false, } } @@ -2639,6 +2684,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_erg_module(), Self::Refinement(refine) => refine.t.is_erg_module(), Self::Poly { name, .. } => &name[..] == "Module", + Self::And(tys) => tys.iter().any(|t| t.is_erg_module()), _ => false, } } @@ -2648,6 +2694,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_py_module(), Self::Refinement(refine) => refine.t.is_py_module(), Self::Poly { name, .. } => &name[..] == "PyModule", + Self::And(tys) => tys.iter().any(|t| t.is_py_module()), _ => false, } } @@ -2658,7 +2705,7 @@ impl Type { Self::Refinement(refine) => refine.t.is_method(), Self::Subr(subr) => subr.is_method(), Self::Quantified(quant) => quant.is_method(), - Self::And(l, r) => l.is_method() && r.is_method(), + Self::And(tys) => tys.iter().any(|t| t.is_method()), _ => false, } } @@ -2669,7 +2716,7 @@ impl Type { Self::Subr(_) => true, Self::Quantified(quant) => quant.is_subr(), Self::Refinement(refine) => refine.t.is_subr(), - Self::And(l, r) => l.is_subr() && r.is_subr(), + Self::And(tys) => tys.iter().any(|t| t.is_subr()), _ => false, } } @@ -2680,7 +2727,10 @@ impl Type { Self::Subr(subr) => Some(subr.kind), Self::Refinement(refine) => refine.t.subr_kind(), Self::Quantified(quant) => quant.subr_kind(), - Self::And(l, r) => l.subr_kind().and_then(|k| r.subr_kind().map(|k2| k | k2)), + Self::And(tys) => tys + .iter() + .filter_map(|t| t.subr_kind()) + .fold(None, |a, b| Some(a? | b)), _ => None, } } @@ -2690,7 +2740,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_quantified_subr(), Self::Quantified(_) => true, Self::Refinement(refine) => refine.t.is_quantified_subr(), - Self::And(l, r) => l.is_quantified_subr() && r.is_quantified_subr(), + Self::And(tys) => tys.iter().any(|t| t.is_quantified_subr()), _ => false, } } @@ -2727,6 +2777,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_iterable(), Self::Poly { name, .. } => &name[..] == "Iterable", Self::Refinement(refine) => refine.t.is_iterable(), + Self::And(tys) => tys.iter().any(|t| t.is_iterable()), _ => false, } } @@ -2817,6 +2868,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_structural(), Self::Structural(_) => true, Self::Refinement(refine) => refine.t.is_structural(), + Self::And(tys) => tys.iter().any(|t| t.is_structural()), _ => false, } } @@ -2826,6 +2878,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_failure(), Self::Refinement(refine) => refine.t.is_failure(), Self::Failure => true, + Self::And(tys) => tys.iter().any(|t| t.is_failure()), _ => false, } } @@ -2902,8 +2955,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.contains_tvar(target) || args.iter().any(|t| t.contains_tvar(target)) } - Self::And(lhs, rhs) => lhs.contains_tvar(target) || rhs.contains_tvar(target), - Self::Or(lhs, rhs) => lhs.contains_tvar(target) || rhs.contains_tvar(target), + Self::And(tys) => tys.iter().any(|t| t.contains_tvar(target)), + Self::Or(tys) => tys.iter().any(|t| t.contains_tvar(target)), Self::Not(t) => t.contains_tvar(target), Self::Ref(t) => t.contains_tvar(target), Self::RefMut { before, after } => { @@ -2962,8 +3015,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.has_type_satisfies(f) || args.iter().any(|t| t.has_type_satisfies(f)) } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.has_type_satisfies(f) || rhs.has_type_satisfies(f) + Self::And(tys) | Self::Or(tys) => { + tys.iter().any(|t| t.has_type_satisfies(f)) } Self::Not(t) => t.has_type_satisfies(f), Self::Ref(t) => t.has_type_satisfies(f), @@ -3030,8 +3083,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.contains_type(target) || args.iter().any(|t| t.contains_type(target)) } - Self::And(lhs, rhs) => lhs.contains_type(target) || rhs.contains_type(target), - Self::Or(lhs, rhs) => lhs.contains_type(target) || rhs.contains_type(target), + Self::And(tys) => tys.iter().any(|t| t.contains_type(target)), + Self::Or(tys) => tys.iter().any(|t| t.contains_type(target)), Self::Not(t) => t.contains_type(target), Self::Ref(t) => t.contains_type(target), Self::RefMut { before, after } => { @@ -3068,8 +3121,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.contains_tp(target) || args.iter().any(|t| t.contains_tp(target)) } - Self::And(lhs, rhs) => lhs.contains_tp(target) || rhs.contains_tp(target), - Self::Or(lhs, rhs) => lhs.contains_tp(target) || rhs.contains_tp(target), + Self::And(tys) => tys.iter().any(|t| t.contains_tp(target)), + Self::Or(tys) => tys.iter().any(|t| t.contains_tp(target)), Self::Not(t) => t.contains_tp(target), Self::Ref(t) => t.contains_tp(target), Self::RefMut { before, after } => { @@ -3102,8 +3155,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.contains_value(target) || args.iter().any(|t| t.contains_value(target)) } - Self::And(lhs, rhs) => lhs.contains_value(target) || rhs.contains_value(target), - Self::Or(lhs, rhs) => lhs.contains_value(target) || rhs.contains_value(target), + Self::And(tys) => tys.iter().any(|t| t.contains_value(target)), + Self::Or(tys) => tys.iter().any(|t| t.contains_value(target)), Self::Not(t) => t.contains_value(target), Self::Ref(t) => t.contains_value(target), Self::RefMut { before, after } => { @@ -3146,9 +3199,7 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.contains_type(self) || args.iter().any(|t| t.contains_type(self)) } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.contains_type(self) || rhs.contains_type(self) - } + Self::And(tys) | Self::Or(tys) => tys.iter().any(|t| t.contains_type(self)), Self::Not(t) => t.contains_type(self), Self::Ref(t) => t.contains_type(self), Self::RefMut { before, after } => { @@ -3218,9 +3269,9 @@ impl Type { Self::Inf => Str::ever("Inf"), Self::NegInf => Str::ever("NegInf"), Self::Mono(name) => name.clone(), - Self::And(_, _) => Str::ever("And"), + Self::And(_) => Str::ever("And"), Self::Not(_) => Str::ever("Not"), - Self::Or(_, _) => Str::ever("Or"), + Self::Or(_) => Str::ever("Or"), Self::Ref(_) => Str::ever("Ref"), Self::RefMut { .. } => Str::ever("RefMut"), Self::Subr(SubrType { @@ -3310,7 +3361,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().contains_intersec(typ), Self::Refinement(refine) => refine.t.contains_intersec(typ), - Self::And(t1, t2) => t1.contains_intersec(typ) || t2.contains_intersec(typ), + Self::And(tys) => tys.iter().any(|t| t.contains_intersec(typ)), _ => self == typ, } } @@ -3319,7 +3370,16 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().union_pair(), Self::Refinement(refine) => refine.t.union_pair(), - Self::Or(t1, t2) => Some((*t1.clone(), *t2.clone())), + Self::Or(tys) if tys.len() == 2 => { + let mut iter = tys.iter(); + Some((iter.next().unwrap().clone(), iter.next().unwrap().clone())) + } + Self::Or(tys) => { + let mut iter = tys.iter(); + let t1 = iter.next().unwrap().clone(); + let t2 = iter.cloned().collect(); + Some((t1, Type::Or(t2))) + } _ => None, } } @@ -3329,7 +3389,7 @@ impl Type { match self { Type::FreeVar(fv) if fv.is_linked() => fv.crack().contains_union(typ), Type::Refinement(refine) => refine.t.contains_union(typ), - Type::Or(t1, t2) => t1.contains_union(typ) || t2.contains_union(typ), + Type::Or(tys) => tys.iter().any(|t| t.contains_union(typ)), _ => self == typ, } } @@ -3343,11 +3403,7 @@ impl Type { .into_iter() .map(|t| t.quantify()) .collect(), - Type::And(t1, t2) => { - let mut types = t1.intersection_types(); - types.extend(t2.intersection_types()); - types - } + Type::And(tys) => tys.iter().cloned().collect(), _ => vec![self.clone()], } } @@ -3423,9 +3479,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_unbound() => true, Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_totally_unbound(), - Self::Or(t1, t2) | Self::And(t1, t2) => { - t1.is_totally_unbound() && t2.is_totally_unbound() - } + Self::Or(tys) | Self::And(tys) => tys.iter().all(|t| t.is_totally_unbound()), Self::Not(t) => t.is_totally_unbound(), _ => false, } @@ -3532,9 +3586,10 @@ impl Type { sub.destructive_coerce(); self.destructive_link(&sub); } - Type::And(l, r) | Type::Or(l, r) => { - l.destructive_coerce(); - r.destructive_coerce(); + Type::And(tys) | Type::Or(tys) => { + for t in tys { + t.destructive_coerce(); + } } Type::Not(l) => l.destructive_coerce(), Type::Poly { params, .. } => { @@ -3587,9 +3642,10 @@ impl Type { sub.undoable_coerce(list); self.undoable_link(&sub, list); } - Type::And(l, r) | Type::Or(l, r) => { - l.undoable_coerce(list); - r.undoable_coerce(list); + Type::And(tys) | Type::Or(tys) => { + for t in tys { + t.undoable_coerce(list); + } } Type::Not(l) => l.undoable_coerce(list), Type::Poly { params, .. } => { @@ -3648,7 +3704,9 @@ impl Type { .map(|t| t.qvars_inner()) .unwrap_or_else(|| set! {}), ), - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.qvars_inner().concat(rhs.qvars_inner()), + Self::And(tys) | Self::Or(tys) => tys + .iter() + .fold(set! {}, |acc, t| acc.concat(t.qvars_inner())), Self::Not(ty) => ty.qvars_inner(), Self::Callable { param_ts, return_t } => param_ts .iter() @@ -3716,7 +3774,7 @@ impl Type { Self::RefMut { before, after } => { before.has_qvar() || after.as_ref().map(|t| t.has_qvar()).unwrap_or(false) } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.has_qvar() || rhs.has_qvar(), + Self::And(tys) | Self::Or(tys) => tys.iter().any(|t| t.has_qvar()), Self::Not(ty) => ty.has_qvar(), Self::Callable { param_ts, return_t } => { param_ts.iter().any(|t| t.has_qvar()) || return_t.has_qvar() @@ -3769,9 +3827,7 @@ impl Type { .map(|t| t.has_undoable_linked_var()) .unwrap_or(false) } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.has_undoable_linked_var() || rhs.has_undoable_linked_var() - } + Self::And(tys) | Self::Or(tys) => tys.iter().any(|t| t.has_undoable_linked_var()), Self::Not(ty) => ty.has_undoable_linked_var(), Self::Callable { param_ts, return_t } => { param_ts.iter().any(|t| t.has_undoable_linked_var()) @@ -3810,9 +3866,7 @@ impl Type { before.has_unbound_var() || after.as_ref().map(|t| t.has_unbound_var()).unwrap_or(false) } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.has_unbound_var() || rhs.has_unbound_var() - } + Self::And(tys) | Self::Or(tys) => tys.iter().any(|t| t.has_unbound_var()), Self::Not(ty) => ty.has_unbound_var(), Self::Callable { param_ts, return_t } => { param_ts.iter().any(|t| t.has_unbound_var()) || return_t.has_unbound_var() @@ -3867,7 +3921,7 @@ impl Type { Self::Refinement(refine) => refine.t.typarams_len(), // REVIEW: Self::Ref(_) | Self::RefMut { .. } => Some(1), - Self::And(_, _) | Self::Or(_, _) => Some(2), + Self::And(tys) | Self::Or(tys) => Some(tys.len()), Self::Not(_) => Some(1), Self::Subr(subr) => Some( subr.non_default_params.len() @@ -3933,9 +3987,7 @@ impl Type { Self::FreeVar(_unbound) => vec![], Self::Refinement(refine) => refine.t.typarams(), Self::Ref(t) | Self::RefMut { before: t, .. } => vec![TyParam::t(*t.clone())], - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - vec![TyParam::t(*lhs.clone()), TyParam::t(*rhs.clone())] - } + Self::And(tys) | Self::Or(tys) => tys.iter().cloned().map(TyParam::t).collect(), Self::Not(t) => vec![TyParam::t(*t.clone())], Self::Subr(subr) => subr.typarams(), Self::Quantified(quant) => quant.typarams(), @@ -4156,8 +4208,8 @@ impl Type { let r = r.iter().map(|(k, v)| (k.clone(), v.derefine())).collect(); Self::NamedTuple(r) } - Self::And(l, r) => l.derefine() & r.derefine(), - Self::Or(l, r) => l.derefine() | r.derefine(), + Self::And(tys) => Self::checked_and(tys.iter().map(|t| t.derefine()).collect()), + Self::Or(tys) => Self::checked_or(tys.iter().map(|t| t.derefine()).collect()), Self::Not(ty) => !ty.derefine(), Self::Proj { lhs, rhs } => lhs.derefine().proj(rhs.clone()), Self::ProjCall { @@ -4224,22 +4276,28 @@ impl Type { }); self } - Self::And(l, r) => { - if l.addr_eq(target) { - return r.eliminate_subsup(target); - } else if r.addr_eq(target) { - return l.eliminate_subsup(target); - } - l.eliminate_subsup(target) & r.eliminate_subsup(target) - } - Self::Or(l, r) => { - if l.addr_eq(target) { - return r.eliminate_subsup(target); - } else if r.addr_eq(target) { - return l.eliminate_subsup(target); - } - l.eliminate_subsup(target) | r.eliminate_subsup(target) - } + Self::And(tys) => Self::checked_and( + tys.into_iter() + .filter_map(|t| { + if t.addr_eq(target) { + None + } else { + Some(t.eliminate_subsup(target)) + } + }) + .collect(), + ), + Self::Or(tys) => Self::checked_or( + tys.into_iter() + .filter_map(|t| { + if t.addr_eq(target) { + None + } else { + Some(t.eliminate_subsup(target)) + } + }) + .collect(), + ), other => other, } } @@ -4294,8 +4352,16 @@ impl Type { before: Box::new(before.eliminate_recursion(target)), after: after.map(|t| Box::new(t.eliminate_recursion(target))), }, - Self::And(l, r) => l.eliminate_recursion(target) & r.eliminate_recursion(target), - Self::Or(l, r) => l.eliminate_recursion(target) | r.eliminate_recursion(target), + Self::And(tys) => Self::checked_and( + tys.into_iter() + .map(|t| t.eliminate_recursion(target)) + .collect(), + ), + Self::Or(tys) => Self::checked_or( + tys.into_iter() + .map(|t| t.eliminate_recursion(target)) + .collect(), + ), Self::Not(ty) => !ty.eliminate_recursion(target), Self::Proj { lhs, rhs } => lhs.eliminate_recursion(target).proj(rhs), Self::ProjCall { @@ -4448,8 +4514,8 @@ impl Type { before: Box::new(before.map(f)), after: after.map(|t| Box::new(t.map(f))), }, - Self::And(l, r) => l.map(f) & r.map(f), - Self::Or(l, r) => l.map(f) | r.map(f), + Self::And(tys) => Self::checked_and(tys.into_iter().map(|t| t.map(f)).collect()), + Self::Or(tys) => Self::checked_or(tys.into_iter().map(|t| t.map(f)).collect()), Self::Not(ty) => !ty.map(f), Self::Proj { lhs, rhs } => lhs.map(f).proj(rhs), Self::ProjCall { @@ -4542,8 +4608,12 @@ impl Type { before: Box::new(before._replace_tp(target, to)), after: after.map(|t| Box::new(t._replace_tp(target, to))), }, - Self::And(l, r) => l._replace_tp(target, to) & r._replace_tp(target, to), - Self::Or(l, r) => l._replace_tp(target, to) | r._replace_tp(target, to), + Self::And(tys) => { + Self::checked_and(tys.into_iter().map(|t| t._replace_tp(target, to)).collect()) + } + Self::Or(tys) => { + Self::checked_or(tys.into_iter().map(|t| t._replace_tp(target, to)).collect()) + } Self::Not(ty) => !ty._replace_tp(target, to), Self::Proj { lhs, rhs } => lhs._replace_tp(target, to).proj(rhs), Self::ProjCall { @@ -4619,8 +4689,8 @@ impl Type { before: Box::new(before.map_tp(f)), after: after.map(|t| Box::new(t.map_tp(f))), }, - Self::And(l, r) => l.map_tp(f) & r.map_tp(f), - Self::Or(l, r) => l.map_tp(f) | r.map_tp(f), + Self::And(tys) => Self::checked_and(tys.into_iter().map(|t| t.map_tp(f)).collect()), + Self::Or(tys) => Self::checked_or(tys.into_iter().map(|t| t.map_tp(f)).collect()), Self::Not(ty) => !ty.map_tp(f), Self::Proj { lhs, rhs } => lhs.map_tp(f).proj(rhs), Self::ProjCall { @@ -4708,8 +4778,16 @@ impl Type { after, }) } - Self::And(l, r) => Ok(l.try_map_tp(f)? & r.try_map_tp(f)?), - Self::Or(l, r) => Ok(l.try_map_tp(f)? | r.try_map_tp(f)?), + Self::And(tys) => Ok(Self::checked_and( + tys.into_iter() + .map(|t| t.try_map_tp(f)) + .collect::>()?, + )), + Self::Or(tys) => Ok(Self::checked_or( + tys.into_iter() + .map(|t| t.try_map_tp(f)) + .collect::>()?, + )), Self::Not(ty) => Ok(!ty.try_map_tp(f)?), Self::Proj { lhs, rhs } => Ok(lhs.try_map_tp(f)?.proj(rhs)), Self::ProjCall { @@ -4742,12 +4820,28 @@ impl Type { *refine.t = refine.t.replace_param(target, to); Self::Refinement(refine) } - Self::And(l, r) => l.replace_param(target, to) & r.replace_param(target, to), + Self::And(tys) => Self::And( + tys.into_iter() + .map(|t| t.replace_param(target, to)) + .collect(), + ), Self::Guard(guard) => Self::Guard(guard.replace_param(target, to)), _ => self, } } + pub fn eliminate_and_or(&mut self) { + match self { + Self::And(tys) if tys.len() == 1 => { + *self = tys.take_all().into_iter().next().unwrap(); + } + Self::Or(tys) if tys.len() == 1 => { + *self = tys.take_all().into_iter().next().unwrap(); + } + _ => {} + } + } + pub fn replace_params<'l, 'r>( mut self, target: impl Iterator, @@ -4815,8 +4909,8 @@ impl Type { } Self::NamedTuple(r) } - Self::And(l, r) => l.normalize() & r.normalize(), - Self::Or(l, r) => l.normalize() | r.normalize(), + Self::And(tys) => Self::checked_and(tys.into_iter().map(|t| t.normalize()).collect()), + Self::Or(tys) => Self::checked_or(tys.into_iter().map(|t| t.normalize()).collect()), Self::Not(ty) => !ty.normalize(), Self::Structural(ty) => ty.normalize().structuralize(), Self::Quantified(quant) => quant.normalize().quantify(), @@ -4848,14 +4942,36 @@ impl Type { free.get_sub().unwrap_or(self.clone()) } else { match self { - Self::And(l, r) => l.lower_bounded() & r.lower_bounded(), - Self::Or(l, r) => l.lower_bounded() | r.lower_bounded(), + Self::And(tys) => { + Self::checked_and(tys.iter().map(|t| t.lower_bounded()).collect()) + } + Self::Or(tys) => Self::checked_or(tys.iter().map(|t| t.lower_bounded()).collect()), Self::Not(ty) => !ty.lower_bounded(), _ => self.clone(), } } } + /// ```erg + /// assert Int.upper_bounded() == Int + /// assert ?T(<: Str).upper_bounded() == Str + /// assert (?T(<: Str) or ?U(<: Int)).upper_bounded() == (Str or Int) + /// ``` + pub fn upper_bounded(&self) -> Type { + if let Ok(free) = <&FreeTyVar>::try_from(self) { + free.get_super().unwrap_or(self.clone()) + } else { + match self { + Self::And(tys) => { + Self::checked_and(tys.iter().map(|t| t.upper_bounded()).collect()) + } + Self::Or(tys) => Self::checked_or(tys.iter().map(|t| t.upper_bounded()).collect()), + Self::Not(ty) => !ty.upper_bounded(), + _ => self.clone(), + } + } + } + pub(crate) fn addr_eq(&self, other: &Type) -> bool { match (self, other) { (Self::FreeVar(slf), _) if slf.is_linked() => slf.crack().addr_eq(other), @@ -5018,7 +5134,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().ands(), Self::Refinement(refine) => refine.t.ands(), - Self::And(l, r) => l.ands().union(&r.ands()), + Self::And(tys) => tys.clone(), _ => set![self.clone()], } } @@ -5031,7 +5147,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().ors(), Self::Refinement(refine) => refine.t.ors(), - Self::Or(l, r) => l.ors().union(&r.ors()), + Self::Or(tys) => tys.clone(), _ => set![self.clone()], } } @@ -5076,7 +5192,7 @@ impl Type { Self::Callable { param_ts, .. } => { param_ts.iter().flat_map(|t| t.contained_ts()).collect() } - Self::And(l, r) | Self::Or(l, r) => l.contained_ts().union(&r.contained_ts()), + Self::And(tys) | Self::Or(tys) => tys.iter().flat_map(|t| t.contained_ts()).collect(), Self::Not(t) => t.contained_ts(), Self::Bounded { sub, sup } => sub.contained_ts().union(&sup.contained_ts()), Self::Quantified(ty) | Self::Structural(ty) => ty.contained_ts(), @@ -5145,9 +5261,18 @@ impl Type { } return_t.dereference(); } - Self::And(l, r) | Self::Or(l, r) => { - l.dereference(); - r.dereference(); + Self::And(tys) | Self::Or(tys) => { + *tys = tys + .take_all() + .into_iter() + .map(|mut t| { + t.dereference(); + t + }) + .collect(); + if tys.len() == 1 { + *self = tys.take_all().into_iter().next().unwrap(); + } } Self::Not(ty) => { ty.dereference(); @@ -5255,7 +5380,7 @@ impl Type { set.extend(return_t.variables()); set } - Self::And(l, r) | Self::Or(l, r) => l.variables().union(&r.variables()), + Self::And(tys) | Self::Or(tys) => tys.iter().flat_map(|t| t.variables()).collect(), Self::Not(ty) => ty.variables(), Self::Bounded { sub, sup } => sub.variables().union(&sup.variables()), Self::Quantified(ty) | Self::Structural(ty) => ty.variables(), @@ -5367,13 +5492,16 @@ impl<'t> ReplaceTable<'t> { self.iterate(l, r); } } - (Type::And(l, r), Type::And(l2, r2)) => { - self.iterate(l, l2); - self.iterate(r, r2); + // FIXME: + (Type::And(tys), Type::And(tys2)) => { + for (l, r) in tys.iter().zip(tys2.iter()) { + self.iterate(l, r); + } } - (Type::Or(l, r), Type::Or(l2, r2)) => { - self.iterate(l, l2); - self.iterate(r, r2); + (Type::Or(tys), Type::Or(tys2)) => { + for (l, r) in tys.iter().zip(tys2.iter()) { + self.iterate(l, r); + } } (Type::Not(t), Type::Not(t2)) => { self.iterate(t, t2); From 3b9bbdf1a5871c7a93df16fc33f874d7e195ce2b Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sun, 15 Sep 2024 17:11:06 +0900 Subject: [PATCH 02/12] fix: union type bug --- crates/erg_compiler/context/compare.rs | 6 +- crates/erg_compiler/context/unify.rs | 104 +++++++++++++++++++++---- crates/erg_compiler/ty/mod.rs | 99 ++++++++++++++++------- 3 files changed, 163 insertions(+), 46 deletions(-) diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index 746a7e612..1d3d5da97 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -1778,13 +1778,15 @@ impl Context { /// intersection_add(Int and ?T(:> NoneType), Str) == Never /// ``` fn intersection_add(&self, intersection: &Type, elem: &Type) -> Type { - let ands = intersection.ands(); + let mut ands = intersection.ands(); let bounded = ands.iter().map(|t| t.lower_bounded()); for t in bounded { if self.subtype_of(&t, elem) { return intersection.clone(); } else if self.supertype_of(&t, elem) { - return constructors::ands(ands.linear_exclude(&t).include(elem.clone())); + ands.retain(|ty| ty != &t); + ands.push(elem.clone()); + return constructors::ands(ands); } } and(intersection.clone(), elem.clone()) diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index 10afe974c..d6657a90c 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -155,7 +155,27 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } Ok(()) } - (Or(l), Or(r)) | (And(l), And(r)) if l.len() == r.len() => { + (And(l), And(r)) if l.len() == r.len() => { + let mut r = r.clone(); + for _ in 0..r.len() { + if l.iter() + .zip(r.iter()) + .all(|(l, r)| self.occur(l, r).is_ok()) + { + return Ok(()); + } + r.rotate_left(1); + } + Err(TyCheckErrors::from(TyCheckError::subtyping_error( + self.ctx.cfg.input.clone(), + line!() as usize, + maybe_sub, + maybe_sup, + self.loc.loc(), + self.ctx.caused_by(), + ))) + } + (Or(l), Or(r)) if l.len() == r.len() => { let l = l.to_vec(); let mut r = r.to_vec(); for _ in 0..r.len() { @@ -176,13 +196,25 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { self.ctx.caused_by(), ))) } - (lhs, Or(tys)) | (lhs, And(tys)) => { + (lhs, And(tys)) => { + for ty in tys.iter() { + self.occur(lhs, ty)?; + } + Ok(()) + } + (lhs, Or(tys)) => { for ty in tys.iter() { self.occur(lhs, ty)?; } Ok(()) } - (Or(tys), rhs) | (And(tys), rhs) => { + (And(tys), rhs) => { + for ty in tys.iter() { + self.occur(ty, rhs)?; + } + Ok(()) + } + (Or(tys), rhs) => { for ty in tys.iter() { self.occur(ty, rhs)?; } @@ -287,13 +319,25 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } Ok(()) } - (lhs, Or(tys)) | (lhs, And(tys)) => { + (lhs, And(tys)) => { for ty in tys.iter() { self.occur_inner(lhs, ty)?; } Ok(()) } - (Or(tys), rhs) | (And(tys), rhs) => { + (lhs, Or(tys)) => { + for ty in tys.iter() { + self.occur_inner(lhs, ty)?; + } + Ok(()) + } + (And(tys), rhs) => { + for ty in tys.iter() { + self.occur_inner(ty, rhs)?; + } + Ok(()) + } + (Or(tys), rhs) => { for ty in tys.iter() { self.occur_inner(ty, rhs)?; } @@ -1208,24 +1252,52 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { // self.sub_unify(&lsub, &union, loc, param_name)?; maybe_sup.update_tyvar(union, intersec, self.undoable, false); } + (And(ltys), And(rtys)) => { + let mut rtys = rtys.clone(); + for _ in 0..rtys.len() { + if ltys + .iter() + .zip(rtys.iter()) + .all(|(l, r)| self.ctx.subtype_of(l, r)) + { + for (l, r) in ltys.iter().zip(rtys.iter()) { + self.sub_unify(l, r)?; + } + return Ok(()); + } + rtys.rotate_left(1); + } + return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error( + self.ctx.cfg.input.clone(), + line!() as usize, + self.loc.loc(), + self.ctx.caused_by(), + self.param_name.as_ref().unwrap_or(&Str::ever("_")), + None, + maybe_sup, + maybe_sub, + self.ctx.get_candidates(maybe_sub), + self.ctx.get_simple_type_mismatch_hint(maybe_sup, maybe_sub), + ))); + } // (Int or ?T) <: (?U or Int) // OK: (Int <: Int); (?T <: ?U) // NG: (Int <: ?U); (?T <: Int) - (Or(ltys), Or(rtys)) | (And(ltys), And(rtys)) => { - let lvars = ltys.to_vec(); - let mut rvars = rtys.to_vec(); - for _ in 0..rvars.len() { - if lvars + (Or(ltys), Or(rtys)) => { + let ltys = ltys.to_vec(); + let mut rtys = rtys.to_vec(); + for _ in 0..rtys.len() { + if ltys .iter() - .zip(rvars.iter()) + .zip(rtys.iter()) .all(|(l, r)| self.ctx.subtype_of(l, r)) { for (l, r) in ltys.iter().zip(rtys.iter()) { self.sub_unify(l, r)?; } - break; + return Ok(()); } - rvars.rotate_left(1); + rtys.rotate_left(1); } return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error( self.ctx.cfg.input.clone(), @@ -1625,8 +1697,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { (And(tys), _) => { for ty in tys { if self.ctx.subtype_of(ty, maybe_sup) { - self.sub_unify(ty, maybe_sup)?; - break; + return self.sub_unify(ty, maybe_sup); } } self.sub_unify(tys.iter().next().unwrap(), maybe_sup)?; @@ -1635,8 +1706,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { (_, Or(tys)) => { for ty in tys { if self.ctx.subtype_of(maybe_sub, ty) { - self.sub_unify(maybe_sub, ty)?; - break; + return self.sub_unify(maybe_sub, ty); } } self.sub_unify(maybe_sub, tys.iter().next().unwrap())?; diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index c5c2d9344..f44f57f01 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -1410,7 +1410,7 @@ pub enum Type { Refinement(RefinementType), // e.g. |T: Type| T -> T Quantified(Box), - And(Set), + And(Vec), Or(Set), Not(Box), // NOTE: It was found that adding a new variant above `Poly` may cause a subtyping bug, @@ -1504,7 +1504,7 @@ impl PartialEq for Type { (Self::NamedTuple(lhs), Self::NamedTuple(rhs)) => lhs == rhs, (Self::Refinement(l), Self::Refinement(r)) => l == r, (Self::Quantified(l), Self::Quantified(r)) => l == r, - (Self::And(l), Self::And(r)) => l.linear_eq(r), + (Self::And(l), Self::And(r)) => l.iter().collect::>().linear_eq(&r.iter().collect()), (Self::Or(l), Self::Or(r)) => l.linear_eq(r), (Self::Not(l), Self::Not(r)) => l == r, ( @@ -1861,18 +1861,18 @@ impl BitAnd for Type { type Output = Self; fn bitand(self, rhs: Self) -> Self::Output { match (self, rhs) { - (Self::And(l), Self::And(r)) => Self::And(l.union(&r)), + (Self::And(l), Self::And(r)) => Self::And([l, r].concat()), (Self::Obj, other) | (other, Self::Obj) => other, (Self::Never, _) | (_, Self::Never) => Self::Never, (Self::And(mut l), r) => { - l.insert(r); + l.push(r); Self::And(l) } (l, Self::And(mut r)) => { - r.insert(l); + r.push(l); Self::And(r) } - (l, r) => Self::checked_and(set! {l, r}), + (l, r) => Self::checked_and(vec! {l, r}), } } } @@ -2008,7 +2008,8 @@ impl HasLevel for Type { .filter_map(|o| *o) .min() } - Self::And(tys) | Self::Or(tys) => tys.iter().filter_map(|t| t.level()).min(), + Self::And(tys) => tys.iter().filter_map(|t| t.level()).min(), + Self::Or(tys) => tys.iter().filter_map(|t| t.level()).min(), Self::Not(ty) => ty.level(), Self::Record(attrs) => attrs.values().filter_map(|t| t.level()).min(), Self::NamedTuple(attrs) => attrs.iter().filter_map(|(_, t)| t.level()).min(), @@ -2089,7 +2090,12 @@ impl HasLevel for Type { Self::Quantified(quant) => { quant.set_level(level); } - Self::And(tys) | Self::Or(tys) => { + Self::And(tys) => { + for t in tys.iter() { + t.set_level(level); + } + } + Self::Or(tys) => { for t in tys.iter() { t.set_level(level); } @@ -2244,6 +2250,8 @@ impl StructuralEq for Type { // NG: (l.structural_eq(l2) && r.structural_eq(r2)) // || (l.structural_eq(r2) && r.structural_eq(l2)) (Self::And(self_ands), Self::And(other_ands)) => { + let self_ands = self_ands.iter().collect::>(); + let other_ands = other_ands.iter().collect::>(); if self_ands.len() != other_ands.len() { return false; } @@ -2361,7 +2369,7 @@ impl Type { } } - pub fn checked_and(tys: Set) -> Self { + pub fn checked_and(tys: Vec) -> Self { if tys.is_empty() { panic!("tys is empty"); } else if tys.len() == 1 { @@ -3015,7 +3023,10 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.has_type_satisfies(f) || args.iter().any(|t| t.has_type_satisfies(f)) } - Self::And(tys) | Self::Or(tys) => { + Self::And(tys) => { + tys.iter().any(|t| t.has_type_satisfies(f)) + } + Self::Or(tys) => { tys.iter().any(|t| t.has_type_satisfies(f)) } Self::Not(t) => t.has_type_satisfies(f), @@ -3199,7 +3210,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.contains_type(self) || args.iter().any(|t| t.contains_type(self)) } - Self::And(tys) | Self::Or(tys) => tys.iter().any(|t| t.contains_type(self)), + Self::And(tys) => tys.iter().any(|t| t.contains_type(self)), + Self::Or(tys) => tys.iter().any(|t| t.contains_type(self)), Self::Not(t) => t.contains_type(self), Self::Ref(t) => t.contains_type(self), Self::RefMut { before, after } => { @@ -3403,7 +3415,7 @@ impl Type { .into_iter() .map(|t| t.quantify()) .collect(), - Type::And(tys) => tys.iter().cloned().collect(), + Type::And(tys) => tys.clone(), _ => vec![self.clone()], } } @@ -3479,7 +3491,8 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_unbound() => true, Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_totally_unbound(), - Self::Or(tys) | Self::And(tys) => tys.iter().all(|t| t.is_totally_unbound()), + Self::And(tys) => tys.iter().all(|t| t.is_totally_unbound()), + Self::Or(tys) => tys.iter().all(|t| t.is_totally_unbound()), Self::Not(t) => t.is_totally_unbound(), _ => false, } @@ -3586,7 +3599,12 @@ impl Type { sub.destructive_coerce(); self.destructive_link(&sub); } - Type::And(tys) | Type::Or(tys) => { + Type::And(tys) => { + for t in tys { + t.destructive_coerce(); + } + } + Type::Or(tys) => { for t in tys { t.destructive_coerce(); } @@ -3642,7 +3660,12 @@ impl Type { sub.undoable_coerce(list); self.undoable_link(&sub, list); } - Type::And(tys) | Type::Or(tys) => { + Type::And(tys) => { + for t in tys { + t.undoable_coerce(list); + } + } + Type::Or(tys) => { for t in tys { t.undoable_coerce(list); } @@ -3704,7 +3727,10 @@ impl Type { .map(|t| t.qvars_inner()) .unwrap_or_else(|| set! {}), ), - Self::And(tys) | Self::Or(tys) => tys + Self::And(tys) => tys + .iter() + .fold(set! {}, |acc, t| acc.concat(t.qvars_inner())), + Self::Or(tys) => tys .iter() .fold(set! {}, |acc, t| acc.concat(t.qvars_inner())), Self::Not(ty) => ty.qvars_inner(), @@ -3774,7 +3800,8 @@ impl Type { Self::RefMut { before, after } => { before.has_qvar() || after.as_ref().map(|t| t.has_qvar()).unwrap_or(false) } - Self::And(tys) | Self::Or(tys) => tys.iter().any(|t| t.has_qvar()), + Self::And(tys) => tys.iter().any(|t| t.has_qvar()), + Self::Or(tys) => tys.iter().any(|t| t.has_qvar()), Self::Not(ty) => ty.has_qvar(), Self::Callable { param_ts, return_t } => { param_ts.iter().any(|t| t.has_qvar()) || return_t.has_qvar() @@ -3827,7 +3854,8 @@ impl Type { .map(|t| t.has_undoable_linked_var()) .unwrap_or(false) } - Self::And(tys) | Self::Or(tys) => tys.iter().any(|t| t.has_undoable_linked_var()), + Self::And(tys) => tys.iter().any(|t| t.has_undoable_linked_var()), + Self::Or(tys) => tys.iter().any(|t| t.has_undoable_linked_var()), Self::Not(ty) => ty.has_undoable_linked_var(), Self::Callable { param_ts, return_t } => { param_ts.iter().any(|t| t.has_undoable_linked_var()) @@ -3866,7 +3894,8 @@ impl Type { before.has_unbound_var() || after.as_ref().map(|t| t.has_unbound_var()).unwrap_or(false) } - Self::And(tys) | Self::Or(tys) => tys.iter().any(|t| t.has_unbound_var()), + Self::And(tys) => tys.iter().any(|t| t.has_unbound_var()), + Self::Or(tys) => tys.iter().any(|t| t.has_unbound_var()), Self::Not(ty) => ty.has_unbound_var(), Self::Callable { param_ts, return_t } => { param_ts.iter().any(|t| t.has_unbound_var()) || return_t.has_unbound_var() @@ -3921,7 +3950,8 @@ impl Type { Self::Refinement(refine) => refine.t.typarams_len(), // REVIEW: Self::Ref(_) | Self::RefMut { .. } => Some(1), - Self::And(tys) | Self::Or(tys) => Some(tys.len()), + Self::And(tys) => Some(tys.len()), + Self::Or(tys) => Some(tys.len()), Self::Not(_) => Some(1), Self::Subr(subr) => Some( subr.non_default_params.len() @@ -3987,7 +4017,8 @@ impl Type { Self::FreeVar(_unbound) => vec![], Self::Refinement(refine) => refine.t.typarams(), Self::Ref(t) | Self::RefMut { before: t, .. } => vec![TyParam::t(*t.clone())], - Self::And(tys) | Self::Or(tys) => tys.iter().cloned().map(TyParam::t).collect(), + Self::And(tys) => tys.iter().cloned().map(TyParam::t).collect(), + Self::Or(tys) => tys.iter().cloned().map(TyParam::t).collect(), Self::Not(t) => vec![TyParam::t(*t.clone())], Self::Subr(subr) => subr.typarams(), Self::Quantified(quant) => quant.typarams(), @@ -4833,7 +4864,7 @@ impl Type { pub fn eliminate_and_or(&mut self) { match self { Self::And(tys) if tys.len() == 1 => { - *self = tys.take_all().into_iter().next().unwrap(); + *self = tys.remove(0); } Self::Or(tys) if tys.len() == 1 => { *self = tys.take_all().into_iter().next().unwrap(); @@ -5130,12 +5161,12 @@ impl Type { /// Add.ands() == {Add} /// (Add and Sub).ands() == {Add, Sub} /// ``` - pub fn ands(&self) -> Set { + pub fn ands(&self) -> Vec { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().ands(), Self::Refinement(refine) => refine.t.ands(), Self::And(tys) => tys.clone(), - _ => set![self.clone()], + _ => vec![self.clone()], } } @@ -5192,7 +5223,8 @@ impl Type { Self::Callable { param_ts, .. } => { param_ts.iter().flat_map(|t| t.contained_ts()).collect() } - Self::And(tys) | Self::Or(tys) => tys.iter().flat_map(|t| t.contained_ts()).collect(), + Self::And(tys) => tys.iter().flat_map(|t| t.contained_ts()).collect(), + Self::Or(tys) => tys.iter().flat_map(|t| t.contained_ts()).collect(), Self::Not(t) => t.contained_ts(), Self::Bounded { sub, sup } => sub.contained_ts().union(&sup.contained_ts()), Self::Quantified(ty) | Self::Structural(ty) => ty.contained_ts(), @@ -5261,7 +5293,19 @@ impl Type { } return_t.dereference(); } - Self::And(tys) | Self::Or(tys) => { + Self::And(tys) => { + *tys = std::mem::take(tys) + .into_iter() + .map(|mut t| { + t.dereference(); + t + }) + .collect(); + if tys.len() == 1 { + *self = tys.remove(0); + } + } + Self::Or(tys) => { *tys = tys .take_all() .into_iter() @@ -5380,7 +5424,8 @@ impl Type { set.extend(return_t.variables()); set } - Self::And(tys) | Self::Or(tys) => tys.iter().flat_map(|t| t.variables()).collect(), + Self::And(tys) => tys.iter().flat_map(|t| t.variables()).collect(), + Self::Or(tys) => tys.iter().flat_map(|t| t.variables()).collect(), Self::Not(ty) => ty.variables(), Self::Bounded { sub, sup } => sub.variables().union(&sup.variables()), Self::Quantified(ty) | Self::Structural(ty) => ty.variables(), From 461e91703ae6e1d9ac318190d05e3b1919965638 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sun, 15 Sep 2024 18:09:09 +0900 Subject: [PATCH 03/12] fix: union type bug (2) --- crates/erg_compiler/context/compare.rs | 18 +++++- .../erg_compiler/context/initialize/traits.rs | 8 +-- crates/erg_compiler/context/unify.rs | 16 +++-- crates/erg_compiler/lower.rs | 62 ++++++++----------- crates/erg_compiler/ty/mod.rs | 23 ++++--- 5 files changed, 73 insertions(+), 54 deletions(-) diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index 1d3d5da97..5707af044 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -824,7 +824,23 @@ impl Context { (Or(ors), rhs) => ors.iter().any(|or| self.supertype_of(or, rhs)), // Int :> (Nat or Str) == Int :> Nat && Int :> Str == false (lhs, Or(ors)) => ors.iter().all(|or| self.supertype_of(lhs, or)), - (And(l), And(r)) => r.iter().any(|r| l.iter().all(|l| self.supertype_of(l, r))), + // Hash and Eq :> HashEq and ... == true + // Add(T) and Eq :> Add(Int) and Eq == true + (And(l), And(r)) => { + if r.iter().any(|r| l.iter().all(|l| self.supertype_of(l, r))) { + return true; + } + if l.len() == r.len() { + let mut r = r.clone(); + for _ in 1..l.len() { + if l.iter().zip(&r).all(|(l, r)| self.supertype_of(l, r)) { + return true; + } + r.rotate_left(1); + } + } + false + } // (Num and Show) :> Show == false (And(ands), rhs) => ands.iter().all(|and| self.supertype_of(and, rhs)), // Show :> (Num and Show) == true diff --git a/crates/erg_compiler/context/initialize/traits.rs b/crates/erg_compiler/context/initialize/traits.rs index 40bad3696..ba94cb30b 100644 --- a/crates/erg_compiler/context/initialize/traits.rs +++ b/crates/erg_compiler/context/initialize/traits.rs @@ -588,10 +588,10 @@ impl Context { neg.register_builtin_erg_decl(OP_NEG, op_t, Visibility::BUILTIN_PUBLIC); neg.register_builtin_erg_decl(OUTPUT, Type, Visibility::BUILTIN_PUBLIC); /* Num */ - let mut num = Self::builtin_mono_trait(NUM, 2); - num.register_superclass(poly(ADD, vec![]), &add); - num.register_superclass(poly(SUB, vec![]), &sub); - num.register_superclass(poly(MUL, vec![]), &mul); + let num = Self::builtin_mono_trait(NUM, 2); + // num.register_superclass(poly(ADD, vec![]), &add); + // num.register_superclass(poly(SUB, vec![]), &sub); + // num.register_superclass(poly(MUL, vec![]), &mul); /* ToBool */ let mut to_bool = Self::builtin_mono_trait(TO_BOOL, 2); let _Slf = mono_q(SELF, subtypeof(mono(TO_BOOL))); diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index d6657a90c..cdc2f383f 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -1951,13 +1951,21 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { /// ``` fn unify(&self, lhs: &Type, rhs: &Type) -> Option { match (lhs, rhs) { + (Never, other) | (other, Never) => { + return Some(other.clone()); + } (Or(tys), other) | (other, Or(tys)) => { + let mut unified = Never; for ty in tys { if let Some(t) = self.unify(ty, other) { - return self.unify(&t, ty); + unified = self.ctx.union(&unified, &t); } } - return None; + if unified != Never { + return Some(unified); + } else { + return None; + } } (FreeVar(fv), _) if fv.is_linked() => return self.unify(&fv.crack(), rhs), (_, FreeVar(fv)) if fv.is_linked() => return self.unify(lhs, &fv.crack()), @@ -1981,11 +1989,11 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { let l_sups = self.ctx.get_super_classes(lhs)?; let r_sups = self.ctx.get_super_classes(rhs)?; for l_sup in l_sups { - if self.ctx.supertype_of(&l_sup, &Obj) { + if l_sup == Obj || self.ctx.is_trait(&l_sup) { continue; } for r_sup in r_sups.clone() { - if self.ctx.supertype_of(&r_sup, &Obj) { + if r_sup == Obj || self.ctx.is_trait(&r_sup) { continue; } if let Some(t) = self.ctx.max(&l_sup, &r_sup).either() { diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index 948a49de4..523d3489c 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -362,10 +362,9 @@ impl GenericASTLowerer { } } - fn elem_err(&self, l: &Type, r: &Type, elem: &hir::Expr) -> LowerErrors { + fn elem_err(&self, union: Type, elem: &hir::Expr) -> LowerErrors { let elem_disp_notype = elem.to_string_notype(); - let l = self.module.context.readable_type(l.clone()); - let r = self.module.context.readable_type(r.clone()); + let union = self.module.context.readable_type(union); LowerErrors::from(LowerError::syntax_error( self.cfg.input.clone(), line!() as usize, @@ -379,10 +378,10 @@ impl GenericASTLowerer { ) .to_owned(), Some(switch_lang!( - "japanese" => format!("[..., {elem_disp_notype}: {l} or {r}]など明示的に型を指定してください"), - "simplified_chinese" => format!("请明确指定类型,例如: [..., {elem_disp_notype}: {l} or {r}]"), - "traditional_chinese" => format!("請明確指定類型,例如: [..., {elem_disp_notype}: {l} or {r}]"), - "english" => format!("please specify the type explicitly, e.g. [..., {elem_disp_notype}: {l} or {r}]"), + "japanese" => format!("[..., {elem_disp_notype}: {union}]など明示的に型を指定してください"), + "simplified_chinese" => format!("请明确指定类型,例如: [..., {elem_disp_notype}: {union}]"), + "traditional_chinese" => format!("請明確指定類型,例如: [..., {elem_disp_notype}: {union}]"), + "english" => format!("please specify the type explicitly, e.g. [..., {elem_disp_notype}: {union}]"), )), )) } @@ -453,36 +452,25 @@ impl GenericASTLowerer { union: &Type, elem: &hir::Expr, ) -> LowerResult<()> { - if ERG_MODE && expect_elem.is_none() { - if let Some((l, r)) = union_.union_pair() { - match (l.is_unbound_var(), r.is_unbound_var()) { - // e.g. [1, "a"] - (false, false) => { - if let hir::Expr::TypeAsc(type_asc) = elem { - // e.g. [1, "a": Str or NoneType] - if !self - .module - .context - .supertype_of(&type_asc.spec.spec_t, union) - { - return Err(self.elem_err(&l, &r, elem)); - } // else(OK): e.g. [1, "a": Str or Int] - } - // OK: ?T(:> {"a"}) or ?U(:> {"b"}) or {"c", "d"} => {"a", "b", "c", "d"} <: Str - else if self - .module - .context - .coerce(union_.derefine(), &()) - .map_or(true, |coerced| coerced.union_pair().is_some()) - { - return Err(self.elem_err(&l, &r, elem)); - } - } - // TODO: check if the type is compatible with the other type - (true, false) => {} - (false, true) => {} - (true, true) => {} - } + if ERG_MODE && expect_elem.is_none() && union_.union_size() > 1 { + if let hir::Expr::TypeAsc(type_asc) = elem { + // e.g. [1, "a": Str or NoneType] + if !self + .module + .context + .supertype_of(&type_asc.spec.spec_t, union) + { + return Err(self.elem_err(union_.clone(), elem)); + } // else(OK): e.g. [1, "a": Str or Int] + } + // OK: ?T(:> {"a"}) or ?U(:> {"b"}) or {"c", "d"} => {"a", "b", "c", "d"} <: Str + else if self + .module + .context + .coerce(union_.derefine(), &()) + .map_or(true, |coerced| coerced.union_pair().is_some()) + { + return Err(self.elem_err(union_.clone(), elem)); } } Ok(()) diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index f44f57f01..f2e933458 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -1504,7 +1504,9 @@ impl PartialEq for Type { (Self::NamedTuple(lhs), Self::NamedTuple(rhs)) => lhs == rhs, (Self::Refinement(l), Self::Refinement(r)) => l == r, (Self::Quantified(l), Self::Quantified(r)) => l == r, - (Self::And(l), Self::And(r)) => l.iter().collect::>().linear_eq(&r.iter().collect()), + (Self::And(l), Self::And(r)) => { + l.iter().collect::>().linear_eq(&r.iter().collect()) + } (Self::Or(l), Self::Or(r)) => l.linear_eq(r), (Self::Not(l), Self::Not(r)) => l == r, ( @@ -1872,7 +1874,7 @@ impl BitAnd for Type { r.push(l); Self::And(r) } - (l, r) => Self::checked_and(vec! {l, r}), + (l, r) => Self::checked_and(vec![l, r]), } } } @@ -3023,12 +3025,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.has_type_satisfies(f) || args.iter().any(|t| t.has_type_satisfies(f)) } - Self::And(tys) => { - tys.iter().any(|t| t.has_type_satisfies(f)) - } - Self::Or(tys) => { - tys.iter().any(|t| t.has_type_satisfies(f)) - } + Self::And(tys) => tys.iter().any(|t| t.has_type_satisfies(f)), + Self::Or(tys) => tys.iter().any(|t| t.has_type_satisfies(f)), Self::Not(t) => t.has_type_satisfies(f), Self::Ref(t) => t.has_type_satisfies(f), Self::RefMut { before, after } => { @@ -3396,6 +3394,15 @@ impl Type { } } + pub fn union_types(&self) -> Option> { + match self { + Self::FreeVar(fv) if fv.is_linked() => fv.crack().union_types(), + Self::Refinement(refine) => refine.t.union_types(), + Self::Or(tys) => Some(tys.clone()), + _ => None, + } + } + /// assert!((A or B).contains_union(B)) pub fn contains_union(&self, typ: &Type) -> bool { match self { From 5508f652fc75a3b1a4c802d2416ef614bf22b9a3 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 16 Sep 2024 11:53:10 +0900 Subject: [PATCH 04/12] Update promise.rs --- crates/erg_compiler/module/promise.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/crates/erg_compiler/module/promise.rs b/crates/erg_compiler/module/promise.rs index 42525989e..e57b5ea68 100644 --- a/crates/erg_compiler/module/promise.rs +++ b/crates/erg_compiler/module/promise.rs @@ -1,7 +1,7 @@ use std::fmt; use std::thread::{current, JoinHandle, ThreadId}; -use erg_common::consts::DEBUG_MODE; +use erg_common::consts::{DEBUG_MODE, SINGLE_THREAD}; use erg_common::dict::Dict; use erg_common::pathutil::NormalizedPathBuf; use erg_common::shared::Shared; @@ -169,12 +169,19 @@ impl SharedPromises { } pub fn join(&self, path: &NormalizedPathBuf) -> std::thread::Result<()> { + if !self.graph.entries().contains(path) { + return Err(Box::new(format!("not registered: {path}"))); + } if self.graph.ancestors(path).contains(&self.root) { // cycle detected, `self.path` must not in the dependencies // Erg analysis processes never join ancestor threads (although joining ancestors itself is allowed in Rust) // self.wait_until_finished(path); return Ok(()); } + if SINGLE_THREAD { + assert!(self.is_joined(path)); + return Ok(()); + } // Suppose A depends on B and C, and B depends on C. // In this case, B must join C before A joins C. Otherwise, a deadlock will occur. let children = self.graph.children(path); From 5eb0a50a23d4de87d1c05362e77913b1a971534d Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 16 Sep 2024 13:15:44 +0900 Subject: [PATCH 05/12] debug: mark --- crates/erg_compiler/build_package.rs | 10 +++++++++- crates/erg_compiler/context/inquire.rs | 2 +- crates/erg_compiler/module/cache.rs | 1 + crates/erg_compiler/module/promise.rs | 7 +++++-- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/crates/erg_compiler/build_package.rs b/crates/erg_compiler/build_package.rs index 0c46ec8f1..636280ca4 100644 --- a/crates/erg_compiler/build_package.rs +++ b/crates/erg_compiler/build_package.rs @@ -418,7 +418,7 @@ impl let res = self.resolve(&mut ast, &cfg); debug_assert!(res.is_ok(), "{:?}", res.unwrap_err()); log!(info "Dependency resolution process completed"); - log!("graph:\n{}", self.shared.graph.display()); + println!("graph:\n{}", self.shared.graph.display()); if self.parse_errors.errors.is_empty() { self.shared.warns.extend(self.parse_errors.warns.flush()); // continue analysis if ELS mode @@ -838,9 +838,12 @@ impl write!(out, "Checking 0/{nmods}").unwrap(); out.flush().unwrap(); } + println!("here?: {path}"); + let mut limit = 100000; while let Some(ancestor) = ancestors.pop() { if graph.ancestors(&ancestor).is_empty() { graph.remove(&ancestor); + limit = 100000; if let Some(entry) = self.asts.remove(&ancestor) { if print_progress { let name = ancestor.file_name().unwrap_or_default().to_string_lossy(); @@ -861,6 +864,10 @@ impl self.build_inlined_module(&ancestor, graph); } } else { + limit -= 1; + if limit == 0 { + panic!("{ancestor} is in a circular dependency"); + } ancestors.insert(0, ancestor); } } @@ -943,6 +950,7 @@ impl } } }; + println!("Start to analyze {path}"); if SINGLE_THREAD { run(); self.shared.promises.mark_as_joined(path); diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index 44aa46b3e..76bfd2b12 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -73,7 +73,7 @@ impl Context { } if self.shared.is_some() && self.promises().is_registered(&path) && !self.mod_cached(&path) { - let _result = self.promises().join(&path); + self.promises().join(&path).unwrap(); } self.opt_mod_cache()? .raw_ref_ctx(&path) diff --git a/crates/erg_compiler/module/cache.rs b/crates/erg_compiler/module/cache.rs index 76fa220ff..482b6e744 100644 --- a/crates/erg_compiler/module/cache.rs +++ b/crates/erg_compiler/module/cache.rs @@ -340,6 +340,7 @@ impl SharedModuleCache { where NormalizedPathBuf: Borrow, { + println!("343"); let mut cache = loop { if let Some(cache) = self.0.try_borrow_mut() { break cache; diff --git a/crates/erg_compiler/module/promise.rs b/crates/erg_compiler/module/promise.rs index e57b5ea68..924215391 100644 --- a/crates/erg_compiler/module/promise.rs +++ b/crates/erg_compiler/module/promise.rs @@ -160,6 +160,7 @@ impl SharedPromises { } pub fn wait_until_finished(&self, path: &NormalizedPathBuf) { + println!("163"); if self.promises.borrow().get(path).is_none() { panic!("not registered: {path}"); } @@ -179,9 +180,11 @@ impl SharedPromises { return Ok(()); } if SINGLE_THREAD { + println!("182: {path}"); assert!(self.is_joined(path)); return Ok(()); } + println!("!?: {path}"); // Suppose A depends on B and C, and B depends on C. // In this case, B must join C before A joins C. Otherwise, a deadlock will occur. let children = self.graph.children(path); @@ -223,14 +226,14 @@ impl SharedPromises { paths.push(path.clone()); } for path in paths { - let _result = self.join(&path); + self.join(&path).unwrap(); } } pub fn join_all(&self) { let paths = self.promises.borrow().keys().cloned().collect::>(); for path in paths { - let _result = self.join(&path); + self.join(&path).unwrap(); } } From fd76f56ba4a702892a7b23e11d99a650193e94f0 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 16 Sep 2024 21:32:01 +0900 Subject: [PATCH 06/12] Revert "debug: mark" This reverts commit 5eb0a50a23d4de87d1c05362e77913b1a971534d. --- crates/erg_compiler/build_package.rs | 10 +--------- crates/erg_compiler/context/inquire.rs | 2 +- crates/erg_compiler/module/cache.rs | 1 - crates/erg_compiler/module/promise.rs | 7 ++----- 4 files changed, 4 insertions(+), 16 deletions(-) diff --git a/crates/erg_compiler/build_package.rs b/crates/erg_compiler/build_package.rs index 636280ca4..0c46ec8f1 100644 --- a/crates/erg_compiler/build_package.rs +++ b/crates/erg_compiler/build_package.rs @@ -418,7 +418,7 @@ impl let res = self.resolve(&mut ast, &cfg); debug_assert!(res.is_ok(), "{:?}", res.unwrap_err()); log!(info "Dependency resolution process completed"); - println!("graph:\n{}", self.shared.graph.display()); + log!("graph:\n{}", self.shared.graph.display()); if self.parse_errors.errors.is_empty() { self.shared.warns.extend(self.parse_errors.warns.flush()); // continue analysis if ELS mode @@ -838,12 +838,9 @@ impl write!(out, "Checking 0/{nmods}").unwrap(); out.flush().unwrap(); } - println!("here?: {path}"); - let mut limit = 100000; while let Some(ancestor) = ancestors.pop() { if graph.ancestors(&ancestor).is_empty() { graph.remove(&ancestor); - limit = 100000; if let Some(entry) = self.asts.remove(&ancestor) { if print_progress { let name = ancestor.file_name().unwrap_or_default().to_string_lossy(); @@ -864,10 +861,6 @@ impl self.build_inlined_module(&ancestor, graph); } } else { - limit -= 1; - if limit == 0 { - panic!("{ancestor} is in a circular dependency"); - } ancestors.insert(0, ancestor); } } @@ -950,7 +943,6 @@ impl } } }; - println!("Start to analyze {path}"); if SINGLE_THREAD { run(); self.shared.promises.mark_as_joined(path); diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index 8c3562a85..30be2cfc8 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -73,7 +73,7 @@ impl Context { } if self.shared.is_some() && self.promises().is_registered(&path) && !self.mod_cached(&path) { - self.promises().join(&path).unwrap(); + let _result = self.promises().join(&path); } self.opt_mod_cache()? .raw_ref_ctx(&path) diff --git a/crates/erg_compiler/module/cache.rs b/crates/erg_compiler/module/cache.rs index 482b6e744..76fa220ff 100644 --- a/crates/erg_compiler/module/cache.rs +++ b/crates/erg_compiler/module/cache.rs @@ -340,7 +340,6 @@ impl SharedModuleCache { where NormalizedPathBuf: Borrow, { - println!("343"); let mut cache = loop { if let Some(cache) = self.0.try_borrow_mut() { break cache; diff --git a/crates/erg_compiler/module/promise.rs b/crates/erg_compiler/module/promise.rs index 924215391..e57b5ea68 100644 --- a/crates/erg_compiler/module/promise.rs +++ b/crates/erg_compiler/module/promise.rs @@ -160,7 +160,6 @@ impl SharedPromises { } pub fn wait_until_finished(&self, path: &NormalizedPathBuf) { - println!("163"); if self.promises.borrow().get(path).is_none() { panic!("not registered: {path}"); } @@ -180,11 +179,9 @@ impl SharedPromises { return Ok(()); } if SINGLE_THREAD { - println!("182: {path}"); assert!(self.is_joined(path)); return Ok(()); } - println!("!?: {path}"); // Suppose A depends on B and C, and B depends on C. // In this case, B must join C before A joins C. Otherwise, a deadlock will occur. let children = self.graph.children(path); @@ -226,14 +223,14 @@ impl SharedPromises { paths.push(path.clone()); } for path in paths { - self.join(&path).unwrap(); + let _result = self.join(&path); } } pub fn join_all(&self) { let paths = self.promises.borrow().keys().cloned().collect::>(); for path in paths { - self.join(&path).unwrap(); + let _result = self.join(&path); } } From cb9380f3aa8208f07397cbf75eb5f727b47881b4 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 16 Sep 2024 21:59:22 +0900 Subject: [PATCH 07/12] Update mod.rs --- crates/erg_compiler/ty/mod.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index a9bec0a63..f2e6b9227 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -1956,6 +1956,7 @@ impl HasType for Type { impl HasLevel for Type { fn level(&self) -> Option { + println!("lev: {self}"); match self { Self::FreeVar(v) => v.level(), Self::Ref(t) => t.level(), @@ -2034,7 +2035,10 @@ impl HasLevel for Type { Some(min) } } - Self::Structural(ty) => ty.level(), + Self::Structural(ty) => { + set_recursion_limit!(None, 128); + ty.level() + } Self::Guard(guard) => guard.to.level(), Self::Quantified(quant) => quant.level(), Self::Bounded { sub, sup } => { From 9d88e8d7e7c245410e5b04f0cb6f401377068000 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 17 Sep 2024 02:00:51 +0900 Subject: [PATCH 08/12] fix: type variable linking bug --- crates/erg_common/dict.rs | 12 ++++++ crates/erg_common/macros.rs | 11 ++++++ .../context/initialize/classes.rs | 39 ++++++++++--------- crates/erg_compiler/ty/free.rs | 11 +++++- crates/erg_compiler/ty/mod.rs | 32 +++++++++++---- 5 files changed, 78 insertions(+), 27 deletions(-) diff --git a/crates/erg_common/dict.rs b/crates/erg_common/dict.rs index 9f46fd13c..c4b2767c5 100644 --- a/crates/erg_common/dict.rs +++ b/crates/erg_common/dict.rs @@ -129,6 +129,18 @@ impl Dict { } } + /// ``` + /// # use erg_common::dict; + /// # use erg_common::dict::Dict; + /// let mut dict = Dict::with_capacity(3); + /// assert_eq!(dict.capacity(), 3); + /// dict.insert("a", 1); + /// assert_eq!(dict.capacity(), 3); + /// dict.insert("b", 2); + /// dict.insert("c", 3); + /// dict.insert("d", 4); + /// assert_ne!(dict.capacity(), 3); + /// ``` pub fn with_capacity(capacity: usize) -> Self { Self { dict: FxHashMap::with_capacity_and_hasher(capacity, Default::default()), diff --git a/crates/erg_common/macros.rs b/crates/erg_common/macros.rs index 93a6130ff..8e4e3e8e8 100644 --- a/crates/erg_common/macros.rs +++ b/crates/erg_common/macros.rs @@ -627,6 +627,17 @@ impl RecursionCounter { #[macro_export] macro_rules! set_recursion_limit { + (panic, $msg:expr, $limit:expr) => { + use std::sync::atomic::AtomicU32; + + static COUNTER: AtomicU32 = AtomicU32::new($limit); + + let counter = $crate::macros::RecursionCounter::new(&COUNTER); + if counter.limit_reached() { + $crate::log!(err "Recursion limit reached"); + panic!($msg); + } + }; ($returns:expr, $limit:expr) => { use std::sync::atomic::AtomicU32; diff --git a/crates/erg_compiler/context/initialize/classes.rs b/crates/erg_compiler/context/initialize/classes.rs index 1d9fba476..ae54c018c 100644 --- a/crates/erg_compiler/context/initialize/classes.rs +++ b/crates/erg_compiler/context/initialize/classes.rs @@ -16,6 +16,9 @@ use crate::varinfo::Mutability; use Mutability::*; impl Context { + // NOTE: Registering traits that a class implements requires type checking, + // which means that registering a class requires that the preceding types have already been registered, + // so `register_builtin_type` should be called as early as possible. pub(super) fn init_builtin_classes(&mut self) { let vis = if PYTHON_MODE { Visibility::BUILTIN_PUBLIC @@ -29,6 +32,7 @@ impl Context { let N = mono_q_tp(TY_N, instanceof(Nat)); let M = mono_q_tp(TY_M, instanceof(Nat)); let never = Self::builtin_mono_class(NEVER, 1); + self.register_builtin_type(Never, never, vis.clone(), Const, Some(NEVER)); /* Obj */ let mut obj = Self::builtin_mono_class(OBJ, 2); obj.register_py_builtin( @@ -2965,6 +2969,21 @@ impl Context { None, union, ); + self.register_builtin_type( + mono(GENERIC_TUPLE), + generic_tuple, + vis.clone(), + Const, + Some(FUNC_TUPLE), + ); + self.register_builtin_type( + homo_tuple_t, + homo_tuple, + vis.clone(), + Const, + Some(FUNC_TUPLE), + ); + self.register_builtin_type(_tuple_t, tuple_, vis.clone(), Const, Some(FUNC_TUPLE)); /* Or (true or type) */ let or_t = poly(OR, vec![ty_tp(L), ty_tp(R)]); let mut or = Self::builtin_poly_class(OR, vec![PS::t_nd(TY_L), PS::t_nd(TY_R)], 2); @@ -3673,6 +3692,8 @@ impl Context { Some(FUNC_UPDATE), ); list_mut_.register_trait_methods(list_mut_t.clone(), list_mut_mutable); + self.register_builtin_type(lis_t, list_, vis.clone(), Const, Some(LIST)); + self.register_builtin_type(list_mut_t, list_mut_, vis.clone(), Const, Some(LIST)); /* ByteArray! */ let bytearray_mut_t = mono(MUT_BYTEARRAY); let mut bytearray_mut = Self::builtin_mono_class(MUT_BYTEARRAY, 2); @@ -4213,7 +4234,6 @@ impl Context { let mut qfunc_meta_type = Self::builtin_mono_class(QUANTIFIED_FUNC_META_TYPE, 2); qfunc_meta_type.register_superclass(mono(QUANTIFIED_PROC_META_TYPE), &qproc_meta_type); qfunc_meta_type.register_superclass(mono(QUANTIFIED_FUNC), &qfunc); - self.register_builtin_type(Never, never, vis.clone(), Const, Some(NEVER)); self.register_builtin_type(Obj, obj, vis.clone(), Const, Some(FUNC_OBJECT)); // self.register_type(mono(RECORD), vec![], record, Visibility::BUILTIN_PRIVATE, Const); let name = if PYTHON_MODE { FUNC_INT } else { INT }; @@ -4261,7 +4281,6 @@ impl Context { Const, Some(UNSIZED_LIST), ); - self.register_builtin_type(lis_t, list_, vis.clone(), Const, Some(LIST)); self.register_builtin_type(mono(SLICE), slice, vis.clone(), Const, Some(FUNC_SLICE)); self.register_builtin_type( mono(GENERIC_SET), @@ -4274,21 +4293,6 @@ impl Context { self.register_builtin_type(g_dict_t, generic_dict, vis.clone(), Const, Some(DICT)); self.register_builtin_type(dict_t, dict_, vis.clone(), Const, Some(DICT)); self.register_builtin_type(mono(BYTES), bytes, vis.clone(), Const, Some(BYTES)); - self.register_builtin_type( - mono(GENERIC_TUPLE), - generic_tuple, - vis.clone(), - Const, - Some(FUNC_TUPLE), - ); - self.register_builtin_type( - homo_tuple_t, - homo_tuple, - vis.clone(), - Const, - Some(FUNC_TUPLE), - ); - self.register_builtin_type(_tuple_t, tuple_, vis.clone(), Const, Some(FUNC_TUPLE)); self.register_builtin_type(mono(RECORD), record, vis.clone(), Const, Some(RECORD)); self.register_builtin_type( mono(RECORD_META_TYPE), @@ -4411,7 +4415,6 @@ impl Context { Some(MEMORYVIEW), ); self.register_builtin_type(mono(MUT_FILE), file_mut, vis.clone(), Const, Some(FILE)); - self.register_builtin_type(list_mut_t, list_mut_, vis.clone(), Const, Some(LIST)); self.register_builtin_type( bytearray_mut_t, bytearray_mut, diff --git a/crates/erg_compiler/ty/free.rs b/crates/erg_compiler/ty/free.rs index 4cb6d0cf1..e9cad3bfb 100644 --- a/crates/erg_compiler/ty/free.rs +++ b/crates/erg_compiler/ty/free.rs @@ -774,7 +774,12 @@ impl Free { let placeholder = placeholder.unwrap_or(&Type::Failure); let is_recursive = self.is_recursive(); if is_recursive { - self.undoable_link(placeholder); + let target = Type::FreeVar(self.clone()); + let placeholder_ = placeholder + .clone() + .eliminate_subsup(&target) + .eliminate_and_or_recursion(&target); + self.undoable_link(&placeholder_); } let res = f(); if is_recursive { @@ -884,7 +889,9 @@ impl Free { let placeholder = placeholder.unwrap_or(&TyParam::Failure); let is_recursive = self.is_recursive(); if is_recursive { - self.undoable_link(placeholder); + let target = TyParam::FreeVar(self.clone()); + let placeholder_ = placeholder.clone().eliminate_recursion(&target); + self.undoable_link(&placeholder_); } let res = f(); if is_recursive { diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index f2e6b9227..abc8dcf68 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -4313,7 +4313,7 @@ impl Type { /// (T or U).eliminate_subsup(T) == U /// ?X(<: T or U).eliminate_subsup(T) == ?X(<: U) /// ``` - pub fn eliminate_subsup(self, target: &Type) -> Self { + pub(crate) fn eliminate_subsup(self, target: &Type) -> Self { match self { Self::FreeVar(fv) if fv.is_linked() => fv.unwrap_linked().eliminate_subsup(target), Self::FreeVar(ref fv) if fv.constraint_is_sandwiched() => { @@ -4361,7 +4361,7 @@ impl Type { /// ?T(<: K(X)).eliminate_recursion(X) == ?T(<: K(X)) /// Tuple(X).eliminate_recursion(X) == Tuple(Never) /// ``` - pub fn eliminate_recursion(self, target: &Type) -> Self { + pub(crate) fn eliminate_recursion(self, target: &Type) -> Self { if self.is_free_var() && self.addr_eq(target) { return Self::Never; } @@ -4409,11 +4409,13 @@ impl Type { }, Self::And(tys) => Self::checked_and( tys.into_iter() + .filter(|t| !t.addr_eq(target)) .map(|t| t.eliminate_recursion(target)) .collect(), ), Self::Or(tys) => Self::checked_or( tys.into_iter() + .filter(|t| !t.addr_eq(target)) .map(|t| t.eliminate_recursion(target)) .collect(), ), @@ -4441,6 +4443,18 @@ impl Type { } } + pub(crate) fn eliminate_and_or_recursion(self, target: &Type) -> Self { + match self { + Self::And(tys) => { + Self::checked_and(tys.into_iter().filter(|t| !t.addr_eq(target)).collect()) + } + Self::Or(tys) => { + Self::checked_or(tys.into_iter().filter(|t| !t.addr_eq(target)).collect()) + } + _ => self, + } + } + pub fn replace(self, target: &Type, to: &Type) -> Type { let table = ReplaceTable::make(target, to); table.replace(self) @@ -4597,7 +4611,7 @@ impl Type { /// Unlike `replace`, this does not make a look-up table. fn _replace(mut self, target: &Type, to: &Type) -> Type { - if self.structural_eq(target) { + if &self == target { self = to.clone(); } self.map(&mut |t| t._replace(target, to)) @@ -5049,8 +5063,8 @@ impl Type { } match self { Self::FreeVar(fv) => { - let to = to.clone().eliminate_subsup(self).eliminate_recursion(self); - fv.link(&to); + let to_ = to.clone().eliminate_subsup(self).eliminate_recursion(self); + fv.link(&to_); } Self::Refinement(refine) => refine.t.destructive_link(to), _ => { @@ -5072,8 +5086,12 @@ impl Type { } match self { Self::FreeVar(fv) => { - let to = to.clone().eliminate_subsup(self); // FIXME: .eliminate_recursion(self) - fv.undoable_link(&to); + // NOTE: we can't use `eliminate_recursion` + let to_ = to + .clone() + .eliminate_subsup(self) + .eliminate_and_or_recursion(self); + fv.undoable_link(&to_); } Self::Refinement(refine) => refine.t.undoable_link(to, list), _ => { From b84fee183d1a66ace5a0178252d5b7b8c1121e82 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 17 Sep 2024 02:07:37 +0900 Subject: [PATCH 09/12] Revert "Update mod.rs" This reverts commit cb9380f3aa8208f07397cbf75eb5f727b47881b4. --- crates/erg_compiler/ty/mod.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index abc8dcf68..44a08145d 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -1956,7 +1956,6 @@ impl HasType for Type { impl HasLevel for Type { fn level(&self) -> Option { - println!("lev: {self}"); match self { Self::FreeVar(v) => v.level(), Self::Ref(t) => t.level(), @@ -2035,10 +2034,7 @@ impl HasLevel for Type { Some(min) } } - Self::Structural(ty) => { - set_recursion_limit!(None, 128); - ty.level() - } + Self::Structural(ty) => ty.level(), Self::Guard(guard) => guard.to.level(), Self::Quantified(quant) => quant.level(), Self::Bounded { sub, sup } => { From a876b34145d5ce7de31e74ffd90526509cea2c49 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 17 Sep 2024 02:38:07 +0900 Subject: [PATCH 10/12] Update unify.rs --- crates/erg_compiler/context/unify.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index 26814a770..3df47f5c5 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -70,6 +70,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { /// occur(?T(<: Str) or ?U(<: Int), ?T(<: Str)) ==> Error /// occur(?T(<: ?U or Y), ?U) ==> OK /// occur(?T, ?T.Output) ==> OK + /// occur(?T, ?T or Int) ==> Error /// ``` fn occur(&self, maybe_sub: &Type, maybe_sup: &Type) -> TyCheckResult<()> { if maybe_sub == maybe_sup { @@ -160,7 +161,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { for _ in 0..r.len() { if l.iter() .zip(r.iter()) - .all(|(l, r)| self.occur(l, r).is_ok()) + .all(|(l, r)| self.occur_inner(l, r).is_ok()) { return Ok(()); } @@ -181,7 +182,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { for _ in 0..r.len() { if l.iter() .zip(r.iter()) - .all(|(l, r)| self.occur(l, r).is_ok()) + .all(|(l, r)| self.occur_inner(l, r).is_ok()) { return Ok(()); } @@ -198,25 +199,25 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } (lhs, And(tys)) => { for ty in tys.iter() { - self.occur(lhs, ty)?; + self.occur_inner(lhs, ty)?; } Ok(()) } (lhs, Or(tys)) => { for ty in tys.iter() { - self.occur(lhs, ty)?; + self.occur_inner(lhs, ty)?; } Ok(()) } (And(tys), rhs) => { for ty in tys.iter() { - self.occur(ty, rhs)?; + self.occur_inner(ty, rhs)?; } Ok(()) } (Or(tys), rhs) => { for ty in tys.iter() { - self.occur(ty, rhs)?; + self.occur_inner(ty, rhs)?; } Ok(()) } From 1f51d188ea4e544c4421a0d29546e9d58927d8eb Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 17 Sep 2024 16:04:45 +0900 Subject: [PATCH 11/12] fix: `Type::has_type_satisfies` --- crates/erg_compiler/ty/mod.rs | 263 +++++----------------------- crates/erg_compiler/ty/predicate.rs | 37 +++- crates/erg_compiler/ty/typaram.rs | 14 +- crates/erg_compiler/ty/value.rs | 2 +- 4 files changed, 81 insertions(+), 235 deletions(-) diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index 44a08145d..32f782779 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -537,48 +537,15 @@ impl SubrType { } pub fn contains_tvar(&self, target: &FreeTyVar) -> bool { - self.non_default_params - .iter() - .any(|pt| pt.typ().contains_tvar(target)) - || self - .var_params - .as_ref() - .map_or(false, |pt| pt.typ().contains_tvar(target)) - || self.default_params.iter().any(|pt| { - pt.typ().contains_tvar(target) - || pt.default_typ().is_some_and(|t| t.contains_tvar(target)) - }) - || self.return_t.contains_tvar(target) + self.has_type_satisfies(|t| t.contains_tvar(target)) } pub fn contains_type(&self, target: &Type) -> bool { - self.non_default_params - .iter() - .any(|pt| pt.typ().contains_type(target)) - || self - .var_params - .as_ref() - .map_or(false, |pt| pt.typ().contains_type(target)) - || self.default_params.iter().any(|pt| { - pt.typ().contains_type(target) - || pt.default_typ().is_some_and(|t| t.contains_type(target)) - }) - || self.return_t.contains_type(target) + self.has_type_satisfies(|t| t.contains_type(target)) } pub fn contains_tp(&self, target: &TyParam) -> bool { - self.non_default_params - .iter() - .any(|pt| pt.typ().contains_tp(target)) - || self - .var_params - .as_ref() - .map_or(false, |pt| pt.typ().contains_tp(target)) - || self.default_params.iter().any(|pt| { - pt.typ().contains_tp(target) - || pt.default_typ().is_some_and(|t| t.contains_tp(target)) - }) - || self.return_t.contains_tp(target) + self.has_type_satisfies(|t| t.contains_tp(target)) } pub fn map(self, f: &mut impl FnMut(Type) -> Type) -> Self { @@ -708,48 +675,27 @@ impl SubrType { Set::multi_intersection(qnames_sets).extended(structural_qname) } - pub fn has_qvar(&self) -> bool { - self.non_default_params.iter().any(|pt| pt.typ().has_qvar()) - || self - .var_params - .as_ref() - .map_or(false, |pt| pt.typ().has_qvar()) + pub fn has_type_satisfies(&self, f: impl Fn(&Type) -> bool + Copy) -> bool { + self.non_default_params.iter().any(|pt| f(pt.typ())) + || self.var_params.as_ref().map_or(false, |pt| f(pt.typ())) || self .default_params .iter() - .any(|pt| pt.typ().has_qvar() || pt.default_typ().is_some_and(|t| t.has_qvar())) - || self.return_t.has_qvar() + .any(|pt| f(pt.typ()) || pt.default_typ().is_some_and(f)) + || self.kw_var_params.as_ref().map_or(false, |pt| f(pt.typ())) + || f(&self.return_t) + } + + pub fn has_qvar(&self) -> bool { + self.has_type_satisfies(|t| t.has_qvar()) } pub fn has_unbound_var(&self) -> bool { - self.non_default_params - .iter() - .any(|pt| pt.typ().has_unbound_var()) - || self - .var_params - .as_ref() - .map_or(false, |pt| pt.typ().has_unbound_var()) - || self.default_params.iter().any(|pt| { - pt.typ().has_unbound_var() || pt.default_typ().is_some_and(|t| t.has_unbound_var()) - }) - || self.return_t.has_unbound_var() + self.has_type_satisfies(|t| t.has_unbound_var()) } pub fn has_undoable_linked_var(&self) -> bool { - self.non_default_params - .iter() - .any(|pt| pt.typ().has_undoable_linked_var()) - || self - .var_params - .as_ref() - .map_or(false, |pt| pt.typ().has_undoable_linked_var()) - || self.default_params.iter().any(|pt| { - pt.typ().has_undoable_linked_var() - || pt - .default_typ() - .is_some_and(|t| t.has_undoable_linked_var()) - }) - || self.return_t.has_undoable_linked_var() + self.has_type_satisfies(|t| t.has_undoable_linked_var()) } pub fn typarams(&self) -> Vec { @@ -2953,91 +2899,37 @@ impl Type { }) .unwrap_or(false) } - Self::Record(rec) => rec.iter().any(|(_, t)| t.contains_tvar(target)), - Self::NamedTuple(rec) => rec.iter().any(|(_, t)| t.contains_tvar(target)), - Self::Poly { params, .. } => params.iter().any(|tp| tp.contains_tvar(target)), - Self::Quantified(t) => t.contains_tvar(target), - Self::Subr(subr) => subr.contains_tvar(target), - // TODO: preds - Self::Refinement(refine) => refine.t.contains_tvar(target), - Self::Structural(ty) => ty.contains_tvar(target), - Self::Proj { lhs, .. } => lhs.contains_tvar(target), - Self::ProjCall { lhs, args, .. } => { - lhs.contains_tvar(target) || args.iter().any(|t| t.contains_tvar(target)) - } - Self::And(tys) => tys.iter().any(|t| t.contains_tvar(target)), - Self::Or(tys) => tys.iter().any(|t| t.contains_tvar(target)), - Self::Not(t) => t.contains_tvar(target), - Self::Ref(t) => t.contains_tvar(target), - Self::RefMut { before, after } => { - before.contains_tvar(target) - || after.as_ref().map_or(false, |t| t.contains_tvar(target)) - } - Self::Bounded { sub, sup } => sub.contains_tvar(target) || sup.contains_tvar(target), - Self::Callable { param_ts, return_t } => { - param_ts.iter().any(|t| t.contains_tvar(target)) || return_t.contains_tvar(target) - } - Self::Guard(guard) => guard.to.contains_tvar(target), - mono_type_pattern!() => false, + _ => self.has_type_satisfies(|t| t.contains_tvar(target)), } } pub fn has_type_satisfies(&self, f: impl Fn(&Type) -> bool + Copy) -> bool { match self { - Self::FreeVar(fv) if fv.is_linked() => fv.crack().has_type_satisfies(f), - Self::FreeVar(fv) if fv.constraint_is_typeof() => { - fv.get_type().unwrap().has_type_satisfies(f) - } + Self::FreeVar(fv) if fv.is_linked() => f(&fv.crack()), + Self::FreeVar(fv) if fv.constraint_is_typeof() => f(&fv.get_type().unwrap()), Self::FreeVar(fv) => fv .get_subsup() - .map(|(sub, sup)| { - fv.do_avoiding_recursion(|| { - sub.has_type_satisfies(f) || sup.has_type_satisfies(f) - }) - }) + .map(|(sub, sup)| fv.do_avoiding_recursion(|| f(&sub) || f(&sup))) .unwrap_or(false), - Self::Record(rec) => rec.iter().any(|(_, t)| t.has_type_satisfies(f)), - Self::NamedTuple(rec) => rec.iter().any(|(_, t)| t.has_type_satisfies(f)), + Self::Record(rec) => rec.values().any(f), + Self::NamedTuple(rec) => rec.iter().any(|(_, t)| f(t)), Self::Poly { params, .. } => params.iter().any(|tp| tp.has_type_satisfies(f)), - Self::Quantified(t) => t.has_type_satisfies(f), - Self::Subr(subr) => { - subr.non_default_params - .iter() - .any(|pt| pt.typ().has_type_satisfies(f)) - || subr - .var_params - .as_ref() - .map_or(false, |pt| pt.typ().has_type_satisfies(f)) - || subr - .default_params - .iter() - .any(|pt| pt.typ().has_type_satisfies(f)) - || subr - .default_params - .iter() - .any(|pt| pt.default_typ().map_or(false, |t| t.has_type_satisfies(f))) - || subr.return_t.has_type_satisfies(f) - } - // TODO: preds - Self::Refinement(refine) => refine.t.has_type_satisfies(f), - Self::Structural(ty) => ty.has_type_satisfies(f), - Self::Proj { lhs, .. } => lhs.has_type_satisfies(f), + Self::Quantified(t) => f(t), + Self::Subr(subr) => subr.has_type_satisfies(f), + Self::Refinement(refine) => f(&refine.t) || refine.pred.has_type_satisfies(f), + Self::Structural(ty) => f(ty), + Self::Proj { lhs, .. } => f(lhs), Self::ProjCall { lhs, args, .. } => { - lhs.has_type_satisfies(f) || args.iter().any(|t| t.has_type_satisfies(f)) - } - Self::And(tys) => tys.iter().any(|t| t.has_type_satisfies(f)), - Self::Or(tys) => tys.iter().any(|t| t.has_type_satisfies(f)), - Self::Not(t) => t.has_type_satisfies(f), - Self::Ref(t) => t.has_type_satisfies(f), - Self::RefMut { before, after } => { - before.has_type_satisfies(f) - || after.as_ref().map_or(false, |t| t.has_type_satisfies(f)) - } - Self::Bounded { sub, sup } => sub.has_type_satisfies(f) || sup.has_type_satisfies(f), - Self::Callable { param_ts, return_t } => { - param_ts.iter().any(|t| t.has_type_satisfies(f)) || return_t.has_type_satisfies(f) - } - Self::Guard(guard) => guard.to.has_type_satisfies(f), + lhs.has_type_satisfies(f) || args.iter().any(|tp| tp.has_type_satisfies(f)) + } + Self::And(tys) => tys.iter().any(f), + Self::Or(tys) => tys.iter().any(f), + Self::Not(t) => f(t), + Self::Ref(t) => f(t), + Self::RefMut { before, after } => f(before) || after.as_ref().map_or(false, |t| f(t)), + Self::Bounded { sub, sup } => f(sub) || f(sup), + Self::Callable { param_ts, return_t } => param_ts.iter().any(f) || f(return_t), + Self::Guard(guard) => f(&guard.to), mono_type_pattern!() => false, } } @@ -3810,31 +3702,15 @@ impl Type { opt_t.map_or(false, |t| t.has_qvar()) } } - Self::Ref(ty) => ty.has_qvar(), - Self::RefMut { before, after } => { - before.has_qvar() || after.as_ref().map(|t| t.has_qvar()).unwrap_or(false) - } - Self::And(tys) => tys.iter().any(|t| t.has_qvar()), - Self::Or(tys) => tys.iter().any(|t| t.has_qvar()), - Self::Not(ty) => ty.has_qvar(), - Self::Callable { param_ts, return_t } => { - param_ts.iter().any(|t| t.has_qvar()) || return_t.has_qvar() - } Self::Subr(subr) => subr.has_qvar(), Self::Quantified(_) => false, // Self::Quantified(quant) => quant.has_qvar(), - Self::Record(r) => r.values().any(|t| t.has_qvar()), - Self::NamedTuple(r) => r.iter().any(|(_, t)| t.has_qvar()), Self::Refinement(refine) => refine.t.has_qvar() || refine.pred.has_qvar(), Self::Poly { params, .. } => params.iter().any(|tp| tp.has_qvar()), - Self::Proj { lhs, .. } => lhs.has_qvar(), Self::ProjCall { lhs, args, .. } => { lhs.has_qvar() || args.iter().any(|tp| tp.has_qvar()) } - Self::Structural(ty) => ty.has_qvar(), - Self::Guard(guard) => guard.to.has_qvar(), - Self::Bounded { sub, sup } => sub.has_qvar() || sup.has_qvar(), - mono_type_pattern!() => false, + _ => self.has_type_satisfies(|t| t.has_qvar()), } } @@ -3860,39 +3736,12 @@ impl Type { opt_t.map_or(false, |t| t.has_undoable_linked_var()) } } - Self::Ref(ty) => ty.has_undoable_linked_var(), - Self::RefMut { before, after } => { - before.has_undoable_linked_var() - || after - .as_ref() - .map(|t| t.has_undoable_linked_var()) - .unwrap_or(false) - } - Self::And(tys) => tys.iter().any(|t| t.has_undoable_linked_var()), - Self::Or(tys) => tys.iter().any(|t| t.has_undoable_linked_var()), - Self::Not(ty) => ty.has_undoable_linked_var(), - Self::Callable { param_ts, return_t } => { - param_ts.iter().any(|t| t.has_undoable_linked_var()) - || return_t.has_undoable_linked_var() - } Self::Subr(subr) => subr.has_undoable_linked_var(), - Self::Quantified(quant) => quant.has_undoable_linked_var(), - Self::Record(r) => r.values().any(|t| t.has_undoable_linked_var()), - Self::NamedTuple(r) => r.iter().any(|(_, t)| t.has_undoable_linked_var()), - Self::Refinement(refine) => { - refine.t.has_undoable_linked_var() || refine.pred.has_undoable_linked_var() - } Self::Poly { params, .. } => params.iter().any(|tp| tp.has_undoable_linked_var()), - Self::Proj { lhs, .. } => lhs.has_undoable_linked_var(), Self::ProjCall { lhs, args, .. } => { lhs.has_undoable_linked_var() || args.iter().any(|tp| tp.has_undoable_linked_var()) } - Self::Structural(ty) => ty.has_undoable_linked_var(), - Self::Guard(guard) => guard.to.has_undoable_linked_var(), - Self::Bounded { sub, sup } => { - sub.has_undoable_linked_var() || sup.has_undoable_linked_var() - } - mono_type_pattern!() => false, + _ => self.has_type_satisfies(|t| t.has_undoable_linked_var()), } } @@ -3903,45 +3752,13 @@ impl Type { pub fn has_unbound_var(&self) -> bool { match self { Self::FreeVar(fv) => fv.has_unbound_var(), - Self::Ref(t) => t.has_unbound_var(), - Self::RefMut { before, after } => { - before.has_unbound_var() - || after.as_ref().map(|t| t.has_unbound_var()).unwrap_or(false) - } - Self::And(tys) => tys.iter().any(|t| t.has_unbound_var()), - Self::Or(tys) => tys.iter().any(|t| t.has_unbound_var()), - Self::Not(ty) => ty.has_unbound_var(), - Self::Callable { param_ts, return_t } => { - param_ts.iter().any(|t| t.has_unbound_var()) || return_t.has_unbound_var() - } - Self::Subr(subr) => { - subr.non_default_params - .iter() - .any(|pt| pt.typ().has_unbound_var()) - || subr - .var_params - .as_ref() - .map(|pt| pt.typ().has_unbound_var()) - .unwrap_or(false) - || subr.default_params.iter().any(|pt| { - pt.typ().has_unbound_var() - || pt.default_typ().is_some_and(|t| t.has_unbound_var()) - }) - || subr.return_t.has_unbound_var() - } - Self::Record(r) => r.values().any(|t| t.has_unbound_var()), - Self::NamedTuple(r) => r.iter().any(|(_, t)| t.has_unbound_var()), + Self::Subr(subr) => subr.has_unbound_var(), Self::Refinement(refine) => refine.t.has_unbound_var() || refine.pred.has_unbound_var(), - Self::Quantified(quant) => quant.has_unbound_var(), Self::Poly { params, .. } => params.iter().any(|p| p.has_unbound_var()), - Self::Proj { lhs, .. } => lhs.has_unbound_var(), Self::ProjCall { lhs, args, .. } => { lhs.has_unbound_var() || args.iter().any(|t| t.has_unbound_var()) } - Self::Structural(ty) => ty.has_unbound_var(), - Self::Guard(guard) => guard.to.has_unbound_var(), - Self::Bounded { sub, sup } => sub.has_unbound_var() || sup.has_unbound_var(), - mono_type_pattern!() => false, + _ => self.has_type_satisfies(|t| t.has_unbound_var()), } } diff --git a/crates/erg_compiler/ty/predicate.rs b/crates/erg_compiler/ty/predicate.rs index 1e6c489ca..d43fd61dc 100644 --- a/crates/erg_compiler/ty/predicate.rs +++ b/crates/erg_compiler/ty/predicate.rs @@ -656,7 +656,8 @@ impl Predicate { pub fn qvars(&self) -> Set<(Str, Constraint)> { match self { - Self::Value(_) | Self::Const(_) | Self::Failure => set! {}, + Self::Const(_) | Self::Failure => set! {}, + Self::Value(val) => val.qvars(), Self::Call { receiver, args, .. } => { let mut set = receiver.qvars(); for arg in args { @@ -680,9 +681,35 @@ impl Predicate { } } + pub fn has_type_satisfies(&self, f: impl Fn(&Type) -> bool + Copy) -> bool { + match self { + Self::Const(_) | Self::Failure => false, + Self::Value(val) => val.has_type_satisfies(f), + Self::Call { receiver, args, .. } => { + receiver.has_type_satisfies(f) || args.iter().any(|a| a.has_type_satisfies(f)) + } + Self::Attr { receiver, .. } => receiver.has_type_satisfies(f), + Self::Equal { rhs, .. } + | Self::GreaterEqual { rhs, .. } + | Self::LessEqual { rhs, .. } + | Self::NotEqual { rhs, .. } => rhs.has_type_satisfies(f), + Self::GeneralEqual { lhs, rhs } + | Self::GeneralLessEqual { lhs, rhs } + | Self::GeneralGreaterEqual { lhs, rhs } + | Self::GeneralNotEqual { lhs, rhs } => { + lhs.has_type_satisfies(f) || rhs.has_type_satisfies(f) + } + Self::Or(lhs, rhs) | Self::And(lhs, rhs) => { + lhs.has_type_satisfies(f) || rhs.has_type_satisfies(f) + } + Self::Not(pred) => pred.has_type_satisfies(f), + } + } + pub fn has_qvar(&self) -> bool { match self { - Self::Value(_) | Self::Const(_) | Self::Failure => false, + Self::Const(_) | Self::Failure => false, + Self::Value(val) => val.has_qvar(), Self::Call { receiver, args, .. } => { receiver.has_qvar() || args.iter().any(|a| a.has_qvar()) } @@ -702,7 +729,8 @@ impl Predicate { pub fn has_unbound_var(&self) -> bool { match self { - Self::Value(_) | Self::Const(_) | Self::Failure => false, + Self::Const(_) | Self::Failure => false, + Self::Value(val) => val.has_unbound_var(), Self::Call { receiver, args, .. } => { receiver.has_unbound_var() || args.iter().any(|a| a.has_unbound_var()) } @@ -724,7 +752,8 @@ impl Predicate { pub fn has_undoable_linked_var(&self) -> bool { match self { - Self::Value(_) | Self::Const(_) | Self::Failure => false, + Self::Const(_) | Self::Failure => false, + Self::Value(val) => val.has_undoable_linked_var(), Self::Call { receiver, args, .. } => { receiver.has_undoable_linked_var() || args.iter().any(|a| a.has_undoable_linked_var()) diff --git a/crates/erg_compiler/ty/typaram.rs b/crates/erg_compiler/ty/typaram.rs index d104c298b..3ec3aca90 100644 --- a/crates/erg_compiler/ty/typaram.rs +++ b/crates/erg_compiler/ty/typaram.rs @@ -1268,21 +1268,21 @@ impl TyParam { pub fn has_type_satisfies(&self, f: impl Fn(&Type) -> bool + Copy) -> bool { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().has_type_satisfies(f), - Self::FreeVar(fv) => fv.get_type().map_or(false, |t| t.has_type_satisfies(f)), - Self::Type(t) => t.has_type_satisfies(f), - Self::Erased(t) => t.has_type_satisfies(f), + Self::FreeVar(fv) => fv.get_type().map_or(false, |t| f(&t)), + Self::Type(t) => f(t), + Self::Erased(t) => f(t), Self::Proj { obj, .. } => obj.has_type_satisfies(f), Self::ProjCall { obj, args, .. } => { - obj.has_type_satisfies(f) || args.iter().any(|t| t.has_type_satisfies(f)) + obj.has_type_satisfies(f) || args.iter().any(|tp| tp.has_type_satisfies(f)) } - Self::List(ts) | Self::Tuple(ts) => ts.iter().any(|t| t.has_type_satisfies(f)), + Self::List(ts) | Self::Tuple(ts) => ts.iter().any(|tp| tp.has_type_satisfies(f)), Self::UnsizedList(elem) => elem.has_type_satisfies(f), - Self::Set(ts) => ts.iter().any(|t| t.has_type_satisfies(f)), + Self::Set(ts) => ts.iter().any(|tp| tp.has_type_satisfies(f)), Self::Dict(ts) => ts .iter() .any(|(k, v)| k.has_type_satisfies(f) || v.has_type_satisfies(f)), Self::Record(rec) | Self::DataClass { fields: rec, .. } => { - rec.iter().any(|(_, tp)| tp.has_type_satisfies(f)) + rec.values().any(|tp| tp.has_type_satisfies(f)) } Self::Lambda(lambda) => lambda.body.iter().any(|tp| tp.has_type_satisfies(f)), Self::UnaryOp { val, .. } => val.has_type_satisfies(f), diff --git a/crates/erg_compiler/ty/value.rs b/crates/erg_compiler/ty/value.rs index 878495016..994040acc 100644 --- a/crates/erg_compiler/ty/value.rs +++ b/crates/erg_compiler/ty/value.rs @@ -2080,7 +2080,7 @@ impl ValueObj { pub fn has_type_satisfies(&self, f: impl Fn(&Type) -> bool + Copy) -> bool { match self { - Self::Type(t) => t.typ().has_type_satisfies(f), + Self::Type(t) => f(t.typ()), Self::List(ts) | Self::Tuple(ts) => ts.iter().any(|t| t.has_type_satisfies(f)), Self::UnsizedList(elem) => elem.has_type_satisfies(f), Self::Set(ts) => ts.iter().any(|t| t.has_type_satisfies(f)), From df837d70d3fb39e2febfd8e57be528b1898f2c0a Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 17 Sep 2024 17:13:28 +0900 Subject: [PATCH 12/12] fix: sub-unification bug --- crates/erg_compiler/context/compare.rs | 1 + crates/erg_compiler/context/unify.rs | 88 ++++++++++++-------------- tests/should_err/and.er | 3 + tests/should_err/or.er | 3 + tests/should_ok/and.er | 7 ++ tests/should_ok/or.er | 7 ++ tests/test.rs | 20 ++++++ 7 files changed, 82 insertions(+), 47 deletions(-) create mode 100644 tests/should_err/and.er create mode 100644 tests/should_err/or.er create mode 100644 tests/should_ok/and.er create mode 100644 tests/should_ok/or.er diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index 5707af044..f6125066d 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -816,6 +816,7 @@ impl Context { } // Int or Str :> Str or Int == (Int :> Str && Str :> Int) || (Int :> Int && Str :> Str) == true // Int or Str or NoneType :> Str or Int + // Int or Str or NoneType :> Str or NoneType or Nat (Or(l), Or(r)) => r.iter().all(|r| l.iter().any(|l| self.supertype_of(l, r))), // not Nat :> not Int == true (Not(l), Not(r)) => self.subtype_of(l, r), diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index 3df47f5c5..fea7e0dcc 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -156,6 +156,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } Ok(()) } + // FIXME: This is not correct, we must visit all permutations of the types (And(l), And(r)) if l.len() == r.len() => { let mut r = r.clone(); for _ in 0..r.len() { @@ -1297,65 +1298,58 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { // self.sub_unify(&lsub, &union, loc, param_name)?; maybe_sup.update_tyvar(union, intersec, self.undoable, false); } + // TODO: Preferentially compare same-structure types (e.g. K(?T) <: K(?U)) (And(ltys), And(rtys)) => { - let mut rtys = rtys.clone(); - for _ in 0..rtys.len() { - if ltys - .iter() - .zip(rtys.iter()) - .all(|(l, r)| self.ctx.subtype_of(l, r)) - { - for (l, r) in ltys.iter().zip(rtys.iter()) { - self.sub_unify(l, r)?; + let mut ltys_ = ltys.clone(); + let mut rtys_ = rtys.clone(); + // Show and EqHash and T <: Eq and Show and Ord + // => EqHash and T <: Eq and Ord + for lty in ltys.iter() { + if let Some(idx) = rtys_.iter().position(|r| r == lty) { + rtys_.remove(idx); + let idx = ltys_.iter().position(|l| l == lty).unwrap(); + ltys_.remove(idx); + } + } + // EqHash and T <: Eq and Ord + for lty in ltys_.iter() { + // lty: EqHash + // rty: Eq, Ord + for rty in rtys_.iter() { + if self.ctx.subtype_of(lty, rty) { + self.sub_unify(lty, rty)?; + continue; } - return Ok(()); } - rtys.rotate_left(1); } - return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error( - self.ctx.cfg.input.clone(), - line!() as usize, - self.loc.loc(), - self.ctx.caused_by(), - self.param_name.as_ref().unwrap_or(&Str::ever("_")), - None, - maybe_sup, - maybe_sub, - self.ctx.get_candidates(maybe_sub), - self.ctx.get_simple_type_mismatch_hint(maybe_sup, maybe_sub), - ))); } + // TODO: Preferentially compare same-structure types (e.g. K(?T) <: K(?U)) + // Nat or Str or NoneType <: NoneType or ?T or Int + // => Str <: ?T // (Int or ?T) <: (?U or Int) // OK: (Int <: Int); (?T <: ?U) // NG: (Int <: ?U); (?T <: Int) (Or(ltys), Or(rtys)) => { - let ltys = ltys.to_vec(); - let mut rtys = rtys.to_vec(); - for _ in 0..rtys.len() { - if ltys - .iter() - .zip(rtys.iter()) - .all(|(l, r)| self.ctx.subtype_of(l, r)) - { - for (l, r) in ltys.iter().zip(rtys.iter()) { - self.sub_unify(l, r)?; + let mut ltys_ = ltys.clone(); + let mut rtys_ = rtys.clone(); + // Nat or T or Str <: Str or Int or NoneType + // => Nat or T <: Int or NoneType + for lty in ltys { + if rtys_.linear_remove(lty) { + ltys_.linear_remove(lty); + } + } + // Nat or T <: Int or NoneType + for lty in ltys_.iter() { + // lty: Nat + // rty: Int, NoneType + for rty in rtys_.iter() { + if self.ctx.subtype_of(lty, rty) { + self.sub_unify(lty, rty)?; + continue; } - return Ok(()); } - rtys.rotate_left(1); } - return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error( - self.ctx.cfg.input.clone(), - line!() as usize, - self.loc.loc(), - self.ctx.caused_by(), - self.param_name.as_ref().unwrap_or(&Str::ever("_")), - None, - maybe_sup, - maybe_sub, - self.ctx.get_candidates(maybe_sub), - self.ctx.get_simple_type_mismatch_hint(maybe_sup, maybe_sub), - ))); } // NG: Nat <: ?T or Int ==> Nat or Int (?T = Nat) // OK: Nat <: ?T or Int ==> ?T or Int diff --git a/tests/should_err/and.er b/tests/should_err/and.er new file mode 100644 index 000000000..a16daf02c --- /dev/null +++ b/tests/should_err/and.er @@ -0,0 +1,3 @@ +a as Eq and Hash and Show and Add(Str) = "a" +f _: Ord and Eq and Show and Hash = None +f a # ERR diff --git a/tests/should_err/or.er b/tests/should_err/or.er new file mode 100644 index 000000000..20e8937da --- /dev/null +++ b/tests/should_err/or.er @@ -0,0 +1,3 @@ +a as Int or Str or NoneType = 1 +f _: Nat or NoneType or Str = None +f a # ERR diff --git a/tests/should_ok/and.er b/tests/should_ok/and.er new file mode 100644 index 000000000..4f3bee0d7 --- /dev/null +++ b/tests/should_ok/and.er @@ -0,0 +1,7 @@ +a as Eq and Hash and Show = 1 +f _: Eq and Show and Hash = None +f a + +b as Eq and Hash and Ord and Show = 1 +g _: Ord and Eq and Show and Hash = None +g b diff --git a/tests/should_ok/or.er b/tests/should_ok/or.er new file mode 100644 index 000000000..2e36f19b0 --- /dev/null +++ b/tests/should_ok/or.er @@ -0,0 +1,7 @@ +a as Nat or Str or NoneType = 1 +f _: Int or NoneType or Str = None +f a + +b as Nat or Str or NoneType or Bool = 1 +g _: Int or NoneType or Bool or Str = None +g b diff --git a/tests/test.rs b/tests/test.rs index 3259eee09..4ed20fb92 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -16,6 +16,11 @@ fn exec_advanced_type_spec() -> Result<(), ()> { expect_success("tests/should_ok/advanced_type_spec.er", 5) } +#[test] +fn exec_and() -> Result<(), ()> { + expect_success("tests/should_ok/and.er", 0) +} + #[test] fn exec_args_expansion() -> Result<(), ()> { expect_success("tests/should_ok/args_expansion.er", 0) @@ -327,6 +332,11 @@ fn exec_operators() -> Result<(), ()> { expect_success("tests/should_ok/operators.er", 0) } +#[test] +fn exec_or() -> Result<(), ()> { + expect_success("tests/should_ok/or.er", 0) +} + #[test] fn exec_patch() -> Result<(), ()> { expect_success("examples/patch.er", 0) @@ -527,6 +537,11 @@ fn exec_list_member_err() -> Result<(), ()> { expect_failure("tests/should_err/list_member.er", 0, 3) } +#[test] +fn exec_and_err() -> Result<(), ()> { + expect_failure("tests/should_err/and.er", 0, 1) +} + #[test] fn exec_as() -> Result<(), ()> { expect_failure("tests/should_err/as.er", 0, 6) @@ -634,6 +649,11 @@ fn exec_move_check() -> Result<(), ()> { expect_failure("examples/move_check.er", 1, 1) } +#[test] +fn exec_or_err() -> Result<(), ()> { + expect_failure("tests/should_err/or.er", 0, 1) +} + #[test] fn exec_poly_type_spec_err() -> Result<(), ()> { expect_failure("tests/should_err/poly_type_spec.er", 0, 3)