From 245d9eee848d3a8b8c8dd943f6fcecf685cf1fb9 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 27 Apr 2023 15:51:56 +0900 Subject: [PATCH] fix: dict typing bugs --- crates/erg_compiler/context/compare.rs | 33 +++++--- crates/erg_compiler/context/eval.rs | 8 +- .../context/initialize/classes.rs | 10 +-- .../context/initialize/const_func.rs | 80 ++++++++++++++----- crates/erg_compiler/context/unify.rs | 5 +- crates/erg_compiler/ty/mod.rs | 2 +- crates/erg_compiler/ty/typaram.rs | 2 +- crates/erg_compiler/ty/value.rs | 5 ++ tests/should_err/mut_dict.er | 7 ++ tests/should_ok/mut_dict.er | 1 + tests/test.rs | 5 ++ 11 files changed, 119 insertions(+), 39 deletions(-) create mode 100644 tests/should_err/mut_dict.er diff --git a/crates/erg_compiler/context/compare.rs b/crates/erg_compiler/context/compare.rs index ac1927965..a46471685 100644 --- a/crates/erg_compiler/context/compare.rs +++ b/crates/erg_compiler/context/compare.rs @@ -1023,17 +1023,7 @@ impl Context { (Some(sub), Some(sup)) => bounded(sub.clone(), sup.clone()), _ => self.simple_union(lhs, rhs), }, - (other, or @ Or(l, r)) | (or @ Or(l, r), other) => { - if &self.union(other, l) == l.as_ref() || &self.union(other, r) == r.as_ref() { - or.clone() - } else if &self.union(other, l) == other { - self.union(other, r) - } else if &self.union(other, r) == other { - self.union(other, l) - } else { - self.simple_union(lhs, rhs) - } - } + (other, or @ Or(_, _)) | (or @ Or(_, _), other) => self.union_add(or, other), (t, Type::Never) | (Type::Never, t) => t.clone(), // Array({1, 2}, 2), Array({3, 4}, 2) ==> Array({1, 2, 3, 4}, 2) ( @@ -1094,6 +1084,27 @@ impl Context { } } + /// ```erg + /// union_add(Int or ?T(:> NoneType), Nat) == Int or ?T + /// union_add(Int or ?T(:> NoneType), Str) == Int or ?T or Str + /// ``` + fn union_add(&self, union: &Type, elem: &Type) -> Type { + let union_ts = union.union_types(); + let fixed = union_ts.into_iter().map(|t| { + if let Ok(free) = <&FreeTyVar>::try_from(&t) { + free.get_sub().unwrap_or(t) + } else { + t + } + }); + for t in fixed { + if self.supertype_of(&t, elem) { + return union.clone(); + } + } + or(union.clone(), elem.clone()) + } + /// ```erg /// simple_union(?T, ?U) == ?T or ?U /// union(Set!(?T(<: Int), 3), Set(?U(<: Nat), 3)) == Set(?T, 3) diff --git a/crates/erg_compiler/context/eval.rs b/crates/erg_compiler/context/eval.rs index 09a191b8f..92b15eb62 100644 --- a/crates/erg_compiler/context/eval.rs +++ b/crates/erg_compiler/context/eval.rs @@ -828,6 +828,12 @@ impl Context { (TyParam::Value(lhs), TyParam::Value(rhs)) => { self.eval_bin(op, lhs, rhs).map(TyParam::value) } + (TyParam::Dict(l), TyParam::Dict(r)) if op == OpKind::Add => { + Ok(TyParam::Dict(l.concat(r))) + } + (TyParam::Array(l), TyParam::Array(r)) if op == OpKind::Add => { + Ok(TyParam::Array([l, r].concat())) + } (TyParam::FreeVar(fv), r) if fv.is_linked() => { self.eval_bin_tp(op, fv.crack().clone(), r) } @@ -854,7 +860,7 @@ impl Context { (lhs @ TyParam::FreeVar(_), rhs) => Ok(TyParam::bin(op, lhs, rhs)), (lhs, rhs @ TyParam::FreeVar(_)) => Ok(TyParam::bin(op, lhs, rhs)), (e @ TyParam::Erased(_), _) | (_, e @ TyParam::Erased(_)) => Ok(e), - (l, r) => feature_error!(self, Location::Unknown, &format!("{l:?} {op} {r:?}")) + (l, r) => feature_error!(self, Location::Unknown, &format!("{l} {op} {r}")) .map_err(Into::into), } } diff --git a/crates/erg_compiler/context/initialize/classes.rs b/crates/erg_compiler/context/initialize/classes.rs index bb4b77f47..1e300052e 100644 --- a/crates/erg_compiler/context/initialize/classes.rs +++ b/crates/erg_compiler/context/initialize/classes.rs @@ -2028,7 +2028,7 @@ impl Context { ); array_mut_.register_trait(array_mut_t.clone(), array_mut_mutable); /* Dict! */ - let dict_mut_t = poly(MUT_DICT, vec![D]); + let dict_mut_t = poly(MUT_DICT, vec![D.clone()]); let mut dict_mut = Self::builtin_poly_class(MUT_DICT, vec![PS::named_nd(TY_D, mono(GENERIC_DICT))], 3); dict_mut.register_superclass(dict_t.clone(), &dict_); @@ -2037,12 +2037,10 @@ impl Context { let insert_t = pr_met( ref_mut( dict_mut_t.clone(), - // TODO: - None, - /*Some(poly( + Some(poly( MUT_DICT, - vec![D + dict!{ K.clone() => V.clone() }.into()], - )),*/ + vec![D + dict! { K.clone() => V.clone() }.into()], + )), ), vec![kw(KW_KEY, K), kw(KW_VALUE, V)], None, diff --git a/crates/erg_compiler/context/initialize/const_func.rs b/crates/erg_compiler/context/initialize/const_func.rs index 70aa666ef..56ef19d70 100644 --- a/crates/erg_compiler/context/initialize/const_func.rs +++ b/crates/erg_compiler/context/initialize/const_func.rs @@ -7,7 +7,7 @@ use crate::context::Context; use crate::feature_error; use crate::ty::constructors::{and, mono, poly, tuple_t, ty_tp}; use crate::ty::value::{EvalValueError, EvalValueResult, GenTypeObj, TypeObj, ValueObj}; -use crate::ty::{Type, ValueArgs}; +use crate::ty::{TyParam, Type, ValueArgs}; use erg_common::error::{ErrorCore, ErrorKind, Location, SubMessage}; use erg_common::style::{Color, StyledStr, StyledString, THEME}; @@ -263,27 +263,71 @@ pub(crate) fn __array_getitem__(mut args: ValueArgs, ctx: &Context) -> EvalValue } } +pub(crate) fn sub_vdict_get<'d>( + dict: &'d Dict, + key: &ValueObj, + ctx: &Context, +) -> Option<&'d ValueObj> { + let mut matches = vec![]; + for (k, v) in dict.iter() { + match (key, k) { + (ValueObj::Type(idx), ValueObj::Type(kt)) if ctx.subtype_of(idx.typ(), kt.typ()) => { + matches.push((idx, kt, v)); + } + (idx, k) if idx == k => { + return Some(v); + } + _ => {} + } + } + for (idx, kt, v) in matches.into_iter() { + match ctx.sub_unify(idx.typ(), kt.typ(), &(), None) { + Ok(_) => { + return Some(v); + } + Err(_err) => { + erg_common::log!(err "{idx} {v}"); + } + } + } + None +} + +pub(crate) fn sub_tpdict_get<'d>( + dict: &'d Dict, + key: &TyParam, + ctx: &Context, +) -> Option<&'d TyParam> { + let mut matches = vec![]; + for (k, v) in dict.iter() { + match (<&Type>::try_from(key), <&Type>::try_from(k)) { + (Ok(idx), Ok(kt)) if ctx.subtype_of(idx, kt) => { + matches.push((idx, kt, v)); + } + (_, _) if key == k => { + return Some(v); + } + _ => {} + } + } + for (idx, kt, v) in matches.into_iter() { + match ctx.sub_unify(idx, kt, &(), None) { + Ok(_) => { + return Some(v); + } + Err(_err) => { + erg_common::log!(err "{idx} {v}"); + } + } + } + None +} + pub(crate) fn __dict_getitem__(mut args: ValueArgs, ctx: &Context) -> EvalValueResult { let slf = args.remove_left_or_key("Self").unwrap(); let slf = enum_unwrap!(slf, ValueObj::Dict); let index = args.remove_left_or_key("Index").unwrap(); - if let Some(v) = slf.get(&index).or_else(|| { - for (k, v) in slf.iter() { - match (&index, k) { - (ValueObj::Type(idx), ValueObj::Type(kt)) => { - if ctx.subtype_of(idx.typ(), kt.typ()) { - return Some(v); - } - } - (idx, k) => { - if idx == k { - return Some(v); - } - } - } - } - None - }) { + if let Some(v) = slf.get(&index).or_else(|| sub_vdict_get(&slf, &index, ctx)) { Ok(v.clone()) } else { let index = if let ValueObj::Type(t) = &index { diff --git a/crates/erg_compiler/context/unify.rs b/crates/erg_compiler/context/unify.rs index 72c0bd71d..8eaaa0364 100644 --- a/crates/erg_compiler/context/unify.rs +++ b/crates/erg_compiler/context/unify.rs @@ -22,6 +22,8 @@ use Predicate as Pred; use Type::*; use ValueObj::{Inf, NegInf}; +use super::initialize::const_func::sub_tpdict_get; + impl Context { /// ```erg /// occur(?T, ?T) ==> OK @@ -444,9 +446,10 @@ impl Context { } (TyParam::Dict(ls), TyParam::Dict(rs)) => { for (lk, lv) in ls.iter() { - if let Some(rv) = rs.get(lk) { + if let Some(rv) = rs.get(lk).or_else(|| sub_tpdict_get(rs, lk, self)) { self.sub_unify_tp(lv, rv, _variance, loc, allow_divergence)?; } else { + log!(err "{rs} does not have key {lk}"); // TODO: return Err(TyCheckErrors::from(TyCheckError::unreachable( self.cfg.input.clone(), diff --git a/crates/erg_compiler/ty/mod.rs b/crates/erg_compiler/ty/mod.rs index 58a29bb5a..da75096bf 100644 --- a/crates/erg_compiler/ty/mod.rs +++ b/crates/erg_compiler/ty/mod.rs @@ -2389,7 +2389,7 @@ impl Type { pub fn has_qvar(&self) -> bool { match self { Self::FreeVar(fv) if fv.is_linked() => fv.crack().has_qvar(), - Self::FreeVar(fv) if fv.is_generalized() => true, + Self::FreeVar(fv) if fv.is_unbound() && fv.is_generalized() => true, Self::FreeVar(fv) => { if let Some((sub, sup)) = fv.get_subsup() { fv.dummy_link(); diff --git a/crates/erg_compiler/ty/typaram.rs b/crates/erg_compiler/ty/typaram.rs index 82f07f81a..2d01478f8 100644 --- a/crates/erg_compiler/ty/typaram.rs +++ b/crates/erg_compiler/ty/typaram.rs @@ -959,7 +959,7 @@ impl TyParam { pub fn has_qvar(&self) -> bool { match self { - Self::FreeVar(fv) if fv.is_generalized() => true, + Self::FreeVar(fv) if fv.is_unbound() && fv.is_generalized() => true, Self::FreeVar(fv) if fv.is_linked() => fv.crack().has_qvar(), Self::Type(t) => t.has_qvar(), Self::Proj { obj, .. } => obj.has_qvar(), diff --git a/crates/erg_compiler/ty/value.rs b/crates/erg_compiler/ty/value.rs index 02ff6a833..7ebd3b224 100644 --- a/crates/erg_compiler/ty/value.rs +++ b/crates/erg_compiler/ty/value.rs @@ -1011,6 +1011,11 @@ impl ValueObj { (Self::Nat(l), Self::Float(r)) => Some(Self::Float(l as f64 - r)), (Self::Float(l), Self::Int(r)) => Some(Self::Float(l - r as f64)), (Self::Str(l), Self::Str(r)) => Some(Self::Str(Str::from(format!("{l}{r}")))), + (Self::Array(l), Self::Array(r)) => { + let arr = Rc::from([l, r].concat()); + Some(Self::Array(arr)) + } + (Self::Dict(l), Self::Dict(r)) => Some(Self::Dict(l.concat(r))), (inf @ (Self::Inf | Self::NegInf), _) | (_, inf @ (Self::Inf | Self::NegInf)) => { Some(inf) } diff --git a/tests/should_err/mut_dict.er b/tests/should_err/mut_dict.er new file mode 100644 index 000000000..461cee210 --- /dev/null +++ b/tests/should_err/mut_dict.er @@ -0,0 +1,7 @@ +d = {"a": 1} +dict = !d + +dict.insert! "b", 2 +_ = dict.get("a") == "a" # ERR +_ = dict.get("b") == "a" # ERR +_ = dict.get("c") # ERR diff --git a/tests/should_ok/mut_dict.er b/tests/should_ok/mut_dict.er index 4e344f828..ed92d9a07 100644 --- a/tests/should_ok/mut_dict.er +++ b/tests/should_ok/mut_dict.er @@ -3,3 +3,4 @@ dict = !d dict.insert! "b", 2 assert dict.get("a") == 1 +assert dict.get("b") == 2 diff --git a/tests/test.rs b/tests/test.rs index 66f0fe6c2..d14b43e65 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -347,6 +347,11 @@ fn exec_mut_array_err() -> Result<(), ()> { expect_failure("tests/should_err/mut_array.er", 0, 4) } +#[test] +fn exec_mut_dict_err() -> Result<(), ()> { + expect_failure("tests/should_err/mut_dict.er", 0, 3) +} + #[test] fn exec_quantified_err() -> Result<(), ()> { expect_failure("tests/should_err/quantified.er", 0, 3)