Skip to content

Commit

Permalink
fix: union type bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Sep 15, 2024
1 parent a0810ad commit 3b9bbdf
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 46 deletions.
6 changes: 4 additions & 2 deletions crates/erg_compiler/context/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1778,13 +1778,15 @@ impl Context {
/// intersection_add(Int and ?T(:> NoneType), Str) == Never
/// ```
fn intersection_add(&self, intersection: &Type, elem: &Type) -> Type {
let ands = intersection.ands();
let mut ands = intersection.ands();
let bounded = ands.iter().map(|t| t.lower_bounded());
for t in bounded {
if self.subtype_of(&t, elem) {
return intersection.clone();
} else if self.supertype_of(&t, elem) {
return constructors::ands(ands.linear_exclude(&t).include(elem.clone()));
ands.retain(|ty| ty != &t);
ands.push(elem.clone());
return constructors::ands(ands);
}
}
and(intersection.clone(), elem.clone())
Expand Down
104 changes: 87 additions & 17 deletions crates/erg_compiler/context/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,27 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
}
Ok(())
}
(Or(l), Or(r)) | (And(l), And(r)) if l.len() == r.len() => {
(And(l), And(r)) if l.len() == r.len() => {
let mut r = r.clone();
for _ in 0..r.len() {
if l.iter()
.zip(r.iter())
.all(|(l, r)| self.occur(l, r).is_ok())
{
return Ok(());
}
r.rotate_left(1);
}
Err(TyCheckErrors::from(TyCheckError::subtyping_error(
self.ctx.cfg.input.clone(),
line!() as usize,
maybe_sub,
maybe_sup,
self.loc.loc(),
self.ctx.caused_by(),
)))
}
(Or(l), Or(r)) if l.len() == r.len() => {
let l = l.to_vec();
let mut r = r.to_vec();
for _ in 0..r.len() {
Expand All @@ -176,13 +196,25 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
self.ctx.caused_by(),
)))
}
(lhs, Or(tys)) | (lhs, And(tys)) => {
(lhs, And(tys)) => {
for ty in tys.iter() {
self.occur(lhs, ty)?;
}
Ok(())
}
(lhs, Or(tys)) => {
for ty in tys.iter() {
self.occur(lhs, ty)?;
}
Ok(())
}
(Or(tys), rhs) | (And(tys), rhs) => {
(And(tys), rhs) => {
for ty in tys.iter() {
self.occur(ty, rhs)?;
}
Ok(())
}
(Or(tys), rhs) => {
for ty in tys.iter() {
self.occur(ty, rhs)?;
}
Expand Down Expand Up @@ -287,13 +319,25 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
}
Ok(())
}
(lhs, Or(tys)) | (lhs, And(tys)) => {
(lhs, And(tys)) => {
for ty in tys.iter() {
self.occur_inner(lhs, ty)?;
}
Ok(())
}
(Or(tys), rhs) | (And(tys), rhs) => {
(lhs, Or(tys)) => {
for ty in tys.iter() {
self.occur_inner(lhs, ty)?;
}
Ok(())
}
(And(tys), rhs) => {
for ty in tys.iter() {
self.occur_inner(ty, rhs)?;
}
Ok(())
}
(Or(tys), rhs) => {
for ty in tys.iter() {
self.occur_inner(ty, rhs)?;
}
Expand Down Expand Up @@ -1208,24 +1252,52 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
// self.sub_unify(&lsub, &union, loc, param_name)?;
maybe_sup.update_tyvar(union, intersec, self.undoable, false);
}
(And(ltys), And(rtys)) => {
let mut rtys = rtys.clone();
for _ in 0..rtys.len() {
if ltys
.iter()
.zip(rtys.iter())
.all(|(l, r)| self.ctx.subtype_of(l, r))
{
for (l, r) in ltys.iter().zip(rtys.iter()) {
self.sub_unify(l, r)?;
}
return Ok(());
}
rtys.rotate_left(1);
}
return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error(
self.ctx.cfg.input.clone(),
line!() as usize,
self.loc.loc(),
self.ctx.caused_by(),
self.param_name.as_ref().unwrap_or(&Str::ever("_")),
None,
maybe_sup,
maybe_sub,
self.ctx.get_candidates(maybe_sub),
self.ctx.get_simple_type_mismatch_hint(maybe_sup, maybe_sub),
)));
}
// (Int or ?T) <: (?U or Int)
// OK: (Int <: Int); (?T <: ?U)
// NG: (Int <: ?U); (?T <: Int)
(Or(ltys), Or(rtys)) | (And(ltys), And(rtys)) => {
let lvars = ltys.to_vec();
let mut rvars = rtys.to_vec();
for _ in 0..rvars.len() {
if lvars
(Or(ltys), Or(rtys)) => {
let ltys = ltys.to_vec();
let mut rtys = rtys.to_vec();
for _ in 0..rtys.len() {
if ltys
.iter()
.zip(rvars.iter())
.zip(rtys.iter())
.all(|(l, r)| self.ctx.subtype_of(l, r))
{
for (l, r) in ltys.iter().zip(rtys.iter()) {
self.sub_unify(l, r)?;
}
break;
return Ok(());
}
rvars.rotate_left(1);
rtys.rotate_left(1);
}
return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error(
self.ctx.cfg.input.clone(),
Expand Down Expand Up @@ -1625,8 +1697,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
(And(tys), _) => {
for ty in tys {
if self.ctx.subtype_of(ty, maybe_sup) {
self.sub_unify(ty, maybe_sup)?;
break;
return self.sub_unify(ty, maybe_sup);
}
}
self.sub_unify(tys.iter().next().unwrap(), maybe_sup)?;
Expand All @@ -1635,8 +1706,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
(_, Or(tys)) => {
for ty in tys {
if self.ctx.subtype_of(maybe_sub, ty) {
self.sub_unify(maybe_sub, ty)?;
break;
return self.sub_unify(maybe_sub, ty);
}
}
self.sub_unify(maybe_sub, tys.iter().next().unwrap())?;
Expand Down
Loading

0 comments on commit 3b9bbdf

Please sign in to comment.