diff --git a/crates/els/completion.rs b/crates/els/completion.rs index 47c465110..13cb818ff 100644 --- a/crates/els/completion.rs +++ b/crates/els/completion.rs @@ -42,24 +42,30 @@ fn comp_item_kind(t: &Type, muty: Mutability) -> CompletionItemKind { Type::Subr(_) | Type::Quantified(_) => CompletionItemKind::FUNCTION, Type::ClassType => CompletionItemKind::CLASS, Type::TraitType => CompletionItemKind::INTERFACE, - Type::Or(l, r) => { - let l = comp_item_kind(l, muty); - let r = comp_item_kind(r, muty); - if l == r { - l + Type::Or(tys) => { + let fst = comp_item_kind(tys.iter().next().unwrap(), muty); + if tys + .iter() + .map(|t| comp_item_kind(t, muty)) + .all(|k| k == fst) + { + fst } else if muty.is_const() { CompletionItemKind::CONSTANT } else { CompletionItemKind::VARIABLE } } - Type::And(l, r) => { - let l = comp_item_kind(l, muty); - let r = comp_item_kind(r, muty); - if l == CompletionItemKind::VARIABLE { - r + Type::And(tys) => { + for k in tys.iter().map(|t| comp_item_kind(t, muty)) { + if k != CompletionItemKind::VARIABLE { + return k; + } + } + if muty.is_const() { + CompletionItemKind::CONSTANT } else { - l + CompletionItemKind::VARIABLE } } Type::Refinement(r) => comp_item_kind(&r.t, muty), diff --git a/crates/erg_common/set.rs b/crates/erg_common/set.rs index c925cf1e9..8a2fe180a 100644 --- a/crates/erg_common/set.rs +++ b/crates/erg_common/set.rs @@ -382,6 +382,20 @@ impl Set { self.insert(other); self } + + /// ``` + /// # use erg_common::set; + /// assert_eq!(set!{1, 2}.product(&set!{3, 4}), set!{(&1, &3), (&1, &4), (&2, &3), (&2, &4)}); + /// ``` + pub fn product<'l, 'r, U: Hash + Eq>(&'l self, other: &'r Set) -> Set<(&'l T, &'r U)> { + let mut res = set! {}; + for x in self.iter() { + for y in other.iter() { + res.insert((x, y)); + } + } + res + } } impl Set { diff --git a/crates/erg_common/traits.rs b/crates/erg_common/traits.rs index 7c0c84bf1..168a268f7 100644 --- a/crates/erg_common/traits.rs +++ b/crates/erg_common/traits.rs @@ -1407,6 +1407,8 @@ impl Immutable for &T {} impl Immutable for Option {} impl Immutable for Vec {} impl Immutable for [T] {} +impl Immutable for (T, U) {} +impl Immutable for (T, U, V) {} impl Immutable for Box {} impl Immutable for std::rc::Rc {} impl Immutable for std::sync::Arc {} diff --git a/crates/erg_common/triple.rs b/crates/erg_common/triple.rs index a743d43ab..a3626d3d3 100644 --- a/crates/erg_common/triple.rs +++ b/crates/erg_common/triple.rs @@ -18,6 +18,16 @@ impl fmt::Display for Triple { } impl Triple { + pub const fn is_ok(&self) -> bool { + matches!(self, Triple::Ok(_)) + } + pub const fn is_err(&self) -> bool { + matches!(self, Triple::Err(_)) + } + pub const fn is_none(&self) -> bool { + matches!(self, Triple::None) + } + pub fn none_then(self, err: E) -> Result { match self { Triple::None => Err(err), diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index 8e3dc2908..911e53757 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -754,9 +754,7 @@ impl Context { self.structural_supertype_of(&l, rhs) } // {1, 2, 3} :> {1} or {2, 3} == true - (Refinement(_refine), Or(l, r)) => { - self.supertype_of(lhs, l) && self.supertype_of(lhs, r) - } + (Refinement(_refine), Or(tys)) => tys.iter().all(|ty| self.supertype_of(lhs, ty)), // ({I: Int | True} :> Int) == true // {N: Nat | ...} :> Int) == false // ({I: Int | I >= 0} :> Int) == false @@ -808,41 +806,20 @@ impl Context { self.sub_unify(&inst, l, &(), None).is_ok() } // Int or Str :> Str or Int == (Int :> Str && Str :> Int) || (Int :> Int && Str :> Str) == true - (Or(l_1, l_2), Or(r_1, r_2)) => { - if l_1.is_union_type() && self.supertype_of(l_1, rhs) { - return true; - } - if l_2.is_union_type() && self.supertype_of(l_2, rhs) { - return true; - } - (self.supertype_of(l_1, r_1) && self.supertype_of(l_2, r_2)) - || (self.supertype_of(l_1, r_2) && self.supertype_of(l_2, r_1)) - } + // Int or Str or NoneType :> Str or Int + (Or(l), Or(r)) => r.iter().all(|r| l.iter().any(|l| self.supertype_of(l, r))), // not Nat :> not Int == true (Not(l), Not(r)) => self.subtype_of(l, r), // (Int or Str) :> Nat == Int :> Nat || Str :> Nat == true // (Num or Show) :> Show == Num :> Show || Show :> Num == true - (Or(l_or, r_or), rhs) => self.supertype_of(l_or, rhs) || self.supertype_of(r_or, rhs), + (Or(ors), rhs) => ors.iter().any(|or| self.supertype_of(or, rhs)), // Int :> (Nat or Str) == Int :> Nat && Int :> Str == false - (lhs, Or(l_or, r_or)) => self.supertype_of(lhs, l_or) && self.supertype_of(lhs, r_or), - (And(l_1, l_2), And(r_1, r_2)) => { - if l_1.is_intersection_type() && self.supertype_of(l_1, rhs) { - return true; - } - if l_2.is_intersection_type() && self.supertype_of(l_2, rhs) { - return true; - } - (self.supertype_of(l_1, r_1) && self.supertype_of(l_2, r_2)) - || (self.supertype_of(l_1, r_2) && self.supertype_of(l_2, r_1)) - } + (lhs, Or(ors)) => ors.iter().all(|or| self.supertype_of(lhs, or)), + (And(l), And(r)) => r.iter().any(|r| l.iter().all(|l| self.supertype_of(l, r))), // (Num and Show) :> Show == false - (And(l_and, r_and), rhs) => { - self.supertype_of(l_and, rhs) && self.supertype_of(r_and, rhs) - } + (And(ands), rhs) => ands.iter().all(|and| self.supertype_of(and, rhs)), // Show :> (Num and Show) == true - (lhs, And(l_and, r_and)) => { - self.supertype_of(lhs, l_and) || self.supertype_of(lhs, r_and) - } + (lhs, And(ands)) => ands.iter().any(|and| self.supertype_of(lhs, and)), // Not(Eq) :> Float == !(Eq :> Float) == true (Not(_), Obj) => false, (Not(l), rhs) => !self.supertype_of(l, rhs), @@ -914,18 +891,18 @@ impl Context { Type::NamedTuple(fields) => fields.iter().cloned().collect(), Type::Refinement(refine) => self.fields(&refine.t), Type::Structural(t) => self.fields(t), - Type::Or(l, r) => { - let l_fields = self.fields(l); - let r_fields = self.fields(r); - let l_field_names = l_fields.keys().collect::>(); - let r_field_names = r_fields.keys().collect::>(); - let field_names = l_field_names.intersection(&r_field_names); + Type::Or(tys) => { + let or_fields = tys.iter().map(|t| self.fields(t)).collect::>(); + let field_names = or_fields + .iter() + .flat_map(|fs| fs.keys()) + .collect::>(); let mut fields = Dict::new(); - for (name, l_t, r_t) in field_names + for (name, tys) in field_names .iter() - .map(|&name| (name, &l_fields[name], &r_fields[name])) + .map(|&name| (name, or_fields.iter().filter_map(|fields| fields.get(name)))) { - let union = self.union(l_t, r_t); + let union = tys.fold(Never, |acc, ty| self.union(&acc, ty)); fields.insert(name.clone(), union); } fields @@ -1408,6 +1385,8 @@ impl Context { /// union({ .a = Int }, { .a = Str }) == { .a = Int or Str } /// union({ .a = Int }, { .a = Int; .b = Int }) == { .a = Int } or { .a = Int; .b = Int } # not to lost `b` information /// union((A and B) or C) == (A or C) and (B or C) + /// union(Never, Int) == Int + /// union(Obj, Int) == Obj /// ``` pub(crate) fn union(&self, lhs: &Type, rhs: &Type) -> Type { if lhs == rhs { @@ -1470,10 +1449,9 @@ impl Context { (Some(sub), Some(sup)) => bounded(sub.clone(), sup.clone()), _ => self.simple_union(lhs, rhs), }, - (other, or @ Or(_, _)) | (or @ Or(_, _), other) => self.union_add(or, other), + (other, or @ Or(_)) | (or @ Or(_), other) => self.union_add(or, other), // (A and B) or C ==> (A or C) and (B or C) - (and_t @ And(_, _), other) | (other, and_t @ And(_, _)) => { - let ands = and_t.ands(); + (And(ands), other) | (other, And(ands)) => { let mut t = Type::Obj; for branch in ands.iter() { let union = self.union(branch, other); @@ -1657,6 +1635,12 @@ impl Context { /// Returns intersection of two types (`A and B`). /// If `A` and `B` have a subtype relationship, it is equal to `min(A, B)`. + /// ```erg + /// intersection(Nat, Int) == Nat + /// intersection(Int, Str) == Never + /// intersection(Obj, Int) == Int + /// intersection(Never, Int) == Never + /// ``` pub(crate) fn intersection(&self, lhs: &Type, rhs: &Type) -> Type { if lhs == rhs { return lhs.clone(); @@ -1687,12 +1671,9 @@ impl Context { (_, Not(r)) => self.diff(lhs, r), (Not(l), _) => self.diff(rhs, l), // A and B and A == A and B - (other, and @ And(_, _)) | (and @ And(_, _), other) => { - self.intersection_add(and, other) - } + (other, and @ And(_)) | (and @ And(_), other) => self.intersection_add(and, other), // (A or B) and C == (A and C) or (B and C) - (or_t @ Or(_, _), other) | (other, or_t @ Or(_, _)) => { - let ors = or_t.ors(); + (Or(ors), other) | (other, Or(ors)) => { if ors.iter().any(|t| t.has_unbound_var()) { return self.simple_intersection(lhs, rhs); } @@ -1827,21 +1808,21 @@ impl Context { fn narrow_type_by_pred(&self, t: Type, pred: &Predicate) -> Type { match (t, pred) { ( - Type::Or(l, r), + Type::Or(tys), Predicate::NotEqual { rhs: TyParam::Value(v), .. }, ) if v.is_none() => { - let l = self.narrow_type_by_pred(*l, pred); - let r = self.narrow_type_by_pred(*r, pred); - if l.is_nonetype() { - r - } else if r.is_nonetype() { - l - } else { - or(l, r) + let mut new_tys = Set::new(); + for ty in tys { + let ty = self.narrow_type_by_pred(ty, pred); + if ty.is_nonelike() { + continue; + } + new_tys.insert(ty); } + Type::checked_or(new_tys) } (Type::Refinement(refine), _) => { let t = self.narrow_type_by_pred(*refine.t, pred); @@ -1983,8 +1964,12 @@ impl Context { guard.target.clone(), self.complement(&guard.to), )), - Or(l, r) => self.intersection(&self.complement(l), &self.complement(r)), - And(l, r) => self.union(&self.complement(l), &self.complement(r)), + Or(ors) => ors + .iter() + .fold(Obj, |l, r| self.intersection(&l, &self.complement(r))), + And(ands) => ands + .iter() + .fold(Never, |l, r| self.union(&l, &self.complement(r))), other => not(other.clone()), } } @@ -2002,7 +1987,14 @@ impl Context { match lhs { Type::FreeVar(fv) if fv.is_linked() => self.diff(&fv.crack(), rhs), // Type::And(l, r) => self.intersection(&self.diff(l, rhs), &self.diff(r, rhs)), - Type::Or(l, r) => self.union(&self.diff(l, rhs), &self.diff(r, rhs)), + Type::Or(tys) => { + let mut new_tys = vec![]; + for ty in tys { + let diff = self.diff(ty, rhs); + new_tys.push(diff); + } + new_tys.into_iter().fold(Never, |l, r| self.union(&l, &r)) + } _ => lhs.clone(), } } diff --git a/crates/erg_compiler/context/eval.rs b/crates/erg_compiler/context/eval.rs index 60e1e64f3..3dce1c24f 100644 --- a/crates/erg_compiler/context/eval.rs +++ b/crates/erg_compiler/context/eval.rs @@ -2285,44 +2285,42 @@ impl Context { Err((t, errs)) } } - Type::And(l, r) => { - let l = match self.eval_t_params(*l, level, t_loc) { - Ok(l) => l, - Err((l, es)) => { - errs.extend(es); - l - } - }; - let r = match self.eval_t_params(*r, level, t_loc) { - Ok(r) => r, - Err((r, es)) => { - errs.extend(es); - r + Type::And(ands) => { + let mut new_ands = set! {}; + for and in ands.into_iter() { + match self.eval_t_params(and, level, t_loc) { + Ok(and) => { + new_ands.insert(and); + } + Err((and, es)) => { + new_ands.insert(and); + errs.extend(es); + } } - }; - let intersec = self.intersection(&l, &r); + } + let intersec = new_ands + .into_iter() + .fold(Type::Obj, |l, r| self.intersection(&l, &r)); if errs.is_empty() { Ok(intersec) } else { Err((intersec, errs)) } } - Type::Or(l, r) => { - let l = match self.eval_t_params(*l, level, t_loc) { - Ok(l) => l, - Err((l, es)) => { - errs.extend(es); - l - } - }; - let r = match self.eval_t_params(*r, level, t_loc) { - Ok(r) => r, - Err((r, es)) => { - errs.extend(es); - r + Type::Or(ors) => { + let mut new_ors = set! {}; + for or in ors.into_iter() { + match self.eval_t_params(or, level, t_loc) { + Ok(or) => { + new_ors.insert(or); + } + Err((or, es)) => { + new_ors.insert(or); + errs.extend(es); + } } - }; - let union = self.union(&l, &r); + } + let union = new_ors.into_iter().fold(Never, |l, r| self.union(&l, &r)); if errs.is_empty() { Ok(union) } else { diff --git a/crates/erg_compiler/context/generalize.rs b/crates/erg_compiler/context/generalize.rs index 13a0049ad..d8899e701 100644 --- a/crates/erg_compiler/context/generalize.rs +++ b/crates/erg_compiler/context/generalize.rs @@ -289,17 +289,21 @@ impl Generalizer { } proj_call(lhs, attr_name, args) } - And(l, r) => { - let l = self.generalize_t(*l, uninit); - let r = self.generalize_t(*r, uninit); + And(ands) => { // not `self.intersection` because types are generalized - and(l, r) + let ands = ands + .into_iter() + .map(|t| self.generalize_t(t, uninit)) + .collect(); + Type::checked_and(ands) } - Or(l, r) => { - let l = self.generalize_t(*l, uninit); - let r = self.generalize_t(*r, uninit); + Or(ors) => { // not `self.union` because types are generalized - or(l, r) + let ors = ors + .into_iter() + .map(|t| self.generalize_t(t, uninit)) + .collect(); + Type::checked_or(ors) } Not(l) => not(self.generalize_t(*l, uninit)), Structural(ty) => { @@ -1045,15 +1049,23 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> { let pred = self.deref_pred(*refine.pred)?; Ok(refinement(refine.var, t, pred)) } - And(l, r) => { - let l = self.deref_tyvar(*l)?; - let r = self.deref_tyvar(*r)?; - Ok(self.ctx.intersection(&l, &r)) + And(ands) => { + let mut new_ands = vec![]; + for t in ands.into_iter() { + new_ands.push(self.deref_tyvar(t)?); + } + Ok(new_ands + .into_iter() + .fold(Type::Obj, |acc, t| self.ctx.intersection(&acc, &t))) } - Or(l, r) => { - let l = self.deref_tyvar(*l)?; - let r = self.deref_tyvar(*r)?; - Ok(self.ctx.union(&l, &r)) + Or(ors) => { + let mut new_ors = vec![]; + for t in ors.into_iter() { + new_ors.push(self.deref_tyvar(t)?); + } + Ok(new_ors + .into_iter() + .fold(Type::Never, |acc, t| self.ctx.union(&acc, &t))) } Not(ty) => { let ty = self.deref_tyvar(*ty)?; @@ -1733,22 +1745,33 @@ impl Context { /// ``` pub(crate) fn squash_tyvar(&self, typ: Type) -> Type { match typ { - Or(l, r) => { - let l = self.squash_tyvar(*l); - let r = self.squash_tyvar(*r); + Or(tys) => { + let new_tys = tys + .into_iter() + .map(|t| self.squash_tyvar(t)) + .collect::>(); + let mut union = Never; // REVIEW: - if l.is_unnamed_unbound_var() && r.is_unnamed_unbound_var() { - match (self.subtype_of(&l, &r), self.subtype_of(&r, &l)) { - (true, true) | (true, false) => { - let _ = self.sub_unify(&l, &r, &(), None); + if new_tys.iter().all(|t| t.is_unnamed_unbound_var()) { + for ty in new_tys.iter() { + if union == Never { + union = ty.clone(); + continue; } - (false, true) => { - let _ = self.sub_unify(&r, &l, &(), None); + match (self.subtype_of(&union, ty), self.subtype_of(&union, ty)) { + (true, true) | (true, false) => { + let _ = self.sub_unify(&union, ty, &(), None); + } + (false, true) => { + let _ = self.sub_unify(ty, &union, &(), None); + } + _ => {} } - _ => {} } } - self.union(&l, &r) + new_tys + .into_iter() + .fold(Never, |acc, t| self.union(&acc, &t)) } FreeVar(ref fv) if fv.constraint_is_sandwiched() => { let (sub_t, super_t) = fv.get_subsup().unwrap(); diff --git a/crates/erg_compiler/context/hint.rs b/crates/erg_compiler/context/hint.rs index e7826d022..b7b795224 100644 --- a/crates/erg_compiler/context/hint.rs +++ b/crates/erg_compiler/context/hint.rs @@ -116,9 +116,12 @@ impl Context { return Some(hint); } } - (Type::And(l, r), found) => { - let left = self.readable_type(l.as_ref().clone()); - let right = self.readable_type(r.as_ref().clone()); + (Type::And(tys), found) if tys.len() == 2 => { + let mut iter = tys.iter(); + let l = iter.next().unwrap(); + let r = iter.next().unwrap(); + let left = self.readable_type(l.clone()); + let right = self.readable_type(r.clone()); if self.supertype_of(l, found) { let msg = switch_lang!( "japanese" => format!("型{found}は{left}のサブタイプですが、{right}のサブタイプではありません"), diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index 70ea0c5be..44aa46b3e 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -1046,35 +1046,35 @@ impl Context { } Type::Structural(t) => self.get_attr_info_from_attributive(t, ident, namespace), // TODO: And - Type::Or(l, r) => { - let l_info = self.get_attr_info_from_attributive(l, ident, namespace); - let r_info = self.get_attr_info_from_attributive(r, ident, namespace); - match (l_info, r_info) { - (Triple::Ok(l), Triple::Ok(r)) => { - let res = self.union(&l.t, &r.t); - let vis = if l.vis.is_public() && r.vis.is_public() { - Visibility::DUMMY_PUBLIC - } else { - Visibility::DUMMY_PRIVATE - }; - let vi = VarInfo::new( - res, - l.muty, - vis, - l.kind, - l.comptime_decos, - l.ctx, - l.py_name, - l.def_loc, - ); - Triple::Ok(vi) - } - (Triple::Ok(_), Triple::Err(e)) | (Triple::Err(e), Triple::Ok(_)) => { - Triple::Err(e) + Type::Or(tys) => { + let mut info = Triple::::None; + for ty in tys { + match ( + self.get_attr_info_from_attributive(ty, ident, namespace), + &info, + ) { + (Triple::Ok(vi), Triple::Ok(vi_)) => { + let res = self.union(&vi.t, &vi_.t); + let vis = if vi.vis.is_public() && vi_.vis.is_public() { + Visibility::DUMMY_PUBLIC + } else { + Visibility::DUMMY_PRIVATE + }; + let vi = VarInfo { t: res, vis, ..vi }; + info = Triple::Ok(vi); + } + (Triple::Ok(vi), Triple::None) => { + info = Triple::Ok(vi); + } + (Triple::Err(err), _) => { + info = Triple::Err(err); + break; + } + (Triple::None, _) => {} + (_, Triple::Err(_)) => unreachable!(), } - (Triple::Err(e1), Triple::Err(_e2)) => Triple::Err(e1), - _ => Triple::None, } + info } _other => Triple::None, } @@ -1952,7 +1952,7 @@ impl Context { res } } - Type::And(_, _) => { + Type::And(_) => { let instance = self.resolve_overload( obj, instance.clone(), @@ -3012,32 +3012,30 @@ impl Context { self.get_nominal_super_type_ctxs(&Type) } } - Type::And(l, r) => { - match ( - self.get_nominal_super_type_ctxs(l), - self.get_nominal_super_type_ctxs(r), - ) { - // TODO: sort - (Some(l), Some(r)) => Some([l, r].concat()), - (Some(l), None) => Some(l), - (None, Some(r)) => Some(r), - (None, None) => None, - } - } - // TODO - Type::Or(l, r) => match (l.as_ref(), r.as_ref()) { - (Type::FreeVar(l), Type::FreeVar(r)) - if l.is_unbound_and_sandwiched() && r.is_unbound_and_sandwiched() => + Type::And(tys) => { + let mut acc = vec![]; + for ctxs in tys + .iter() + .filter_map(|t| self.get_nominal_super_type_ctxs(t)) { - let (_lsub, lsup) = l.get_subsup().unwrap(); - let (_rsub, rsup) = r.get_subsup().unwrap(); - self.get_nominal_super_type_ctxs(&self.union(&lsup, &rsup)) + acc.extend(ctxs); } - (Type::Refinement(l), Type::Refinement(r)) if l.t == r.t => { - self.get_nominal_super_type_ctxs(&l.t) + if acc.is_empty() { + None + } else { + Some(acc) } - _ => self.get_nominal_type_ctx(&Obj).map(|ctx| vec![ctx]), - }, + } + Type::Or(tys) => { + let union = tys + .iter() + .fold(Never, |l, r| self.union(&l, &r.upper_bounded())); + if union.is_union_type() { + self.get_nominal_super_type_ctxs(&Obj) + } else { + self.get_nominal_super_type_ctxs(&union) + } + } _ => self .get_simple_nominal_super_type_ctxs(t) .map(|ctxs| ctxs.collect()), @@ -3231,7 +3229,7 @@ impl Context { .unwrap_or(self) .rec_local_get_mono_type("GenericNamedTuple"); } - Type::Or(_l, _r) => { + Type::Or(_) => { if let Some(ctx) = self.get_nominal_type_ctx(&poly("Or", vec![])) { return Some(ctx); } @@ -3366,26 +3364,27 @@ impl Context { match trait_ { // And(Add, Sub) == intersection({Int <: Add(Int), Bool <: Add(Bool) ...}, {Int <: Sub(Int), ...}) // == {Int <: Add(Int) and Sub(Int), ...} - Type::And(l, r) => { - let l_impls = self.get_trait_impls(l); - let l_base = Set::from_iter(l_impls.iter().map(|ti| &ti.sub_type)); - let r_impls = self.get_trait_impls(r); - let r_base = Set::from_iter(r_impls.iter().map(|ti| &ti.sub_type)); - let bases = l_base.intersection(&r_base); + Type::And(tys) => { + let impls = tys + .iter() + .flat_map(|ty| self.get_trait_impls(ty)) + .collect::>(); + let bases = impls.iter().map(|ti| &ti.sub_type); let mut isec = set! {}; - for base in bases.into_iter() { - let lti = l_impls.iter().find(|ti| &ti.sub_type == base).unwrap(); - let rti = r_impls.iter().find(|ti| &ti.sub_type == base).unwrap(); - let sup_trait = self.intersection(<i.sup_trait, &rti.sup_trait); - isec.insert(TraitImpl::new(lti.sub_type.clone(), sup_trait, None)); + for base in bases { + let base_impls = impls.iter().filter(|ti| ti.sub_type == *base); + let sup_trait = + base_impls.fold(Obj, |l, r| self.intersection(&l, &r.sup_trait)); + if sup_trait != Obj { + isec.insert(TraitImpl::new(base.clone(), sup_trait, None)); + } } isec } - Type::Or(l, r) => { - let l_impls = self.get_trait_impls(l); - let r_impls = self.get_trait_impls(r); + Type::Or(tys) => { // FIXME: - l_impls.union(&r_impls) + tys.iter() + .fold(set! {}, |acc, ty| acc.union(&self.get_trait_impls(ty))) } _ => self.get_simple_trait_impls(trait_), } @@ -3955,11 +3954,11 @@ impl Context { pub fn is_class(&self, typ: &Type) -> bool { match typ { - Type::And(_l, _r) => false, + Type::And(_) => false, Type::Never => true, Type::FreeVar(fv) if fv.is_linked() => self.is_class(&fv.crack()), Type::FreeVar(_) => false, - Type::Or(l, r) => self.is_class(l) && self.is_class(r), + Type::Or(tys) => tys.iter().all(|t| self.is_class(t)), Type::Proj { lhs, rhs } => self .get_proj_candidates(lhs, rhs) .iter() @@ -3982,7 +3981,8 @@ impl Context { Type::Never => false, Type::FreeVar(fv) if fv.is_linked() => self.is_class(&fv.crack()), Type::FreeVar(_) => false, - Type::And(l, r) | Type::Or(l, r) => self.is_trait(l) && self.is_trait(r), + Type::And(tys) => tys.iter().any(|t| self.is_trait(t)), + Type::Or(tys) => tys.iter().all(|t| self.is_trait(t)), Type::Proj { lhs, rhs } => self .get_proj_candidates(lhs, rhs) .iter() diff --git a/crates/erg_compiler/context/instantiate.rs b/crates/erg_compiler/context/instantiate.rs index 42d0c36f6..be4028548 100644 --- a/crates/erg_compiler/context/instantiate.rs +++ b/crates/erg_compiler/context/instantiate.rs @@ -953,15 +953,21 @@ impl Context { let t = self.instantiate_t_inner(*t, tmp_tv_cache, loc)?; Ok(t.structuralize()) } - And(l, r) => { - let l = self.instantiate_t_inner(*l, tmp_tv_cache, loc)?; - let r = self.instantiate_t_inner(*r, tmp_tv_cache, loc)?; - Ok(self.intersection(&l, &r)) + And(tys) => { + let mut new_tys = vec![]; + for ty in tys.iter().cloned() { + new_tys.push(self.instantiate_t_inner(ty, tmp_tv_cache, loc)?); + } + Ok(new_tys + .into_iter() + .fold(Obj, |l, r| self.intersection(&l, &r))) } - Or(l, r) => { - let l = self.instantiate_t_inner(*l, tmp_tv_cache, loc)?; - let r = self.instantiate_t_inner(*r, tmp_tv_cache, loc)?; - Ok(self.union(&l, &r)) + Or(tys) => { + let mut new_tys = vec![]; + for ty in tys.iter().cloned() { + new_tys.push(self.instantiate_t_inner(ty, tmp_tv_cache, loc)?); + } + Ok(new_tys.into_iter().fold(Never, |l, r| self.union(&l, &r))) } Not(ty) => { let ty = self.instantiate_t_inner(*ty, tmp_tv_cache, loc)?; @@ -998,10 +1004,12 @@ impl Context { let t = fv.crack().clone(); self.instantiate(t, callee) } - And(lhs, rhs) => { - let lhs = self.instantiate(*lhs, callee)?; - let rhs = self.instantiate(*rhs, callee)?; - Ok(lhs & rhs) + And(tys) => { + let tys = tys + .into_iter() + .map(|t| self.instantiate(t, callee)) + .collect::>>()?; + Ok(tys.into_iter().fold(Obj, |l, r| l & r)) } Quantified(quant) => { let mut tmp_tv_cache = TyVarCache::new(self.level, self); @@ -1028,22 +1036,16 @@ impl Context { )?; } } - Type::And(l, r) => { - if let Some(self_t) = l.self_t() { - self.sub_unify( - callee.ref_t(), - self_t, - callee, - Some(&Str::ever("self")), - )?; - } - if let Some(self_t) = r.self_t() { - self.sub_unify( - callee.ref_t(), - self_t, - callee, - Some(&Str::ever("self")), - )?; + Type::And(tys) => { + for ty in tys { + if let Some(self_t) = ty.self_t() { + self.sub_unify( + callee.ref_t(), + self_t, + callee, + Some(&Str::ever("self")), + )?; + } } } other => unreachable!("{other}"), @@ -1066,10 +1068,12 @@ impl Context { let t = fv.crack().clone(); self.instantiate_dummy(t) } - And(lhs, rhs) => { - let lhs = self.instantiate_dummy(*lhs)?; - let rhs = self.instantiate_dummy(*rhs)?; - Ok(lhs & rhs) + And(tys) => { + let tys = tys + .into_iter() + .map(|t| self.instantiate_dummy(t)) + .collect::>>()?; + Ok(tys.into_iter().fold(Obj, |l, r| l & r)) } Quantified(quant) => { let mut tmp_tv_cache = TyVarCache::new(self.level, self); diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index 3dc1d5519..10afe974c 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -155,17 +155,38 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } Ok(()) } - (Or(l, r), Or(l2, r2)) | (And(l, r), And(l2, r2)) => self - .occur(l, l2) - .and(self.occur(r, r2)) - .or(self.occur(l, r2).and(self.occur(r, l2))), - (lhs, Or(l, r)) | (lhs, And(l, r)) => { - self.occur_inner(lhs, l)?; - self.occur_inner(lhs, r) - } - (Or(l, r), rhs) | (And(l, r), rhs) => { - self.occur_inner(l, rhs)?; - self.occur_inner(r, rhs) + (Or(l), Or(r)) | (And(l), And(r)) if l.len() == r.len() => { + let l = l.to_vec(); + let mut r = r.to_vec(); + for _ in 0..r.len() { + if l.iter() + .zip(r.iter()) + .all(|(l, r)| self.occur(l, r).is_ok()) + { + return Ok(()); + } + r.rotate_left(1); + } + Err(TyCheckErrors::from(TyCheckError::subtyping_error( + self.ctx.cfg.input.clone(), + line!() as usize, + maybe_sub, + maybe_sup, + self.loc.loc(), + self.ctx.caused_by(), + ))) + } + (lhs, Or(tys)) | (lhs, And(tys)) => { + for ty in tys.iter() { + self.occur(lhs, ty)?; + } + Ok(()) + } + (Or(tys), rhs) | (And(tys), rhs) => { + for ty in tys.iter() { + self.occur(ty, rhs)?; + } + Ok(()) } _ => Ok(()), } @@ -266,13 +287,17 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } Ok(()) } - (lhs, Or(l, r)) | (lhs, And(l, r)) => { - self.occur_inner(lhs, l)?; - self.occur_inner(lhs, r) + (lhs, Or(tys)) | (lhs, And(tys)) => { + for ty in tys.iter() { + self.occur_inner(lhs, ty)?; + } + Ok(()) } - (Or(l, r), rhs) | (And(l, r), rhs) => { - self.occur_inner(l, rhs)?; - self.occur_inner(r, rhs) + (Or(tys), rhs) | (And(tys), rhs) => { + for ty in tys.iter() { + self.occur_inner(ty, rhs)?; + } + Ok(()) } _ => Ok(()), } @@ -1186,35 +1211,42 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { // (Int or ?T) <: (?U or Int) // OK: (Int <: Int); (?T <: ?U) // NG: (Int <: ?U); (?T <: Int) - (Or(l1, r1), Or(l2, r2)) | (And(l1, r1), And(l2, r2)) => { - if self.ctx.subtype_of(l1, l2) && self.ctx.subtype_of(r1, r2) { - let (l_sup, r_sup) = if !l1.is_unbound_var() - && !r2.is_unbound_var() - && self.ctx.subtype_of(l1, r2) + (Or(ltys), Or(rtys)) | (And(ltys), And(rtys)) => { + let lvars = ltys.to_vec(); + let mut rvars = rtys.to_vec(); + for _ in 0..rvars.len() { + if lvars + .iter() + .zip(rvars.iter()) + .all(|(l, r)| self.ctx.subtype_of(l, r)) { - (r2, l2) - } else { - (l2, r2) - }; - self.sub_unify(l1, l_sup)?; - self.sub_unify(r1, r_sup)?; - } else { - self.sub_unify(l1, r2)?; - self.sub_unify(r1, l2)?; + for (l, r) in ltys.iter().zip(rtys.iter()) { + self.sub_unify(l, r)?; + } + break; + } + rvars.rotate_left(1); } + return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error( + self.ctx.cfg.input.clone(), + line!() as usize, + self.loc.loc(), + self.ctx.caused_by(), + self.param_name.as_ref().unwrap_or(&Str::ever("_")), + None, + maybe_sup, + maybe_sub, + self.ctx.get_candidates(maybe_sub), + self.ctx.get_simple_type_mismatch_hint(maybe_sup, maybe_sub), + ))); } // NG: Nat <: ?T or Int ==> Nat or Int (?T = Nat) // OK: Nat <: ?T or Int ==> ?T or Int - (sub, Or(l, r)) - if l.is_unbound_var() - && !sub.is_unbound_var() - && !r.is_unbound_var() - && self.ctx.subtype_of(sub, r) => {} - (sub, Or(l, r)) - if r.is_unbound_var() - && !sub.is_unbound_var() - && !l.is_unbound_var() - && self.ctx.subtype_of(sub, l) => {} + (sub, Or(tys)) + if !sub.is_unbound_var() + && tys + .iter() + .any(|ty| !ty.is_unbound_var() && self.ctx.subtype_of(sub, ty)) => {} // e.g. Structural({ .method = (self: T) -> Int })/T (Structural(sub), FreeVar(sup_fv)) if sup_fv.is_unbound() && sub.contains_tvar(sup_fv) => {} @@ -1578,30 +1610,36 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { } } // (X or Y) <: Z is valid when X <: Z and Y <: Z - (Or(l, r), _) => { - self.sub_unify(l, maybe_sup)?; - self.sub_unify(r, maybe_sup)?; + (Or(tys), _) => { + for ty in tys { + self.sub_unify(ty, maybe_sup)?; + } } // X <: (Y and Z) is valid when X <: Y and X <: Z - (_, And(l, r)) => { - self.sub_unify(maybe_sub, l)?; - self.sub_unify(maybe_sub, r)?; + (_, And(tys)) => { + for ty in tys { + self.sub_unify(maybe_sub, ty)?; + } } // (X and Y) <: Z is valid when X <: Z or Y <: Z - (And(l, r), _) => { - if self.ctx.subtype_of(l, maybe_sup) { - self.sub_unify(l, maybe_sup)?; - } else { - self.sub_unify(r, maybe_sup)?; + (And(tys), _) => { + for ty in tys { + if self.ctx.subtype_of(ty, maybe_sup) { + self.sub_unify(ty, maybe_sup)?; + break; + } } + self.sub_unify(tys.iter().next().unwrap(), maybe_sup)?; } // X <: (Y or Z) is valid when X <: Y or X <: Z - (_, Or(l, r)) => { - if self.ctx.subtype_of(maybe_sub, l) { - self.sub_unify(maybe_sub, l)?; - } else { - self.sub_unify(maybe_sub, r)?; + (_, Or(tys)) => { + for ty in tys { + if self.ctx.subtype_of(maybe_sub, ty) { + self.sub_unify(maybe_sub, ty)?; + break; + } } + self.sub_unify(maybe_sub, tys.iter().next().unwrap())?; } (Ref(sub), Ref(sup)) => { self.sub_unify(sub, sup)?; @@ -1843,27 +1881,27 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> { /// ``` fn unify(&self, lhs: &Type, rhs: &Type) -> Option { match (lhs, rhs) { - (Type::Or(l, r), other) | (other, Type::Or(l, r)) => { - if let Some(t) = self.unify(l, other) { - return self.unify(&t, l); - } else if let Some(t) = self.unify(r, other) { - return self.unify(&t, l); + (Or(tys), other) | (other, Or(tys)) => { + for ty in tys { + if let Some(t) = self.unify(ty, other) { + return self.unify(&t, ty); + } } return None; } - (Type::FreeVar(fv), _) if fv.is_linked() => return self.unify(&fv.crack(), rhs), - (_, Type::FreeVar(fv)) if fv.is_linked() => return self.unify(lhs, &fv.crack()), + (FreeVar(fv), _) if fv.is_linked() => return self.unify(&fv.crack(), rhs), + (_, FreeVar(fv)) if fv.is_linked() => return self.unify(lhs, &fv.crack()), // TODO: unify(?T, ?U) ? - (Type::FreeVar(_), Type::FreeVar(_)) => {} - (Type::FreeVar(fv), _) if fv.constraint_is_sandwiched() => { + (FreeVar(_), FreeVar(_)) => {} + (FreeVar(fv), _) if fv.constraint_is_sandwiched() => { let sub = fv.get_sub()?; return self.unify(&sub, rhs); } - (_, Type::FreeVar(fv)) if fv.constraint_is_sandwiched() => { + (_, FreeVar(fv)) if fv.constraint_is_sandwiched() => { let sub = fv.get_sub()?; return self.unify(lhs, &sub); } - (Type::Refinement(lhs), Type::Refinement(rhs)) => { + (Refinement(lhs), Refinement(rhs)) => { if let Some(_union) = self.unify(&lhs.t, &rhs.t) { return Some(self.ctx.union_refinement(lhs, rhs).into()); } diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index 04f45d4c7..948a49de4 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -1502,9 +1502,10 @@ impl GenericASTLowerer { } _ => {} }, - Type::And(lhs, rhs) => { - self.push_guard(nth, kind, lhs); - self.push_guard(nth, kind, rhs); + Type::And(tys) => { + for ty in tys { + self.push_guard(nth, kind, ty); + } } _ => {} } diff --git a/crates/erg_compiler/ty/constructors.rs b/crates/erg_compiler/ty/constructors.rs index 3e78dc33c..eb9ffd61d 100644 --- a/crates/erg_compiler/ty/constructors.rs +++ b/crates/erg_compiler/ty/constructors.rs @@ -593,35 +593,11 @@ pub fn refinement(var: Str, t: Type, pred: Predicate) -> Type { } pub fn and(lhs: Type, rhs: Type) -> Type { - match (lhs, rhs) { - (Type::And(l, r), other) | (other, Type::And(l, r)) => { - if l.as_ref() == &other { - and(*r, other) - } else if r.as_ref() == &other { - and(*l, other) - } else { - Type::And(Box::new(Type::And(l, r)), Box::new(other)) - } - } - (Type::Obj, other) | (other, Type::Obj) => other, - (lhs, rhs) => Type::And(Box::new(lhs), Box::new(rhs)), - } + lhs & rhs } pub fn or(lhs: Type, rhs: Type) -> Type { - match (lhs, rhs) { - (Type::Or(l, r), other) | (other, Type::Or(l, r)) => { - if l.as_ref() == &other { - or(*r, other) - } else if r.as_ref() == &other { - or(*l, other) - } else { - Type::Or(Box::new(Type::Or(l, r)), Box::new(other)) - } - } - (Type::Never, other) | (other, Type::Never) => other, - (lhs, rhs) => Type::Or(Box::new(lhs), Box::new(rhs)), - } + lhs | rhs } pub fn ors(tys: impl IntoIterator) -> Type { diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index 7f545ed14..c5c2d9344 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -1410,8 +1410,8 @@ pub enum Type { Refinement(RefinementType), // e.g. |T: Type| T -> T Quantified(Box), - And(Box, Box), - Or(Box, Box), + And(Set), + Or(Set), Not(Box), // NOTE: It was found that adding a new variant above `Poly` may cause a subtyping bug, // possibly related to enum internal numbering, but the cause is unknown. @@ -1504,8 +1504,8 @@ impl PartialEq for Type { (Self::NamedTuple(lhs), Self::NamedTuple(rhs)) => lhs == rhs, (Self::Refinement(l), Self::Refinement(r)) => l == r, (Self::Quantified(l), Self::Quantified(r)) => l == r, - (Self::And(_, _), Self::And(_, _)) => self.ands().linear_eq(&other.ands()), - (Self::Or(_, _), Self::Or(_, _)) => self.ors().linear_eq(&other.ors()), + (Self::And(l), Self::And(r)) => l.linear_eq(r), + (Self::Or(l), Self::Or(r)) => l.linear_eq(r), (Self::Not(l), Self::Not(r)) => l == r, ( Self::Poly { @@ -1659,20 +1659,28 @@ impl LimitedDisplay for Type { write!(f, "|")?; quantified.limited_fmt(f, limit - 1) } - Self::And(lhs, rhs) => { - lhs.limited_fmt(f, limit - 1)?; - write!(f, " and ")?; - rhs.limited_fmt(f, limit - 1) + Self::And(ands) => { + for (i, ty) in ands.iter().enumerate() { + if i > 0 { + write!(f, " and ")?; + } + ty.limited_fmt(f, limit - 1)?; + } + Ok(()) + } + Self::Or(ors) => { + for (i, ty) in ors.iter().enumerate() { + if i > 0 { + write!(f, " or ")?; + } + ty.limited_fmt(f, limit - 1)?; + } + Ok(()) } Self::Not(ty) => { write!(f, "not ")?; ty.limited_fmt(f, limit - 1) } - Self::Or(lhs, rhs) => { - lhs.limited_fmt(f, limit - 1)?; - write!(f, " or ")?; - rhs.limited_fmt(f, limit - 1) - } Self::Poly { name, params } => { write!(f, "{name}(")?; if !DEBUG_MODE && self.is_module() { @@ -1852,14 +1860,40 @@ impl From> for Type { impl BitAnd for Type { type Output = Self; fn bitand(self, rhs: Self) -> Self::Output { - Self::And(Box::new(self), Box::new(rhs)) + match (self, rhs) { + (Self::And(l), Self::And(r)) => Self::And(l.union(&r)), + (Self::Obj, other) | (other, Self::Obj) => other, + (Self::Never, _) | (_, Self::Never) => Self::Never, + (Self::And(mut l), r) => { + l.insert(r); + Self::And(l) + } + (l, Self::And(mut r)) => { + r.insert(l); + Self::And(r) + } + (l, r) => Self::checked_and(set! {l, r}), + } } } impl BitOr for Type { type Output = Self; fn bitor(self, rhs: Self) -> Self::Output { - Self::Or(Box::new(self), Box::new(rhs)) + match (self, rhs) { + (Self::Or(l), Self::Or(r)) => Self::Or(l.union(&r)), + (Self::Obj, _) | (_, Self::Obj) => Self::Obj, + (Self::Never, other) | (other, Self::Never) => other, + (Self::Or(mut l), r) => { + l.insert(r); + Self::Or(l) + } + (l, Self::Or(mut r)) => { + r.insert(l); + Self::Or(r) + } + (l, r) => Self::checked_or(set! {l, r}), + } } } @@ -1974,17 +2008,7 @@ impl HasLevel for Type { .filter_map(|o| *o) .min() } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - let l = lhs - .level() - .unwrap_or(GENERIC_LEVEL) - .min(rhs.level().unwrap_or(GENERIC_LEVEL)); - if l == GENERIC_LEVEL { - None - } else { - Some(l) - } - } + Self::And(tys) | Self::Or(tys) => tys.iter().filter_map(|t| t.level()).min(), Self::Not(ty) => ty.level(), Self::Record(attrs) => attrs.values().filter_map(|t| t.level()).min(), Self::NamedTuple(attrs) => attrs.iter().filter_map(|(_, t)| t.level()).min(), @@ -2065,9 +2089,10 @@ impl HasLevel for Type { Self::Quantified(quant) => { quant.set_level(level); } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.set_level(level); - rhs.set_level(level); + Self::And(tys) | Self::Or(tys) => { + for t in tys.iter() { + t.set_level(level); + } } Self::Not(ty) => ty.set_level(level), Self::Record(attrs) => { @@ -2218,9 +2243,7 @@ impl StructuralEq for Type { (Self::Guard(l), Self::Guard(r)) => l.structural_eq(r), // NG: (l.structural_eq(l2) && r.structural_eq(r2)) // || (l.structural_eq(r2) && r.structural_eq(l2)) - (Self::And(_, _), Self::And(_, _)) => { - let self_ands = self.ands(); - let other_ands = other.ands(); + (Self::And(self_ands), Self::And(other_ands)) => { if self_ands.len() != other_ands.len() { return false; } @@ -2234,9 +2257,7 @@ impl StructuralEq for Type { } true } - (Self::Or(_, _), Self::Or(_, _)) => { - let self_ors = self.ors(); - let other_ors = other.ors(); + (Self::Or(self_ors), Self::Or(other_ors)) => { if self_ors.len() != other_ors.len() { return false; } @@ -2330,10 +2351,30 @@ impl Type { } } + pub fn checked_or(tys: Set) -> Self { + if tys.is_empty() { + panic!("tys is empty"); + } else if tys.len() == 1 { + tys.into_iter().next().unwrap() + } else { + Self::Or(tys) + } + } + + pub fn checked_and(tys: Set) -> Self { + if tys.is_empty() { + panic!("tys is empty"); + } else if tys.len() == 1 { + tys.into_iter().next().unwrap() + } else { + Self::And(tys) + } + } + pub fn quantify(self) -> Self { debug_assert!(self.is_subr(), "{self} is not subr"); match self { - Self::And(lhs, rhs) => lhs.quantify() & rhs.quantify(), + Self::And(tys) => Self::And(tys.into_iter().map(|t| t.quantify()).collect()), other => Self::Quantified(Box::new(other)), } } @@ -2431,7 +2472,7 @@ impl Type { Self::Quantified(t) => t.is_procedure(), Self::Subr(subr) if subr.kind == SubrKind::Proc => true, Self::Refinement(refine) => refine.t.is_procedure(), - Self::And(lhs, rhs) => lhs.is_procedure() && rhs.is_procedure(), + Self::And(tys) => tys.iter().any(|t| t.is_procedure()), _ => false, } } @@ -2449,6 +2490,7 @@ impl Type { name.ends_with('!') } Self::Refinement(refine) => refine.t.is_mut_type(), + Self::And(tys) => tys.iter().any(|t| t.is_mut_type()), _ => false, } } @@ -2467,6 +2509,7 @@ impl Type { Self::Poly { name, params, .. } if &name[..] == "Tuple" => params.is_empty(), Self::Refinement(refine) => refine.t.is_nonelike(), Self::Bounded { sup, .. } => sup.is_nonelike(), + Self::And(tys) => tys.iter().any(|t| t.is_nonelike()), _ => false, } } @@ -2516,7 +2559,7 @@ impl Type { pub fn is_union_type(&self) -> bool { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_union_type(), - Self::Or(_, _) => true, + Self::Or(_) => true, Self::Refinement(refine) => refine.t.is_union_type(), _ => false, } @@ -2549,7 +2592,7 @@ impl Type { pub fn is_intersection_type(&self) -> bool { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_intersection_type(), - Self::And(_, _) => true, + Self::And(_) => true, Self::Refinement(refine) => refine.t.is_intersection_type(), _ => false, } @@ -2563,11 +2606,11 @@ impl Type { fv.do_avoiding_recursion(|| sub.union_size().max(sup.union_size())) } // Or(Or(Int, Str), Nat) == 3 - Self::Or(l, r) => l.union_size() + r.union_size(), + Self::Or(tys) => tys.len(), Self::Refinement(refine) => refine.t.union_size(), Self::Ref(t) => t.union_size(), Self::RefMut { before, after: _ } => before.union_size(), - Self::And(lhs, rhs) => lhs.union_size().max(rhs.union_size()), + Self::And(tys) => tys.iter().map(|ty| ty.union_size()).max().unwrap_or(1), Self::Not(ty) => ty.union_size(), Self::Callable { param_ts, return_t } => param_ts .iter() @@ -2608,7 +2651,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_refinement(), Self::Refinement(_) => true, - Self::And(l, r) => l.is_refinement() && r.is_refinement(), + Self::And(tys) => tys.iter().any(|t| t.is_refinement()), _ => false, } } @@ -2617,6 +2660,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_singleton_refinement(), Self::Refinement(refine) => matches!(refine.pred.as_ref(), Predicate::Equal { .. }), + Self::And(tys) => tys.iter().any(|t| t.is_singleton_refinement()), _ => false, } } @@ -2626,6 +2670,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_record(), Self::Record(_) => true, Self::Refinement(refine) => refine.t.is_record(), + Self::And(tys) => tys.iter().any(|t| t.is_record()), _ => false, } } @@ -2639,6 +2684,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_erg_module(), Self::Refinement(refine) => refine.t.is_erg_module(), Self::Poly { name, .. } => &name[..] == "Module", + Self::And(tys) => tys.iter().any(|t| t.is_erg_module()), _ => false, } } @@ -2648,6 +2694,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_py_module(), Self::Refinement(refine) => refine.t.is_py_module(), Self::Poly { name, .. } => &name[..] == "PyModule", + Self::And(tys) => tys.iter().any(|t| t.is_py_module()), _ => false, } } @@ -2658,7 +2705,7 @@ impl Type { Self::Refinement(refine) => refine.t.is_method(), Self::Subr(subr) => subr.is_method(), Self::Quantified(quant) => quant.is_method(), - Self::And(l, r) => l.is_method() && r.is_method(), + Self::And(tys) => tys.iter().any(|t| t.is_method()), _ => false, } } @@ -2669,7 +2716,7 @@ impl Type { Self::Subr(_) => true, Self::Quantified(quant) => quant.is_subr(), Self::Refinement(refine) => refine.t.is_subr(), - Self::And(l, r) => l.is_subr() && r.is_subr(), + Self::And(tys) => tys.iter().any(|t| t.is_subr()), _ => false, } } @@ -2680,7 +2727,10 @@ impl Type { Self::Subr(subr) => Some(subr.kind), Self::Refinement(refine) => refine.t.subr_kind(), Self::Quantified(quant) => quant.subr_kind(), - Self::And(l, r) => l.subr_kind().and_then(|k| r.subr_kind().map(|k2| k | k2)), + Self::And(tys) => tys + .iter() + .filter_map(|t| t.subr_kind()) + .fold(None, |a, b| Some(a? | b)), _ => None, } } @@ -2690,7 +2740,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_quantified_subr(), Self::Quantified(_) => true, Self::Refinement(refine) => refine.t.is_quantified_subr(), - Self::And(l, r) => l.is_quantified_subr() && r.is_quantified_subr(), + Self::And(tys) => tys.iter().any(|t| t.is_quantified_subr()), _ => false, } } @@ -2727,6 +2777,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_iterable(), Self::Poly { name, .. } => &name[..] == "Iterable", Self::Refinement(refine) => refine.t.is_iterable(), + Self::And(tys) => tys.iter().any(|t| t.is_iterable()), _ => false, } } @@ -2817,6 +2868,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_structural(), Self::Structural(_) => true, Self::Refinement(refine) => refine.t.is_structural(), + Self::And(tys) => tys.iter().any(|t| t.is_structural()), _ => false, } } @@ -2826,6 +2878,7 @@ impl Type { Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_failure(), Self::Refinement(refine) => refine.t.is_failure(), Self::Failure => true, + Self::And(tys) => tys.iter().any(|t| t.is_failure()), _ => false, } } @@ -2902,8 +2955,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.contains_tvar(target) || args.iter().any(|t| t.contains_tvar(target)) } - Self::And(lhs, rhs) => lhs.contains_tvar(target) || rhs.contains_tvar(target), - Self::Or(lhs, rhs) => lhs.contains_tvar(target) || rhs.contains_tvar(target), + Self::And(tys) => tys.iter().any(|t| t.contains_tvar(target)), + Self::Or(tys) => tys.iter().any(|t| t.contains_tvar(target)), Self::Not(t) => t.contains_tvar(target), Self::Ref(t) => t.contains_tvar(target), Self::RefMut { before, after } => { @@ -2962,8 +3015,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.has_type_satisfies(f) || args.iter().any(|t| t.has_type_satisfies(f)) } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.has_type_satisfies(f) || rhs.has_type_satisfies(f) + Self::And(tys) | Self::Or(tys) => { + tys.iter().any(|t| t.has_type_satisfies(f)) } Self::Not(t) => t.has_type_satisfies(f), Self::Ref(t) => t.has_type_satisfies(f), @@ -3030,8 +3083,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.contains_type(target) || args.iter().any(|t| t.contains_type(target)) } - Self::And(lhs, rhs) => lhs.contains_type(target) || rhs.contains_type(target), - Self::Or(lhs, rhs) => lhs.contains_type(target) || rhs.contains_type(target), + Self::And(tys) => tys.iter().any(|t| t.contains_type(target)), + Self::Or(tys) => tys.iter().any(|t| t.contains_type(target)), Self::Not(t) => t.contains_type(target), Self::Ref(t) => t.contains_type(target), Self::RefMut { before, after } => { @@ -3068,8 +3121,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.contains_tp(target) || args.iter().any(|t| t.contains_tp(target)) } - Self::And(lhs, rhs) => lhs.contains_tp(target) || rhs.contains_tp(target), - Self::Or(lhs, rhs) => lhs.contains_tp(target) || rhs.contains_tp(target), + Self::And(tys) => tys.iter().any(|t| t.contains_tp(target)), + Self::Or(tys) => tys.iter().any(|t| t.contains_tp(target)), Self::Not(t) => t.contains_tp(target), Self::Ref(t) => t.contains_tp(target), Self::RefMut { before, after } => { @@ -3102,8 +3155,8 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.contains_value(target) || args.iter().any(|t| t.contains_value(target)) } - Self::And(lhs, rhs) => lhs.contains_value(target) || rhs.contains_value(target), - Self::Or(lhs, rhs) => lhs.contains_value(target) || rhs.contains_value(target), + Self::And(tys) => tys.iter().any(|t| t.contains_value(target)), + Self::Or(tys) => tys.iter().any(|t| t.contains_value(target)), Self::Not(t) => t.contains_value(target), Self::Ref(t) => t.contains_value(target), Self::RefMut { before, after } => { @@ -3146,9 +3199,7 @@ impl Type { Self::ProjCall { lhs, args, .. } => { lhs.contains_type(self) || args.iter().any(|t| t.contains_type(self)) } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.contains_type(self) || rhs.contains_type(self) - } + Self::And(tys) | Self::Or(tys) => tys.iter().any(|t| t.contains_type(self)), Self::Not(t) => t.contains_type(self), Self::Ref(t) => t.contains_type(self), Self::RefMut { before, after } => { @@ -3218,9 +3269,9 @@ impl Type { Self::Inf => Str::ever("Inf"), Self::NegInf => Str::ever("NegInf"), Self::Mono(name) => name.clone(), - Self::And(_, _) => Str::ever("And"), + Self::And(_) => Str::ever("And"), Self::Not(_) => Str::ever("Not"), - Self::Or(_, _) => Str::ever("Or"), + Self::Or(_) => Str::ever("Or"), Self::Ref(_) => Str::ever("Ref"), Self::RefMut { .. } => Str::ever("RefMut"), Self::Subr(SubrType { @@ -3310,7 +3361,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().contains_intersec(typ), Self::Refinement(refine) => refine.t.contains_intersec(typ), - Self::And(t1, t2) => t1.contains_intersec(typ) || t2.contains_intersec(typ), + Self::And(tys) => tys.iter().any(|t| t.contains_intersec(typ)), _ => self == typ, } } @@ -3319,7 +3370,16 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().union_pair(), Self::Refinement(refine) => refine.t.union_pair(), - Self::Or(t1, t2) => Some((*t1.clone(), *t2.clone())), + Self::Or(tys) if tys.len() == 2 => { + let mut iter = tys.iter(); + Some((iter.next().unwrap().clone(), iter.next().unwrap().clone())) + } + Self::Or(tys) => { + let mut iter = tys.iter(); + let t1 = iter.next().unwrap().clone(); + let t2 = iter.cloned().collect(); + Some((t1, Type::Or(t2))) + } _ => None, } } @@ -3329,7 +3389,7 @@ impl Type { match self { Type::FreeVar(fv) if fv.is_linked() => fv.crack().contains_union(typ), Type::Refinement(refine) => refine.t.contains_union(typ), - Type::Or(t1, t2) => t1.contains_union(typ) || t2.contains_union(typ), + Type::Or(tys) => tys.iter().any(|t| t.contains_union(typ)), _ => self == typ, } } @@ -3343,11 +3403,7 @@ impl Type { .into_iter() .map(|t| t.quantify()) .collect(), - Type::And(t1, t2) => { - let mut types = t1.intersection_types(); - types.extend(t2.intersection_types()); - types - } + Type::And(tys) => tys.iter().cloned().collect(), _ => vec![self.clone()], } } @@ -3423,9 +3479,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_unbound() => true, Self::FreeVar(fv) if fv.is_linked() => fv.crack().is_totally_unbound(), - Self::Or(t1, t2) | Self::And(t1, t2) => { - t1.is_totally_unbound() && t2.is_totally_unbound() - } + Self::Or(tys) | Self::And(tys) => tys.iter().all(|t| t.is_totally_unbound()), Self::Not(t) => t.is_totally_unbound(), _ => false, } @@ -3532,9 +3586,10 @@ impl Type { sub.destructive_coerce(); self.destructive_link(&sub); } - Type::And(l, r) | Type::Or(l, r) => { - l.destructive_coerce(); - r.destructive_coerce(); + Type::And(tys) | Type::Or(tys) => { + for t in tys { + t.destructive_coerce(); + } } Type::Not(l) => l.destructive_coerce(), Type::Poly { params, .. } => { @@ -3587,9 +3642,10 @@ impl Type { sub.undoable_coerce(list); self.undoable_link(&sub, list); } - Type::And(l, r) | Type::Or(l, r) => { - l.undoable_coerce(list); - r.undoable_coerce(list); + Type::And(tys) | Type::Or(tys) => { + for t in tys { + t.undoable_coerce(list); + } } Type::Not(l) => l.undoable_coerce(list), Type::Poly { params, .. } => { @@ -3648,7 +3704,9 @@ impl Type { .map(|t| t.qvars_inner()) .unwrap_or_else(|| set! {}), ), - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.qvars_inner().concat(rhs.qvars_inner()), + Self::And(tys) | Self::Or(tys) => tys + .iter() + .fold(set! {}, |acc, t| acc.concat(t.qvars_inner())), Self::Not(ty) => ty.qvars_inner(), Self::Callable { param_ts, return_t } => param_ts .iter() @@ -3716,7 +3774,7 @@ impl Type { Self::RefMut { before, after } => { before.has_qvar() || after.as_ref().map(|t| t.has_qvar()).unwrap_or(false) } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.has_qvar() || rhs.has_qvar(), + Self::And(tys) | Self::Or(tys) => tys.iter().any(|t| t.has_qvar()), Self::Not(ty) => ty.has_qvar(), Self::Callable { param_ts, return_t } => { param_ts.iter().any(|t| t.has_qvar()) || return_t.has_qvar() @@ -3769,9 +3827,7 @@ impl Type { .map(|t| t.has_undoable_linked_var()) .unwrap_or(false) } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.has_undoable_linked_var() || rhs.has_undoable_linked_var() - } + Self::And(tys) | Self::Or(tys) => tys.iter().any(|t| t.has_undoable_linked_var()), Self::Not(ty) => ty.has_undoable_linked_var(), Self::Callable { param_ts, return_t } => { param_ts.iter().any(|t| t.has_undoable_linked_var()) @@ -3810,9 +3866,7 @@ impl Type { before.has_unbound_var() || after.as_ref().map(|t| t.has_unbound_var()).unwrap_or(false) } - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - lhs.has_unbound_var() || rhs.has_unbound_var() - } + Self::And(tys) | Self::Or(tys) => tys.iter().any(|t| t.has_unbound_var()), Self::Not(ty) => ty.has_unbound_var(), Self::Callable { param_ts, return_t } => { param_ts.iter().any(|t| t.has_unbound_var()) || return_t.has_unbound_var() @@ -3867,7 +3921,7 @@ impl Type { Self::Refinement(refine) => refine.t.typarams_len(), // REVIEW: Self::Ref(_) | Self::RefMut { .. } => Some(1), - Self::And(_, _) | Self::Or(_, _) => Some(2), + Self::And(tys) | Self::Or(tys) => Some(tys.len()), Self::Not(_) => Some(1), Self::Subr(subr) => Some( subr.non_default_params.len() @@ -3933,9 +3987,7 @@ impl Type { Self::FreeVar(_unbound) => vec![], Self::Refinement(refine) => refine.t.typarams(), Self::Ref(t) | Self::RefMut { before: t, .. } => vec![TyParam::t(*t.clone())], - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => { - vec![TyParam::t(*lhs.clone()), TyParam::t(*rhs.clone())] - } + Self::And(tys) | Self::Or(tys) => tys.iter().cloned().map(TyParam::t).collect(), Self::Not(t) => vec![TyParam::t(*t.clone())], Self::Subr(subr) => subr.typarams(), Self::Quantified(quant) => quant.typarams(), @@ -4156,8 +4208,8 @@ impl Type { let r = r.iter().map(|(k, v)| (k.clone(), v.derefine())).collect(); Self::NamedTuple(r) } - Self::And(l, r) => l.derefine() & r.derefine(), - Self::Or(l, r) => l.derefine() | r.derefine(), + Self::And(tys) => Self::checked_and(tys.iter().map(|t| t.derefine()).collect()), + Self::Or(tys) => Self::checked_or(tys.iter().map(|t| t.derefine()).collect()), Self::Not(ty) => !ty.derefine(), Self::Proj { lhs, rhs } => lhs.derefine().proj(rhs.clone()), Self::ProjCall { @@ -4224,22 +4276,28 @@ impl Type { }); self } - Self::And(l, r) => { - if l.addr_eq(target) { - return r.eliminate_subsup(target); - } else if r.addr_eq(target) { - return l.eliminate_subsup(target); - } - l.eliminate_subsup(target) & r.eliminate_subsup(target) - } - Self::Or(l, r) => { - if l.addr_eq(target) { - return r.eliminate_subsup(target); - } else if r.addr_eq(target) { - return l.eliminate_subsup(target); - } - l.eliminate_subsup(target) | r.eliminate_subsup(target) - } + Self::And(tys) => Self::checked_and( + tys.into_iter() + .filter_map(|t| { + if t.addr_eq(target) { + None + } else { + Some(t.eliminate_subsup(target)) + } + }) + .collect(), + ), + Self::Or(tys) => Self::checked_or( + tys.into_iter() + .filter_map(|t| { + if t.addr_eq(target) { + None + } else { + Some(t.eliminate_subsup(target)) + } + }) + .collect(), + ), other => other, } } @@ -4294,8 +4352,16 @@ impl Type { before: Box::new(before.eliminate_recursion(target)), after: after.map(|t| Box::new(t.eliminate_recursion(target))), }, - Self::And(l, r) => l.eliminate_recursion(target) & r.eliminate_recursion(target), - Self::Or(l, r) => l.eliminate_recursion(target) | r.eliminate_recursion(target), + Self::And(tys) => Self::checked_and( + tys.into_iter() + .map(|t| t.eliminate_recursion(target)) + .collect(), + ), + Self::Or(tys) => Self::checked_or( + tys.into_iter() + .map(|t| t.eliminate_recursion(target)) + .collect(), + ), Self::Not(ty) => !ty.eliminate_recursion(target), Self::Proj { lhs, rhs } => lhs.eliminate_recursion(target).proj(rhs), Self::ProjCall { @@ -4448,8 +4514,8 @@ impl Type { before: Box::new(before.map(f)), after: after.map(|t| Box::new(t.map(f))), }, - Self::And(l, r) => l.map(f) & r.map(f), - Self::Or(l, r) => l.map(f) | r.map(f), + Self::And(tys) => Self::checked_and(tys.into_iter().map(|t| t.map(f)).collect()), + Self::Or(tys) => Self::checked_or(tys.into_iter().map(|t| t.map(f)).collect()), Self::Not(ty) => !ty.map(f), Self::Proj { lhs, rhs } => lhs.map(f).proj(rhs), Self::ProjCall { @@ -4542,8 +4608,12 @@ impl Type { before: Box::new(before._replace_tp(target, to)), after: after.map(|t| Box::new(t._replace_tp(target, to))), }, - Self::And(l, r) => l._replace_tp(target, to) & r._replace_tp(target, to), - Self::Or(l, r) => l._replace_tp(target, to) | r._replace_tp(target, to), + Self::And(tys) => { + Self::checked_and(tys.into_iter().map(|t| t._replace_tp(target, to)).collect()) + } + Self::Or(tys) => { + Self::checked_or(tys.into_iter().map(|t| t._replace_tp(target, to)).collect()) + } Self::Not(ty) => !ty._replace_tp(target, to), Self::Proj { lhs, rhs } => lhs._replace_tp(target, to).proj(rhs), Self::ProjCall { @@ -4619,8 +4689,8 @@ impl Type { before: Box::new(before.map_tp(f)), after: after.map(|t| Box::new(t.map_tp(f))), }, - Self::And(l, r) => l.map_tp(f) & r.map_tp(f), - Self::Or(l, r) => l.map_tp(f) | r.map_tp(f), + Self::And(tys) => Self::checked_and(tys.into_iter().map(|t| t.map_tp(f)).collect()), + Self::Or(tys) => Self::checked_or(tys.into_iter().map(|t| t.map_tp(f)).collect()), Self::Not(ty) => !ty.map_tp(f), Self::Proj { lhs, rhs } => lhs.map_tp(f).proj(rhs), Self::ProjCall { @@ -4708,8 +4778,16 @@ impl Type { after, }) } - Self::And(l, r) => Ok(l.try_map_tp(f)? & r.try_map_tp(f)?), - Self::Or(l, r) => Ok(l.try_map_tp(f)? | r.try_map_tp(f)?), + Self::And(tys) => Ok(Self::checked_and( + tys.into_iter() + .map(|t| t.try_map_tp(f)) + .collect::>()?, + )), + Self::Or(tys) => Ok(Self::checked_or( + tys.into_iter() + .map(|t| t.try_map_tp(f)) + .collect::>()?, + )), Self::Not(ty) => Ok(!ty.try_map_tp(f)?), Self::Proj { lhs, rhs } => Ok(lhs.try_map_tp(f)?.proj(rhs)), Self::ProjCall { @@ -4742,12 +4820,28 @@ impl Type { *refine.t = refine.t.replace_param(target, to); Self::Refinement(refine) } - Self::And(l, r) => l.replace_param(target, to) & r.replace_param(target, to), + Self::And(tys) => Self::And( + tys.into_iter() + .map(|t| t.replace_param(target, to)) + .collect(), + ), Self::Guard(guard) => Self::Guard(guard.replace_param(target, to)), _ => self, } } + pub fn eliminate_and_or(&mut self) { + match self { + Self::And(tys) if tys.len() == 1 => { + *self = tys.take_all().into_iter().next().unwrap(); + } + Self::Or(tys) if tys.len() == 1 => { + *self = tys.take_all().into_iter().next().unwrap(); + } + _ => {} + } + } + pub fn replace_params<'l, 'r>( mut self, target: impl Iterator, @@ -4815,8 +4909,8 @@ impl Type { } Self::NamedTuple(r) } - Self::And(l, r) => l.normalize() & r.normalize(), - Self::Or(l, r) => l.normalize() | r.normalize(), + Self::And(tys) => Self::checked_and(tys.into_iter().map(|t| t.normalize()).collect()), + Self::Or(tys) => Self::checked_or(tys.into_iter().map(|t| t.normalize()).collect()), Self::Not(ty) => !ty.normalize(), Self::Structural(ty) => ty.normalize().structuralize(), Self::Quantified(quant) => quant.normalize().quantify(), @@ -4848,14 +4942,36 @@ impl Type { free.get_sub().unwrap_or(self.clone()) } else { match self { - Self::And(l, r) => l.lower_bounded() & r.lower_bounded(), - Self::Or(l, r) => l.lower_bounded() | r.lower_bounded(), + Self::And(tys) => { + Self::checked_and(tys.iter().map(|t| t.lower_bounded()).collect()) + } + Self::Or(tys) => Self::checked_or(tys.iter().map(|t| t.lower_bounded()).collect()), Self::Not(ty) => !ty.lower_bounded(), _ => self.clone(), } } } + /// ```erg + /// assert Int.upper_bounded() == Int + /// assert ?T(<: Str).upper_bounded() == Str + /// assert (?T(<: Str) or ?U(<: Int)).upper_bounded() == (Str or Int) + /// ``` + pub fn upper_bounded(&self) -> Type { + if let Ok(free) = <&FreeTyVar>::try_from(self) { + free.get_super().unwrap_or(self.clone()) + } else { + match self { + Self::And(tys) => { + Self::checked_and(tys.iter().map(|t| t.upper_bounded()).collect()) + } + Self::Or(tys) => Self::checked_or(tys.iter().map(|t| t.upper_bounded()).collect()), + Self::Not(ty) => !ty.upper_bounded(), + _ => self.clone(), + } + } + } + pub(crate) fn addr_eq(&self, other: &Type) -> bool { match (self, other) { (Self::FreeVar(slf), _) if slf.is_linked() => slf.crack().addr_eq(other), @@ -5018,7 +5134,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().ands(), Self::Refinement(refine) => refine.t.ands(), - Self::And(l, r) => l.ands().union(&r.ands()), + Self::And(tys) => tys.clone(), _ => set![self.clone()], } } @@ -5031,7 +5147,7 @@ impl Type { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().ors(), Self::Refinement(refine) => refine.t.ors(), - Self::Or(l, r) => l.ors().union(&r.ors()), + Self::Or(tys) => tys.clone(), _ => set![self.clone()], } } @@ -5076,7 +5192,7 @@ impl Type { Self::Callable { param_ts, .. } => { param_ts.iter().flat_map(|t| t.contained_ts()).collect() } - Self::And(l, r) | Self::Or(l, r) => l.contained_ts().union(&r.contained_ts()), + Self::And(tys) | Self::Or(tys) => tys.iter().flat_map(|t| t.contained_ts()).collect(), Self::Not(t) => t.contained_ts(), Self::Bounded { sub, sup } => sub.contained_ts().union(&sup.contained_ts()), Self::Quantified(ty) | Self::Structural(ty) => ty.contained_ts(), @@ -5145,9 +5261,18 @@ impl Type { } return_t.dereference(); } - Self::And(l, r) | Self::Or(l, r) => { - l.dereference(); - r.dereference(); + Self::And(tys) | Self::Or(tys) => { + *tys = tys + .take_all() + .into_iter() + .map(|mut t| { + t.dereference(); + t + }) + .collect(); + if tys.len() == 1 { + *self = tys.take_all().into_iter().next().unwrap(); + } } Self::Not(ty) => { ty.dereference(); @@ -5255,7 +5380,7 @@ impl Type { set.extend(return_t.variables()); set } - Self::And(l, r) | Self::Or(l, r) => l.variables().union(&r.variables()), + Self::And(tys) | Self::Or(tys) => tys.iter().flat_map(|t| t.variables()).collect(), Self::Not(ty) => ty.variables(), Self::Bounded { sub, sup } => sub.variables().union(&sup.variables()), Self::Quantified(ty) | Self::Structural(ty) => ty.variables(), @@ -5367,13 +5492,16 @@ impl<'t> ReplaceTable<'t> { self.iterate(l, r); } } - (Type::And(l, r), Type::And(l2, r2)) => { - self.iterate(l, l2); - self.iterate(r, r2); + // FIXME: + (Type::And(tys), Type::And(tys2)) => { + for (l, r) in tys.iter().zip(tys2.iter()) { + self.iterate(l, r); + } } - (Type::Or(l, r), Type::Or(l2, r2)) => { - self.iterate(l, l2); - self.iterate(r, r2); + (Type::Or(tys), Type::Or(tys2)) => { + for (l, r) in tys.iter().zip(tys2.iter()) { + self.iterate(l, r); + } } (Type::Not(t), Type::Not(t2)) => { self.iterate(t, t2);