Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change And/Or-type structures #521

Merged
merged 14 commits into from
Sep 18, 2024
28 changes: 17 additions & 11 deletions crates/els/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,30 @@ fn comp_item_kind(t: &Type, muty: Mutability) -> CompletionItemKind {
Type::Subr(_) | Type::Quantified(_) => CompletionItemKind::FUNCTION,
Type::ClassType => CompletionItemKind::CLASS,
Type::TraitType => CompletionItemKind::INTERFACE,
Type::Or(l, r) => {
let l = comp_item_kind(l, muty);
let r = comp_item_kind(r, muty);
if l == r {
l
Type::Or(tys) => {
let fst = comp_item_kind(tys.iter().next().unwrap(), muty);
if tys
.iter()
.map(|t| comp_item_kind(t, muty))
.all(|k| k == fst)
{
fst
} else if muty.is_const() {
CompletionItemKind::CONSTANT
} else {
CompletionItemKind::VARIABLE
}
}
Type::And(l, r) => {
let l = comp_item_kind(l, muty);
let r = comp_item_kind(r, muty);
if l == CompletionItemKind::VARIABLE {
r
Type::And(tys) => {
for k in tys.iter().map(|t| comp_item_kind(t, muty)) {
if k != CompletionItemKind::VARIABLE {
return k;
}
}
if muty.is_const() {
CompletionItemKind::CONSTANT
} else {
l
CompletionItemKind::VARIABLE
}
}
Type::Refinement(r) => comp_item_kind(&r.t, muty),
Expand Down
12 changes: 12 additions & 0 deletions crates/erg_common/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ impl<K, V> Dict<K, V> {
}
}

/// ```
/// # use erg_common::dict;
/// # use erg_common::dict::Dict;
/// let mut dict = Dict::with_capacity(3);
/// assert_eq!(dict.capacity(), 3);
/// dict.insert("a", 1);
/// assert_eq!(dict.capacity(), 3);
/// dict.insert("b", 2);
/// dict.insert("c", 3);
/// dict.insert("d", 4);
/// assert_ne!(dict.capacity(), 3);
/// ```
pub fn with_capacity(capacity: usize) -> Self {
Self {
dict: FxHashMap::with_capacity_and_hasher(capacity, Default::default()),
Expand Down
11 changes: 11 additions & 0 deletions crates/erg_common/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,17 @@ impl RecursionCounter {

#[macro_export]
macro_rules! set_recursion_limit {
(panic, $msg:expr, $limit:expr) => {
use std::sync::atomic::AtomicU32;

static COUNTER: AtomicU32 = AtomicU32::new($limit);

let counter = $crate::macros::RecursionCounter::new(&COUNTER);
if counter.limit_reached() {
$crate::log!(err "Recursion limit reached");
panic!($msg);
}
};
($returns:expr, $limit:expr) => {
use std::sync::atomic::AtomicU32;

Expand Down
14 changes: 14 additions & 0 deletions crates/erg_common/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,20 @@ impl<T: Hash + Eq + Clone> Set<T> {
self.insert(other);
self
}

/// ```
/// # use erg_common::set;
/// assert_eq!(set!{1, 2}.product(&set!{3, 4}), set!{(&1, &3), (&1, &4), (&2, &3), (&2, &4)});
/// ```
pub fn product<'l, 'r, U: Hash + Eq>(&'l self, other: &'r Set<U>) -> Set<(&'l T, &'r U)> {
let mut res = set! {};
for x in self.iter() {
for y in other.iter() {
res.insert((x, y));
}
}
res
}
}

impl<T: Hash + Ord> Set<T> {
Expand Down
2 changes: 2 additions & 0 deletions crates/erg_common/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1407,6 +1407,8 @@ impl<T: Immutable + ?Sized> Immutable for &T {}
impl<T: Immutable> Immutable for Option<T> {}
impl<T: Immutable> Immutable for Vec<T> {}
impl<T: Immutable> Immutable for [T] {}
impl<T: Immutable, U: Immutable> Immutable for (T, U) {}
impl<T: Immutable, U: Immutable, V: Immutable> Immutable for (T, U, V) {}
impl<T: Immutable + ?Sized> Immutable for Box<T> {}
impl<T: Immutable + ?Sized> Immutable for std::rc::Rc<T> {}
impl<T: Immutable + ?Sized> Immutable for std::sync::Arc<T> {}
10 changes: 10 additions & 0 deletions crates/erg_common/triple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ impl<T: fmt::Display, E: fmt::Display> fmt::Display for Triple<T, E> {
}

impl<T, E> Triple<T, E> {
pub const fn is_ok(&self) -> bool {
matches!(self, Triple::Ok(_))
}
pub const fn is_err(&self) -> bool {
matches!(self, Triple::Err(_))
}
pub const fn is_none(&self) -> bool {
matches!(self, Triple::None)
}

pub fn none_then(self, err: E) -> Result<T, E> {
match self {
Triple::None => Err(err),
Expand Down
127 changes: 69 additions & 58 deletions crates/erg_compiler/context/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,9 +763,7 @@ impl Context {
self.structural_supertype_of(&l, rhs)
}
// {1, 2, 3} :> {1} or {2, 3} == true
(Refinement(_refine), Or(l, r)) => {
self.supertype_of(lhs, l) && self.supertype_of(lhs, r)
}
(Refinement(_refine), Or(tys)) => tys.iter().all(|ty| self.supertype_of(lhs, ty)),
// ({I: Int | True} :> Int) == true
// {N: Nat | ...} :> Int) == false
// ({I: Int | I >= 0} :> Int) == false
Expand Down Expand Up @@ -817,41 +815,37 @@ impl Context {
self.sub_unify(&inst, l, &(), None).is_ok()
}
// Int or Str :> Str or Int == (Int :> Str && Str :> Int) || (Int :> Int && Str :> Str) == true
(Or(l_1, l_2), Or(r_1, r_2)) => {
if l_1.is_union_type() && self.supertype_of(l_1, rhs) {
return true;
}
if l_2.is_union_type() && self.supertype_of(l_2, rhs) {
return true;
}
(self.supertype_of(l_1, r_1) && self.supertype_of(l_2, r_2))
|| (self.supertype_of(l_1, r_2) && self.supertype_of(l_2, r_1))
}
// Int or Str or NoneType :> Str or Int
// Int or Str or NoneType :> Str or NoneType or Nat
(Or(l), Or(r)) => r.iter().all(|r| l.iter().any(|l| self.supertype_of(l, r))),
// not Nat :> not Int == true
(Not(l), Not(r)) => self.subtype_of(l, r),
// (Int or Str) :> Nat == Int :> Nat || Str :> Nat == true
// (Num or Show) :> Show == Num :> Show || Show :> Num == true
(Or(l_or, r_or), rhs) => self.supertype_of(l_or, rhs) || self.supertype_of(r_or, rhs),
(Or(ors), rhs) => ors.iter().any(|or| self.supertype_of(or, rhs)),
// Int :> (Nat or Str) == Int :> Nat && Int :> Str == false
(lhs, Or(l_or, r_or)) => self.supertype_of(lhs, l_or) && self.supertype_of(lhs, r_or),
(And(l_1, l_2), And(r_1, r_2)) => {
if l_1.is_intersection_type() && self.supertype_of(l_1, rhs) {
(lhs, Or(ors)) => ors.iter().all(|or| self.supertype_of(lhs, or)),
// Hash and Eq :> HashEq and ... == true
// Add(T) and Eq :> Add(Int) and Eq == true
(And(l), And(r)) => {
if r.iter().any(|r| l.iter().all(|l| self.supertype_of(l, r))) {
return true;
}
if l_2.is_intersection_type() && self.supertype_of(l_2, rhs) {
return true;
if l.len() == r.len() {
let mut r = r.clone();
for _ in 1..l.len() {
if l.iter().zip(&r).all(|(l, r)| self.supertype_of(l, r)) {
return true;
}
r.rotate_left(1);
}
}
(self.supertype_of(l_1, r_1) && self.supertype_of(l_2, r_2))
|| (self.supertype_of(l_1, r_2) && self.supertype_of(l_2, r_1))
false
}
// (Num and Show) :> Show == false
(And(l_and, r_and), rhs) => {
self.supertype_of(l_and, rhs) && self.supertype_of(r_and, rhs)
}
(And(ands), rhs) => ands.iter().all(|and| self.supertype_of(and, rhs)),
// Show :> (Num and Show) == true
(lhs, And(l_and, r_and)) => {
self.supertype_of(lhs, l_and) || self.supertype_of(lhs, r_and)
}
(lhs, And(ands)) => ands.iter().any(|and| self.supertype_of(lhs, and)),
// Not(Eq) :> Float == !(Eq :> Float) == true
(Not(_), Obj) => false,
(Not(l), rhs) => !self.supertype_of(l, rhs),
Expand Down Expand Up @@ -923,18 +917,18 @@ impl Context {
Type::NamedTuple(fields) => fields.iter().cloned().collect(),
Type::Refinement(refine) => self.fields(&refine.t),
Type::Structural(t) => self.fields(t),
Type::Or(l, r) => {
let l_fields = self.fields(l);
let r_fields = self.fields(r);
let l_field_names = l_fields.keys().collect::<Set<_>>();
let r_field_names = r_fields.keys().collect::<Set<_>>();
let field_names = l_field_names.intersection(&r_field_names);
Type::Or(tys) => {
let or_fields = tys.iter().map(|t| self.fields(t)).collect::<Set<_>>();
let field_names = or_fields
.iter()
.flat_map(|fs| fs.keys())
.collect::<Set<_>>();
let mut fields = Dict::new();
for (name, l_t, r_t) in field_names
for (name, tys) in field_names
.iter()
.map(|&name| (name, &l_fields[name], &r_fields[name]))
.map(|&name| (name, or_fields.iter().filter_map(|fields| fields.get(name))))
{
let union = self.union(l_t, r_t);
let union = tys.fold(Never, |acc, ty| self.union(&acc, ty));
fields.insert(name.clone(), union);
}
fields
Expand Down Expand Up @@ -1417,6 +1411,8 @@ impl Context {
/// union({ .a = Int }, { .a = Str }) == { .a = Int or Str }
/// union({ .a = Int }, { .a = Int; .b = Int }) == { .a = Int } or { .a = Int; .b = Int } # not to lost `b` information
/// union((A and B) or C) == (A or C) and (B or C)
/// union(Never, Int) == Int
/// union(Obj, Int) == Obj
/// ```
pub(crate) fn union(&self, lhs: &Type, rhs: &Type) -> Type {
if lhs == rhs {
Expand Down Expand Up @@ -1479,10 +1475,9 @@ impl Context {
(Some(sub), Some(sup)) => bounded(sub.clone(), sup.clone()),
_ => self.simple_union(lhs, rhs),
},
(other, or @ Or(_, _)) | (or @ Or(_, _), other) => self.union_add(or, other),
(other, or @ Or(_)) | (or @ Or(_), other) => self.union_add(or, other),
// (A and B) or C ==> (A or C) and (B or C)
(and_t @ And(_, _), other) | (other, and_t @ And(_, _)) => {
let ands = and_t.ands();
(And(ands), other) | (other, And(ands)) => {
let mut t = Type::Obj;
for branch in ands.iter() {
let union = self.union(branch, other);
Expand Down Expand Up @@ -1666,6 +1661,12 @@ impl Context {

/// Returns intersection of two types (`A and B`).
/// If `A` and `B` have a subtype relationship, it is equal to `min(A, B)`.
/// ```erg
/// intersection(Nat, Int) == Nat
/// intersection(Int, Str) == Never
/// intersection(Obj, Int) == Int
/// intersection(Never, Int) == Never
/// ```
pub(crate) fn intersection(&self, lhs: &Type, rhs: &Type) -> Type {
if lhs == rhs {
return lhs.clone();
Expand Down Expand Up @@ -1696,12 +1697,9 @@ impl Context {
(_, Not(r)) => self.diff(lhs, r),
(Not(l), _) => self.diff(rhs, l),
// A and B and A == A and B
(other, and @ And(_, _)) | (and @ And(_, _), other) => {
self.intersection_add(and, other)
}
(other, and @ And(_)) | (and @ And(_), other) => self.intersection_add(and, other),
// (A or B) and C == (A and C) or (B and C)
(or_t @ Or(_, _), other) | (other, or_t @ Or(_, _)) => {
let ors = or_t.ors();
(Or(ors), other) | (other, Or(ors)) => {
if ors.iter().any(|t| t.has_unbound_var()) {
return self.simple_intersection(lhs, rhs);
}
Expand Down Expand Up @@ -1797,13 +1795,15 @@ impl Context {
/// intersection_add(Int and ?T(:> NoneType), Str) == Never
/// ```
fn intersection_add(&self, intersection: &Type, elem: &Type) -> Type {
let ands = intersection.ands();
let mut ands = intersection.ands();
let bounded = ands.iter().map(|t| t.lower_bounded());
for t in bounded {
if self.subtype_of(&t, elem) {
return intersection.clone();
} else if self.supertype_of(&t, elem) {
return constructors::ands(ands.linear_exclude(&t).include(elem.clone()));
ands.retain(|ty| ty != &t);
ands.push(elem.clone());
return constructors::ands(ands);
}
}
and(intersection.clone(), elem.clone())
Expand Down Expand Up @@ -1836,21 +1836,21 @@ impl Context {
fn narrow_type_by_pred(&self, t: Type, pred: &Predicate) -> Type {
match (t, pred) {
(
Type::Or(l, r),
Type::Or(tys),
Predicate::NotEqual {
rhs: TyParam::Value(v),
..
},
) if v.is_none() => {
let l = self.narrow_type_by_pred(*l, pred);
let r = self.narrow_type_by_pred(*r, pred);
if l.is_nonetype() {
r
} else if r.is_nonetype() {
l
} else {
or(l, r)
let mut new_tys = Set::new();
for ty in tys {
let ty = self.narrow_type_by_pred(ty, pred);
if ty.is_nonelike() {
continue;
}
new_tys.insert(ty);
}
Type::checked_or(new_tys)
}
(Type::Refinement(refine), _) => {
let t = self.narrow_type_by_pred(*refine.t, pred);
Expand Down Expand Up @@ -1992,8 +1992,12 @@ impl Context {
guard.target.clone(),
self.complement(&guard.to),
)),
Or(l, r) => self.intersection(&self.complement(l), &self.complement(r)),
And(l, r) => self.union(&self.complement(l), &self.complement(r)),
Or(ors) => ors
.iter()
.fold(Obj, |l, r| self.intersection(&l, &self.complement(r))),
And(ands) => ands
.iter()
.fold(Never, |l, r| self.union(&l, &self.complement(r))),
other => not(other.clone()),
}
}
Expand All @@ -2011,7 +2015,14 @@ impl Context {
match lhs {
Type::FreeVar(fv) if fv.is_linked() => self.diff(&fv.crack(), rhs),
// Type::And(l, r) => self.intersection(&self.diff(l, rhs), &self.diff(r, rhs)),
Type::Or(l, r) => self.union(&self.diff(l, rhs), &self.diff(r, rhs)),
Type::Or(tys) => {
let mut new_tys = vec![];
for ty in tys {
let diff = self.diff(ty, rhs);
new_tys.push(diff);
}
new_tys.into_iter().fold(Never, |l, r| self.union(&l, &r))
}
_ => lhs.clone(),
}
}
Expand Down
Loading
Loading