Skip to content

Commit

Permalink
fix: sub-unification bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Sep 17, 2024
1 parent 1f51d18 commit df837d7
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 47 deletions.
1 change: 1 addition & 0 deletions crates/erg_compiler/context/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,7 @@ impl Context {
}
// Int or Str :> Str or Int == (Int :> Str && Str :> Int) || (Int :> Int && Str :> Str) == true
// Int or Str or NoneType :> Str or Int
// Int or Str or NoneType :> Str or NoneType or Nat
(Or(l), Or(r)) => r.iter().all(|r| l.iter().any(|l| self.supertype_of(l, r))),
// not Nat :> not Int == true
(Not(l), Not(r)) => self.subtype_of(l, r),
Expand Down
88 changes: 41 additions & 47 deletions crates/erg_compiler/context/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
}
Ok(())
}
// FIXME: This is not correct, we must visit all permutations of the types
(And(l), And(r)) if l.len() == r.len() => {
let mut r = r.clone();
for _ in 0..r.len() {
Expand Down Expand Up @@ -1297,65 +1298,58 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
// self.sub_unify(&lsub, &union, loc, param_name)?;
maybe_sup.update_tyvar(union, intersec, self.undoable, false);
}
// TODO: Preferentially compare same-structure types (e.g. K(?T) <: K(?U))
(And(ltys), And(rtys)) => {
let mut rtys = rtys.clone();
for _ in 0..rtys.len() {
if ltys
.iter()
.zip(rtys.iter())
.all(|(l, r)| self.ctx.subtype_of(l, r))
{
for (l, r) in ltys.iter().zip(rtys.iter()) {
self.sub_unify(l, r)?;
let mut ltys_ = ltys.clone();
let mut rtys_ = rtys.clone();
// Show and EqHash and T <: Eq and Show and Ord
// => EqHash and T <: Eq and Ord
for lty in ltys.iter() {
if let Some(idx) = rtys_.iter().position(|r| r == lty) {
rtys_.remove(idx);
let idx = ltys_.iter().position(|l| l == lty).unwrap();
ltys_.remove(idx);
}
}
// EqHash and T <: Eq and Ord
for lty in ltys_.iter() {
// lty: EqHash
// rty: Eq, Ord
for rty in rtys_.iter() {
if self.ctx.subtype_of(lty, rty) {
self.sub_unify(lty, rty)?;
continue;
}
return Ok(());
}
rtys.rotate_left(1);
}
return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error(
self.ctx.cfg.input.clone(),
line!() as usize,
self.loc.loc(),
self.ctx.caused_by(),
self.param_name.as_ref().unwrap_or(&Str::ever("_")),
None,
maybe_sup,
maybe_sub,
self.ctx.get_candidates(maybe_sub),
self.ctx.get_simple_type_mismatch_hint(maybe_sup, maybe_sub),
)));
}
// TODO: Preferentially compare same-structure types (e.g. K(?T) <: K(?U))
// Nat or Str or NoneType <: NoneType or ?T or Int
// => Str <: ?T
// (Int or ?T) <: (?U or Int)
// OK: (Int <: Int); (?T <: ?U)
// NG: (Int <: ?U); (?T <: Int)
(Or(ltys), Or(rtys)) => {
let ltys = ltys.to_vec();
let mut rtys = rtys.to_vec();
for _ in 0..rtys.len() {
if ltys
.iter()
.zip(rtys.iter())
.all(|(l, r)| self.ctx.subtype_of(l, r))
{
for (l, r) in ltys.iter().zip(rtys.iter()) {
self.sub_unify(l, r)?;
let mut ltys_ = ltys.clone();
let mut rtys_ = rtys.clone();
// Nat or T or Str <: Str or Int or NoneType
// => Nat or T <: Int or NoneType
for lty in ltys {
if rtys_.linear_remove(lty) {
ltys_.linear_remove(lty);
}
}
// Nat or T <: Int or NoneType
for lty in ltys_.iter() {
// lty: Nat
// rty: Int, NoneType
for rty in rtys_.iter() {
if self.ctx.subtype_of(lty, rty) {
self.sub_unify(lty, rty)?;
continue;
}
return Ok(());
}
rtys.rotate_left(1);
}
return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error(
self.ctx.cfg.input.clone(),
line!() as usize,
self.loc.loc(),
self.ctx.caused_by(),
self.param_name.as_ref().unwrap_or(&Str::ever("_")),
None,
maybe_sup,
maybe_sub,
self.ctx.get_candidates(maybe_sub),
self.ctx.get_simple_type_mismatch_hint(maybe_sup, maybe_sub),
)));
}
// NG: Nat <: ?T or Int ==> Nat or Int (?T = Nat)
// OK: Nat <: ?T or Int ==> ?T or Int
Expand Down
3 changes: 3 additions & 0 deletions tests/should_err/and.er
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
a as Eq and Hash and Show and Add(Str) = "a"
f _: Ord and Eq and Show and Hash = None
f a # ERR
3 changes: 3 additions & 0 deletions tests/should_err/or.er
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
a as Int or Str or NoneType = 1
f _: Nat or NoneType or Str = None
f a # ERR
7 changes: 7 additions & 0 deletions tests/should_ok/and.er
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
a as Eq and Hash and Show = 1
f _: Eq and Show and Hash = None
f a

b as Eq and Hash and Ord and Show = 1
g _: Ord and Eq and Show and Hash = None
g b
7 changes: 7 additions & 0 deletions tests/should_ok/or.er
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
a as Nat or Str or NoneType = 1
f _: Int or NoneType or Str = None
f a

b as Nat or Str or NoneType or Bool = 1
g _: Int or NoneType or Bool or Str = None
g b
20 changes: 20 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ fn exec_advanced_type_spec() -> Result<(), ()> {
expect_success("tests/should_ok/advanced_type_spec.er", 5)
}

#[test]
fn exec_and() -> Result<(), ()> {
expect_success("tests/should_ok/and.er", 0)
}

#[test]
fn exec_args_expansion() -> Result<(), ()> {
expect_success("tests/should_ok/args_expansion.er", 0)
Expand Down Expand Up @@ -327,6 +332,11 @@ fn exec_operators() -> Result<(), ()> {
expect_success("tests/should_ok/operators.er", 0)
}

#[test]
fn exec_or() -> Result<(), ()> {
expect_success("tests/should_ok/or.er", 0)
}

#[test]
fn exec_patch() -> Result<(), ()> {
expect_success("examples/patch.er", 0)
Expand Down Expand Up @@ -527,6 +537,11 @@ fn exec_list_member_err() -> Result<(), ()> {
expect_failure("tests/should_err/list_member.er", 0, 3)
}

#[test]
fn exec_and_err() -> Result<(), ()> {
expect_failure("tests/should_err/and.er", 0, 1)
}

#[test]
fn exec_as() -> Result<(), ()> {
expect_failure("tests/should_err/as.er", 0, 6)
Expand Down Expand Up @@ -634,6 +649,11 @@ fn exec_move_check() -> Result<(), ()> {
expect_failure("examples/move_check.er", 1, 1)
}

#[test]
fn exec_or_err() -> Result<(), ()> {
expect_failure("tests/should_err/or.er", 0, 1)
}

#[test]
fn exec_poly_type_spec_err() -> Result<(), ()> {
expect_failure("tests/should_err/poly_type_spec.er", 0, 3)
Expand Down

0 comments on commit df837d7

Please sign in to comment.