Skip to content

Commit

Permalink
fix: check subtype field types mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Oct 1, 2024
1 parent be428cf commit a490811
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 6 deletions.
10 changes: 10 additions & 0 deletions crates/erg_common/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,16 @@ impl<K: Hash + Eq + Immutable, V> Dict<K, V> {
{
self.dict.remove_entry(k)
}

pub fn remove_entries<'q, Q>(&mut self, keys: impl IntoIterator<Item = &'q Q>)
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized + 'q,
{
for k in keys {
self.remove_entry(k);
}
}
}

impl<K: Hash + Eq, V> Dict<K, V> {
Expand Down
16 changes: 15 additions & 1 deletion crates/erg_compiler/context/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1736,7 +1736,21 @@ impl Context {
}
}
// {.i = Int} and {.s = Str} == {.i = Int; .s = Str}
(Record(l), Record(r)) => Type::Record(l.clone().concat(r.clone())),
// {.i = Int} and {.i = Nat} == {.i = Nat}
// {i = Int} and {.i = Int} == {.i = Nat}
(Record(l), Record(r)) => {
let mut new = l.clone().concat(r.clone());
let duplicates = l.keys().filter(|&k| r.contains_key(k)).collect::<Set<_>>();
for k in duplicates {
let v = &mut new[k];
*v = self.intersection(&l[k], &r[k]);
let (k, t) = new.get_key_value(k).unwrap();
if k.vis.is_private() {
new.insert(Field::public(k.symbol.clone()), t.clone());
}
}
Type::Record(new)
}
// {i = Int; j = Int} and not {i = Int} == {j = Int}
// not {i = Int} and {i = Int; j = Int} == {j = Int}
(other @ Record(rec), Not(t)) | (Not(t), other @ Record(rec)) => match t.as_ref() {
Expand Down
100 changes: 95 additions & 5 deletions crates/erg_compiler/context/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2075,6 +2075,22 @@ impl Context {
TypeObj::Builtin { t, .. } => Some(t),
TypeObj::Generated(t) => t.base_or_sup().map(|t| t.typ()),
};
let invalid_fields = if let Some(TypeObj::Builtin {
t: Type::Record(rec),
..
}) = gen.additional()
{
if let Err((fields, es)) =
self.check_subtype_instance_attrs(sup.typ(), rec, call)
{
errs.extend(es);
fields
} else {
Set::new()
}
} else {
Set::new()
};
// `Super.Requirement := {x = Int}` and `Self.Additional := {y = Int}`
// => `Self.Requirement := {x = Int; y = Int}`
let call_t = {
Expand All @@ -2089,8 +2105,15 @@ impl Context {
errs.extend(es);
}
}
let param_t = if let Some(Type::Record(rec)) = param_t {
let mut rec = rec.clone();
rec.remove_entries(&invalid_fields);
Some(Type::Record(rec))
} else {
param_t.cloned()
};
let nd_params = param_t
.map(|t| self.intersection(t, additional.typ()))
.map(|pt| self.intersection(&pt, additional.typ()))
.or(Some(additional.typ().clone()))
.map_or(vec![], |t| vec![ParamTy::Pos(t)]);
(nd_params, None, vec![], None)
Expand Down Expand Up @@ -2121,8 +2144,15 @@ impl Context {
let (nd_params, var_params, d_params, kw_var_params) = if let Some(additional) =
gen.additional()
{
let param_t = if let Some(Type::Record(rec)) = param_t {
let mut rec = rec.clone();
rec.remove_entries(&invalid_fields);
Some(Type::Record(rec))
} else {
param_t.cloned()
};
let nd_params = param_t
.map(|t| self.intersection(t, additional.typ()))
.map(|pt| self.intersection(&pt, additional.typ()))
.or(Some(additional.typ().clone()))
.map_or(vec![], |t| vec![ParamTy::Pos(t)]);
(nd_params, None, vec![], None)
Expand Down Expand Up @@ -2209,12 +2239,68 @@ impl Context {
}
}

fn check_subtype_instance_attrs(
&self,
sup: &Type,
rec: &Dict<Field, Type>,
call: Option<&ast::Call>,
) -> Result<(), (Set<Field>, CompileErrors)> {
let mut errs = CompileErrors::empty();
let mut invalid_fields = Set::new();
let sup_ctx = self.get_nominal_type_ctx(sup);
let additional = call.and_then(|call| {
if let Some(ast::Expr::Record(record)) = call.args.get_with_key("Additional") {
Some(record)
} else {
None
}
});
for (field, sub_t) in rec.iter() {
let loc = additional
.as_ref()
.and_then(|record| {
record
.keys()
.iter()
.find(|id| id.inspect() == &field.symbol)
.map(|name| name.loc())
})
.unwrap_or_default();
let varname = VarName::from_str(field.symbol.clone());
if let Some(sup_ctx) = sup_ctx {
if let Some(sup_vi) = sup_ctx.decls.get(&varname) {
if !self.subtype_of(sub_t, &sup_vi.t) {
invalid_fields.insert(field.clone());
errs.push(CompileError::type_mismatch_error(
self.cfg.input.clone(),
line!() as usize,
loc,
self.caused_by(),
&field.symbol,
None,
&sup_vi.t,
sub_t,
None,
None,
));
}
}
}
}
if errs.is_empty() {
Ok(())
} else {
Err((invalid_fields, errs))
}
}

fn register_instance_attrs(
&self,
ctx: &mut Context,
rec: &Dict<Field, Type>,
call: Option<&ast::Call>,
) -> CompileResult<()> {
let mut errs = CompileErrors::empty();
let record = call.and_then(|call| {
if let Some(ast::Expr::Record(record)) = call
.args
Expand Down Expand Up @@ -2248,16 +2334,20 @@ impl Context {
);
// self.index().register(&vi);
if let Some(_ent) = ctx.decls.insert(varname.clone(), vi) {
return Err(CompileErrors::from(CompileError::duplicate_decl_error(
errs.push(CompileError::duplicate_decl_error(
self.cfg.input.clone(),
line!() as usize,
varname.loc(),
self.caused_by(),
varname.inspect(),
)));
));
}
}
Ok(())
if errs.is_empty() {
Ok(())
} else {
Err(errs)
}
}

fn gen_class_new_method(
Expand Down
9 changes: 9 additions & 0 deletions tests/should_err/inherit.er
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@Inheritable
Base = Class {
.value = Str;
}

C = Inherit Base, Additional := {
.value = Int # ERR
}
_ = C.new { .value = 10 }
10 changes: 10 additions & 0 deletions tests/should_ok/inherit.er
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,13 @@ D.
d: D = D.new(1)
print! d
assert D.const == 0

@Inheritable
Base = Class {
.value = Int
}
E = Inherit Base, Additional := {
.value = Nat
}
e = E.new { .value = 10 }
discard e.value.times!
5 changes: 5 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,11 @@ fn exec_infer_union_array() -> Result<(), ()> {
expect_failure("tests/should_err/infer_union_array.er", 2, 1)
}

#[test]
fn exec_inherit_err() -> Result<(), ()> {
expect_failure("tests/should_err/inherit.er", 0, 1)
}

#[test]
fn exec_init_del_err() -> Result<(), ()> {
expect_failure("tests/should_err/init_del.er", 0, 1)
Expand Down

0 comments on commit a490811

Please sign in to comment.