Skip to content

Commit

Permalink
fix: refinement type assert cast bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Feb 14, 2024
1 parent 121738d commit 13f303f
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 1 deletion.
8 changes: 8 additions & 0 deletions crates/erg_common/triple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ impl<T> Triple<T, T> {
Triple::Ok(a) | Triple::Err(a) => Some(a),
}
}

pub fn merge_or(self, default: T) -> T {
match self {
Triple::None => default,
Triple::Ok(ok) => ok,
Triple::Err(err) => err,
}
}
}

impl<T, E: std::error::Error> Triple<T, E> {
Expand Down
4 changes: 3 additions & 1 deletion crates/erg_compiler/context/inquire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3657,6 +3657,7 @@ impl Context {
/// ```erg
/// recover_typarams(Int, Nat) == Nat
/// recover_typarams(Array!(Int, _), Array(Nat, 2)) == Array!(Nat, 2)
/// recover_typarams(Str or NoneType, {"a", "b"}) == {"a", "b"}
/// ```
/// ```erg
/// # REVIEW: should be?
Expand All @@ -3667,7 +3668,8 @@ impl Context {
let is_never =
self.subtype_of(&intersec, &Type::Never) && guard.to.as_ref() != &Type::Never;
if !is_never {
return Ok(intersec);
let min = self.min(&intersec, &guard.to).merge_or(&intersec);
return Ok(min.clone());
}
if guard.to.is_monomorphic() {
if self.related(base, &guard.to) {
Expand Down
8 changes: 8 additions & 0 deletions crates/erg_compiler/tests/infer.er
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,11 @@ c_new x, y = C.new x, y
C = Class Int
C.
new x, y = Self x + y

val!() =
for! [{ "a": "b" }], (pkg as {Str: Str}) =>
x = pkg.get("a", "c")
assert x in {"b"}
val!::return x
"d"
val = val!()
3 changes: 3 additions & 0 deletions crates/erg_compiler/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ fn _test_infer_types() -> Result<(), ()> {
let c_new_t = func2(add_r, r, c.clone()).quantify();
module.context.assert_var_type("c_new", &c_new_t)?;
module.context.assert_attr_type(&c, "new", &c_new_t)?;
module
.context
.assert_var_type("val", &v_enum(set! { "b".into(), "d".into() }))?;
Ok(())
}

Expand Down

0 comments on commit 13f303f

Please sign in to comment.