diff --git a/crates/erg_common/dict.rs b/crates/erg_common/dict.rs index ad62cf597..2e9cd8900 100644 --- a/crates/erg_common/dict.rs +++ b/crates/erg_common/dict.rs @@ -329,6 +329,16 @@ impl Dict { { self.dict.remove_entry(k) } + + pub fn remove_entries<'q, Q>(&mut self, keys: impl IntoIterator) + where + K: Borrow, + Q: Hash + Eq + ?Sized + 'q, + { + for k in keys { + self.remove_entry(k); + } + } } impl Dict { diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index d5ebb611e..9c8879f8f 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -1736,7 +1736,21 @@ impl Context { } } // {.i = Int} and {.s = Str} == {.i = Int; .s = Str} - (Record(l), Record(r)) => Type::Record(l.clone().concat(r.clone())), + // {.i = Int} and {.i = Nat} == {.i = Nat} + // {i = Int} and {.i = Int} == {.i = Nat} + (Record(l), Record(r)) => { + let mut new = l.clone().concat(r.clone()); + let duplicates = l.keys().filter(|&k| r.contains_key(k)).collect::>(); + for k in duplicates { + let v = &mut new[k]; + *v = self.intersection(&l[k], &r[k]); + let (k, t) = new.get_key_value(k).unwrap(); + if k.vis.is_private() { + new.insert(Field::public(k.symbol.clone()), t.clone()); + } + } + Type::Record(new) + } // {i = Int; j = Int} and not {i = Int} == {j = Int} // not {i = Int} and {i = Int; j = Int} == {j = Int} (other @ Record(rec), Not(t)) | (Not(t), other @ Record(rec)) => match t.as_ref() { diff --git a/crates/erg_compiler/context/register.rs b/crates/erg_compiler/context/register.rs index d5baea3a3..b30419082 100644 --- a/crates/erg_compiler/context/register.rs +++ b/crates/erg_compiler/context/register.rs @@ -2075,6 +2075,22 @@ impl Context { TypeObj::Builtin { t, .. } => Some(t), TypeObj::Generated(t) => t.base_or_sup().map(|t| t.typ()), }; + let invalid_fields = if let Some(TypeObj::Builtin { + t: Type::Record(rec), + .. + }) = gen.additional() + { + if let Err((fields, es)) = + self.check_subtype_instance_attrs(sup.typ(), rec, call) + { + errs.extend(es); + fields + } else { + Set::new() + } + } else { + Set::new() + }; // `Super.Requirement := {x = Int}` and `Self.Additional := {y = Int}` // => `Self.Requirement := {x = Int; y = Int}` let call_t = { @@ -2089,8 +2105,15 @@ impl Context { errs.extend(es); } } + let param_t = if let Some(Type::Record(rec)) = param_t { + let mut rec = rec.clone(); + rec.remove_entries(&invalid_fields); + Some(Type::Record(rec)) + } else { + param_t.cloned() + }; let nd_params = param_t - .map(|t| self.intersection(t, additional.typ())) + .map(|pt| self.intersection(&pt, additional.typ())) .or(Some(additional.typ().clone())) .map_or(vec![], |t| vec![ParamTy::Pos(t)]); (nd_params, None, vec![], None) @@ -2121,8 +2144,15 @@ impl Context { let (nd_params, var_params, d_params, kw_var_params) = if let Some(additional) = gen.additional() { + let param_t = if let Some(Type::Record(rec)) = param_t { + let mut rec = rec.clone(); + rec.remove_entries(&invalid_fields); + Some(Type::Record(rec)) + } else { + param_t.cloned() + }; let nd_params = param_t - .map(|t| self.intersection(t, additional.typ())) + .map(|pt| self.intersection(&pt, additional.typ())) .or(Some(additional.typ().clone())) .map_or(vec![], |t| vec![ParamTy::Pos(t)]); (nd_params, None, vec![], None) @@ -2209,12 +2239,68 @@ impl Context { } } + fn check_subtype_instance_attrs( + &self, + sup: &Type, + rec: &Dict, + call: Option<&ast::Call>, + ) -> Result<(), (Set, CompileErrors)> { + let mut errs = CompileErrors::empty(); + let mut invalid_fields = Set::new(); + let sup_ctx = self.get_nominal_type_ctx(sup); + let additional = call.and_then(|call| { + if let Some(ast::Expr::Record(record)) = call.args.get_with_key("Additional") { + Some(record) + } else { + None + } + }); + for (field, sub_t) in rec.iter() { + let loc = additional + .as_ref() + .and_then(|record| { + record + .keys() + .iter() + .find(|id| id.inspect() == &field.symbol) + .map(|name| name.loc()) + }) + .unwrap_or_default(); + let varname = VarName::from_str(field.symbol.clone()); + if let Some(sup_ctx) = sup_ctx { + if let Some(sup_vi) = sup_ctx.decls.get(&varname) { + if !self.subtype_of(sub_t, &sup_vi.t) { + invalid_fields.insert(field.clone()); + errs.push(CompileError::type_mismatch_error( + self.cfg.input.clone(), + line!() as usize, + loc, + self.caused_by(), + &field.symbol, + None, + &sup_vi.t, + sub_t, + None, + None, + )); + } + } + } + } + if errs.is_empty() { + Ok(()) + } else { + Err((invalid_fields, errs)) + } + } + fn register_instance_attrs( &self, ctx: &mut Context, rec: &Dict, call: Option<&ast::Call>, ) -> CompileResult<()> { + let mut errs = CompileErrors::empty(); let record = call.and_then(|call| { if let Some(ast::Expr::Record(record)) = call .args @@ -2248,16 +2334,20 @@ impl Context { ); // self.index().register(&vi); if let Some(_ent) = ctx.decls.insert(varname.clone(), vi) { - return Err(CompileErrors::from(CompileError::duplicate_decl_error( + errs.push(CompileError::duplicate_decl_error( self.cfg.input.clone(), line!() as usize, varname.loc(), self.caused_by(), varname.inspect(), - ))); + )); } } - Ok(()) + if errs.is_empty() { + Ok(()) + } else { + Err(errs) + } } fn gen_class_new_method( diff --git a/tests/should_err/inherit.er b/tests/should_err/inherit.er new file mode 100644 index 000000000..70ea22c53 --- /dev/null +++ b/tests/should_err/inherit.er @@ -0,0 +1,9 @@ +@Inheritable +Base = Class { + .value = Str; +} + +C = Inherit Base, Additional := { + .value = Int # ERR +} +_ = C.new { .value = 10 } diff --git a/tests/should_ok/inherit.er b/tests/should_ok/inherit.er index b980052af..2c85bfc91 100644 --- a/tests/should_ok/inherit.er +++ b/tests/should_ok/inherit.er @@ -13,3 +13,13 @@ D. d: D = D.new(1) print! d assert D.const == 0 + +@Inheritable +Base = Class { + .value = Int +} +E = Inherit Base, Additional := { + .value = Nat +} +e = E.new { .value = 10 } +discard e.value.times! diff --git a/tests/test.rs b/tests/test.rs index 596ca57a5..40c029c0f 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -644,6 +644,11 @@ fn exec_infer_union_array() -> Result<(), ()> { expect_failure("tests/should_err/infer_union_array.er", 2, 1) } +#[test] +fn exec_inherit_err() -> Result<(), ()> { + expect_failure("tests/should_err/inherit.er", 0, 1) +} + #[test] fn exec_init_del_err() -> Result<(), ()> { expect_failure("tests/should_err/init_del.er", 0, 1)