Skip to content

Commit

Permalink
feat: improve record type narrowing
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Mar 24, 2024
1 parent 35f55c6 commit ce5eafc
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 4 deletions.
4 changes: 2 additions & 2 deletions crates/erg_compiler/context/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ impl Context {
/// union(Array(Int, 2), Array(Str, 2)) == Array(Int or Str, 2)
/// union(Array(Int, 2), Array(Str, 3)) == Array(Int, 2) or Array(Int, 3)
/// union({ .a = Int }, { .a = Str }) == { .a = Int or Str }
/// union({ .a = Int }, { .a = Int; .b = Int }) == { .a = Int }
/// union({ .a = Int }, { .a = Int; .b = Int }) == { .a = Int } or { .a = Int; .b = Int } # not to lost `b` information
/// union((A and B) or C) == (A or C) and (B or C)
/// ```
pub(crate) fn union(&self, lhs: &Type, rhs: &Type) -> Type {
Expand All @@ -1326,7 +1326,7 @@ impl Context {
union
}
}
(Record(l), Record(r)) => {
(Record(l), Record(r)) if l.len() == r.len() && l.len() == 1 => {
let mut union = Dict::new();
for (l_k, l_v) in l.iter() {
if let Some((r_k, r_v)) = r.get_key_value(l_k) {
Expand Down
30 changes: 30 additions & 0 deletions crates/erg_compiler/context/inquire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,36 @@ impl Context {
}
}
Type::Structural(t) => self.get_attr_info_from_attributive(t, ident, namespace),
// TODO: And
Type::Or(l, r) => {
let l_info = self.get_attr_info_from_attributive(l, ident, namespace);
let r_info = self.get_attr_info_from_attributive(r, ident, namespace);
match (l_info, r_info) {
(Triple::Ok(l), Triple::Ok(r)) => {
let res = self.union(&l.t, &r.t);
let vis = if l.vis.is_public() && r.vis.is_public() {
Visibility::DUMMY_PUBLIC
} else {
Visibility::DUMMY_PRIVATE
};
let vi = VarInfo::new(
res,
l.muty,
vis,
l.kind,
l.comptime_decos,
l.ctx,
l.py_name,
l.def_loc,
);
Triple::Ok(vi)
},
(Triple::Ok(_), Triple::Err(e))
| (Triple::Err(e), Triple::Ok(_)) => Triple::Err(e),
(Triple::Err(e1), Triple::Err(_e2)) => Triple::Err(e1),
_ => Triple::None,
}
}
_other => Triple::None,
}
}
Expand Down
10 changes: 10 additions & 0 deletions crates/erg_compiler/context/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2405,6 +2405,16 @@ impl Context {
if expr == target.as_ref() {
return Some(*guard.to.clone());
}
// { r.x in Int } => { r in Structural { .x = Int } }
else if let ast::Expr::Accessor(ast::Accessor::Attr(attr)) = target.as_ref() {
if attr.obj.as_ref() == expr {
let mut rec = Dict::new();
let vis = self.instantiate_vis_modifier(&attr.ident.vis).ok()?;
let field = Field::new(vis, attr.ident.inspect().clone());
rec.insert(field, *guard.to.clone());
return Some(Type::Record(rec).structuralize());
}
}
}
}
None
Expand Down
5 changes: 4 additions & 1 deletion crates/erg_compiler/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2954,7 +2954,10 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
if let Some(casted) = casted {
// e.g. casted == {x: Obj | x != None}, expr: Int or NoneType => intersec == Int
let intersec = self.module.context.intersection(expr.ref_t(), &casted);
if expr.ref_t().is_projection() || intersec != Type::Never {
// bad narrowing: C and Structural { foo = Foo }
if expr.ref_t().is_projection()
|| (intersec != Type::Never && intersec.ands().iter().all(|t| !t.is_structural()))
{
if let Some(ref_mut_t) = expr.ref_mut_t() {
*ref_mut_t = intersec;
}
Expand Down
3 changes: 3 additions & 0 deletions tests/should_err/record.er
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@ C = Class()

_ = { 1: C.new() }
_ = { C.new(): 1 } # ERR

ints_or_strs _: {.x = Int; .y = Int} or {.x = Str; .y = Str} = None
ints_or_strs({.x = 1; .y = "a"}) # ERR
12 changes: 12 additions & 0 deletions tests/should_ok/assert_cast.er
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,15 @@ opt_i as Int or NoneType = 1
rec = { .opt_i; }
if rec.opt_i != None, do:
assert rec.opt_i.abs() == 1

ints_or_strs(r: {.x = Int; .y = Int} or {.x = Str; .y = Str}): Nat =
if r.x in Int:
do: r.y.abs()
do: 0
assert ints_or_strs({.x = 1; .y = 2}) == 2

int_or_strs(rec: { .x = Int } or { .x = Str; .y = Str }): Str =
if rec.x in Str:
do: rec.y
do: str rec.x
assert int_or_strs({.x = 1}) == "1"
2 changes: 1 addition & 1 deletion tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ fn exec_record() -> Result<(), ()> {

#[test]
fn exec_record_err() -> Result<(), ()> {
expect_failure("tests/should_err/record.er", 0, 1)
expect_failure("tests/should_err/record.er", 0, 2)
}

#[test]
Expand Down

0 comments on commit ce5eafc

Please sign in to comment.