Skip to content

Commit

Permalink
fix: Type::{And, Or}(Set<Type>)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Sep 14, 2024
1 parent 82bc710 commit b0c3137
Show file tree
Hide file tree
Showing 14 changed files with 661 additions and 466 deletions.
28 changes: 17 additions & 11 deletions crates/els/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,30 @@ fn comp_item_kind(t: &Type, muty: Mutability) -> CompletionItemKind {
Type::Subr(_) | Type::Quantified(_) => CompletionItemKind::FUNCTION,
Type::ClassType => CompletionItemKind::CLASS,
Type::TraitType => CompletionItemKind::INTERFACE,
Type::Or(l, r) => {
let l = comp_item_kind(l, muty);
let r = comp_item_kind(r, muty);
if l == r {
l
Type::Or(tys) => {
let fst = comp_item_kind(tys.iter().next().unwrap(), muty);
if tys
.iter()
.map(|t| comp_item_kind(t, muty))
.all(|k| k == fst)
{
fst
} else if muty.is_const() {
CompletionItemKind::CONSTANT
} else {
CompletionItemKind::VARIABLE
}
}
Type::And(l, r) => {
let l = comp_item_kind(l, muty);
let r = comp_item_kind(r, muty);
if l == CompletionItemKind::VARIABLE {
r
Type::And(tys) => {
for k in tys.iter().map(|t| comp_item_kind(t, muty)) {
if k != CompletionItemKind::VARIABLE {
return k;
}
}
if muty.is_const() {
CompletionItemKind::CONSTANT
} else {
l
CompletionItemKind::VARIABLE
}
}
Type::Refinement(r) => comp_item_kind(&r.t, muty),
Expand Down
14 changes: 14 additions & 0 deletions crates/erg_common/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,20 @@ impl<T: Hash + Eq + Clone> Set<T> {
self.insert(other);
self
}

/// ```
/// # use erg_common::set;
/// assert_eq!(set!{1, 2}.product(&set!{3, 4}), set!{(&1, &3), (&1, &4), (&2, &3), (&2, &4)});
/// ```
pub fn product<'l, 'r, U: Hash + Eq>(&'l self, other: &'r Set<U>) -> Set<(&'l T, &'r U)> {
let mut res = set! {};
for x in self.iter() {
for y in other.iter() {
res.insert((x, y));
}
}
res
}
}

impl<T: Hash + Ord> Set<T> {
Expand Down
2 changes: 2 additions & 0 deletions crates/erg_common/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1407,6 +1407,8 @@ impl<T: Immutable + ?Sized> Immutable for &T {}
impl<T: Immutable> Immutable for Option<T> {}
impl<T: Immutable> Immutable for Vec<T> {}
impl<T: Immutable> Immutable for [T] {}
impl<T: Immutable, U: Immutable> Immutable for (T, U) {}
impl<T: Immutable, U: Immutable, V: Immutable> Immutable for (T, U, V) {}
impl<T: Immutable + ?Sized> Immutable for Box<T> {}
impl<T: Immutable + ?Sized> Immutable for std::rc::Rc<T> {}
impl<T: Immutable + ?Sized> Immutable for std::sync::Arc<T> {}
10 changes: 10 additions & 0 deletions crates/erg_common/triple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ impl<T: fmt::Display, E: fmt::Display> fmt::Display for Triple<T, E> {
}

impl<T, E> Triple<T, E> {
pub const fn is_ok(&self) -> bool {
matches!(self, Triple::Ok(_))
}
pub const fn is_err(&self) -> bool {
matches!(self, Triple::Err(_))
}
pub const fn is_none(&self) -> bool {
matches!(self, Triple::None)
}

pub fn none_then(self, err: E) -> Result<T, E> {
match self {
Triple::None => Err(err),
Expand Down
112 changes: 52 additions & 60 deletions crates/erg_compiler/context/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -754,9 +754,7 @@ impl Context {
self.structural_supertype_of(&l, rhs)
}
// {1, 2, 3} :> {1} or {2, 3} == true
(Refinement(_refine), Or(l, r)) => {
self.supertype_of(lhs, l) && self.supertype_of(lhs, r)
}
(Refinement(_refine), Or(tys)) => tys.iter().all(|ty| self.supertype_of(lhs, ty)),
// ({I: Int | True} :> Int) == true
// {N: Nat | ...} :> Int) == false
// ({I: Int | I >= 0} :> Int) == false
Expand Down Expand Up @@ -808,41 +806,20 @@ impl Context {
self.sub_unify(&inst, l, &(), None).is_ok()
}
// Int or Str :> Str or Int == (Int :> Str && Str :> Int) || (Int :> Int && Str :> Str) == true
(Or(l_1, l_2), Or(r_1, r_2)) => {
if l_1.is_union_type() && self.supertype_of(l_1, rhs) {
return true;
}
if l_2.is_union_type() && self.supertype_of(l_2, rhs) {
return true;
}
(self.supertype_of(l_1, r_1) && self.supertype_of(l_2, r_2))
|| (self.supertype_of(l_1, r_2) && self.supertype_of(l_2, r_1))
}
// Int or Str or NoneType :> Str or Int
(Or(l), Or(r)) => r.iter().all(|r| l.iter().any(|l| self.supertype_of(l, r))),
// not Nat :> not Int == true
(Not(l), Not(r)) => self.subtype_of(l, r),
// (Int or Str) :> Nat == Int :> Nat || Str :> Nat == true
// (Num or Show) :> Show == Num :> Show || Show :> Num == true
(Or(l_or, r_or), rhs) => self.supertype_of(l_or, rhs) || self.supertype_of(r_or, rhs),
(Or(ors), rhs) => ors.iter().any(|or| self.supertype_of(or, rhs)),
// Int :> (Nat or Str) == Int :> Nat && Int :> Str == false
(lhs, Or(l_or, r_or)) => self.supertype_of(lhs, l_or) && self.supertype_of(lhs, r_or),
(And(l_1, l_2), And(r_1, r_2)) => {
if l_1.is_intersection_type() && self.supertype_of(l_1, rhs) {
return true;
}
if l_2.is_intersection_type() && self.supertype_of(l_2, rhs) {
return true;
}
(self.supertype_of(l_1, r_1) && self.supertype_of(l_2, r_2))
|| (self.supertype_of(l_1, r_2) && self.supertype_of(l_2, r_1))
}
(lhs, Or(ors)) => ors.iter().all(|or| self.supertype_of(lhs, or)),
(And(l), And(r)) => r.iter().any(|r| l.iter().all(|l| self.supertype_of(l, r))),
// (Num and Show) :> Show == false
(And(l_and, r_and), rhs) => {
self.supertype_of(l_and, rhs) && self.supertype_of(r_and, rhs)
}
(And(ands), rhs) => ands.iter().all(|and| self.supertype_of(and, rhs)),
// Show :> (Num and Show) == true
(lhs, And(l_and, r_and)) => {
self.supertype_of(lhs, l_and) || self.supertype_of(lhs, r_and)
}
(lhs, And(ands)) => ands.iter().any(|and| self.supertype_of(lhs, and)),
// Not(Eq) :> Float == !(Eq :> Float) == true
(Not(_), Obj) => false,
(Not(l), rhs) => !self.supertype_of(l, rhs),
Expand Down Expand Up @@ -914,18 +891,18 @@ impl Context {
Type::NamedTuple(fields) => fields.iter().cloned().collect(),
Type::Refinement(refine) => self.fields(&refine.t),
Type::Structural(t) => self.fields(t),
Type::Or(l, r) => {
let l_fields = self.fields(l);
let r_fields = self.fields(r);
let l_field_names = l_fields.keys().collect::<Set<_>>();
let r_field_names = r_fields.keys().collect::<Set<_>>();
let field_names = l_field_names.intersection(&r_field_names);
Type::Or(tys) => {
let or_fields = tys.iter().map(|t| self.fields(t)).collect::<Set<_>>();
let field_names = or_fields
.iter()
.flat_map(|fs| fs.keys())
.collect::<Set<_>>();
let mut fields = Dict::new();
for (name, l_t, r_t) in field_names
for (name, tys) in field_names
.iter()
.map(|&name| (name, &l_fields[name], &r_fields[name]))
.map(|&name| (name, or_fields.iter().filter_map(|fields| fields.get(name))))
{
let union = self.union(l_t, r_t);
let union = tys.fold(Never, |acc, ty| self.union(&acc, ty));
fields.insert(name.clone(), union);
}
fields
Expand Down Expand Up @@ -1408,6 +1385,8 @@ impl Context {
/// union({ .a = Int }, { .a = Str }) == { .a = Int or Str }
/// union({ .a = Int }, { .a = Int; .b = Int }) == { .a = Int } or { .a = Int; .b = Int } # not to lost `b` information
/// union((A and B) or C) == (A or C) and (B or C)
/// union(Never, Int) == Int
/// union(Obj, Int) == Obj
/// ```
pub(crate) fn union(&self, lhs: &Type, rhs: &Type) -> Type {
if lhs == rhs {
Expand Down Expand Up @@ -1470,10 +1449,9 @@ impl Context {
(Some(sub), Some(sup)) => bounded(sub.clone(), sup.clone()),
_ => self.simple_union(lhs, rhs),
},
(other, or @ Or(_, _)) | (or @ Or(_, _), other) => self.union_add(or, other),
(other, or @ Or(_)) | (or @ Or(_), other) => self.union_add(or, other),
// (A and B) or C ==> (A or C) and (B or C)
(and_t @ And(_, _), other) | (other, and_t @ And(_, _)) => {
let ands = and_t.ands();
(And(ands), other) | (other, And(ands)) => {
let mut t = Type::Obj;
for branch in ands.iter() {
let union = self.union(branch, other);
Expand Down Expand Up @@ -1657,6 +1635,12 @@ impl Context {

/// Returns intersection of two types (`A and B`).
/// If `A` and `B` have a subtype relationship, it is equal to `min(A, B)`.
/// ```erg
/// intersection(Nat, Int) == Nat
/// intersection(Int, Str) == Never
/// intersection(Obj, Int) == Int
/// intersection(Never, Int) == Never
/// ```
pub(crate) fn intersection(&self, lhs: &Type, rhs: &Type) -> Type {
if lhs == rhs {
return lhs.clone();
Expand Down Expand Up @@ -1687,12 +1671,9 @@ impl Context {
(_, Not(r)) => self.diff(lhs, r),
(Not(l), _) => self.diff(rhs, l),
// A and B and A == A and B
(other, and @ And(_, _)) | (and @ And(_, _), other) => {
self.intersection_add(and, other)
}
(other, and @ And(_)) | (and @ And(_), other) => self.intersection_add(and, other),
// (A or B) and C == (A and C) or (B and C)
(or_t @ Or(_, _), other) | (other, or_t @ Or(_, _)) => {
let ors = or_t.ors();
(Or(ors), other) | (other, Or(ors)) => {
if ors.iter().any(|t| t.has_unbound_var()) {
return self.simple_intersection(lhs, rhs);
}
Expand Down Expand Up @@ -1827,21 +1808,21 @@ impl Context {
fn narrow_type_by_pred(&self, t: Type, pred: &Predicate) -> Type {
match (t, pred) {
(
Type::Or(l, r),
Type::Or(tys),
Predicate::NotEqual {
rhs: TyParam::Value(v),
..
},
) if v.is_none() => {
let l = self.narrow_type_by_pred(*l, pred);
let r = self.narrow_type_by_pred(*r, pred);
if l.is_nonetype() {
r
} else if r.is_nonetype() {
l
} else {
or(l, r)
let mut new_tys = Set::new();
for ty in tys {
let ty = self.narrow_type_by_pred(ty, pred);
if ty.is_nonelike() {
continue;
}
new_tys.insert(ty);
}
Type::checked_or(new_tys)
}
(Type::Refinement(refine), _) => {
let t = self.narrow_type_by_pred(*refine.t, pred);
Expand Down Expand Up @@ -1983,8 +1964,12 @@ impl Context {
guard.target.clone(),
self.complement(&guard.to),
)),
Or(l, r) => self.intersection(&self.complement(l), &self.complement(r)),
And(l, r) => self.union(&self.complement(l), &self.complement(r)),
Or(ors) => ors
.iter()
.fold(Obj, |l, r| self.intersection(&l, &self.complement(r))),
And(ands) => ands
.iter()
.fold(Never, |l, r| self.union(&l, &self.complement(r))),
other => not(other.clone()),
}
}
Expand All @@ -2002,7 +1987,14 @@ impl Context {
match lhs {
Type::FreeVar(fv) if fv.is_linked() => self.diff(&fv.crack(), rhs),
// Type::And(l, r) => self.intersection(&self.diff(l, rhs), &self.diff(r, rhs)),
Type::Or(l, r) => self.union(&self.diff(l, rhs), &self.diff(r, rhs)),
Type::Or(tys) => {
let mut new_tys = vec![];
for ty in tys {
let diff = self.diff(ty, rhs);
new_tys.push(diff);
}
new_tys.into_iter().fold(Never, |l, r| self.union(&l, &r))
}
_ => lhs.clone(),
}
}
Expand Down
58 changes: 28 additions & 30 deletions crates/erg_compiler/context/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2285,44 +2285,42 @@ impl Context {
Err((t, errs))
}
}
Type::And(l, r) => {
let l = match self.eval_t_params(*l, level, t_loc) {
Ok(l) => l,
Err((l, es)) => {
errs.extend(es);
l
}
};
let r = match self.eval_t_params(*r, level, t_loc) {
Ok(r) => r,
Err((r, es)) => {
errs.extend(es);
r
Type::And(ands) => {
let mut new_ands = set! {};
for and in ands.into_iter() {
match self.eval_t_params(and, level, t_loc) {
Ok(and) => {
new_ands.insert(and);
}
Err((and, es)) => {
new_ands.insert(and);
errs.extend(es);
}
}
};
let intersec = self.intersection(&l, &r);
}
let intersec = new_ands
.into_iter()
.fold(Type::Obj, |l, r| self.intersection(&l, &r));
if errs.is_empty() {
Ok(intersec)
} else {
Err((intersec, errs))
}
}
Type::Or(l, r) => {
let l = match self.eval_t_params(*l, level, t_loc) {
Ok(l) => l,
Err((l, es)) => {
errs.extend(es);
l
}
};
let r = match self.eval_t_params(*r, level, t_loc) {
Ok(r) => r,
Err((r, es)) => {
errs.extend(es);
r
Type::Or(ors) => {
let mut new_ors = set! {};
for or in ors.into_iter() {
match self.eval_t_params(or, level, t_loc) {
Ok(or) => {
new_ors.insert(or);
}
Err((or, es)) => {
new_ors.insert(or);
errs.extend(es);
}
}
};
let union = self.union(&l, &r);
}
let union = new_ors.into_iter().fold(Never, |l, r| self.union(&l, &r));
if errs.is_empty() {
Ok(union)
} else {
Expand Down
Loading

0 comments on commit b0c3137

Please sign in to comment.