From 74e89f6d5b4e8bd49537729311d73fdf15f72e0a Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 13 Mar 2024 22:35:08 +0900 Subject: [PATCH] fix: polymorphic type spec instantiation bugs --- crates/erg_common/lib.rs | 25 + crates/erg_common/set.rs | 1 + crates/erg_compiler/context/eval.rs | 23 +- .../erg_compiler/context/instantiate_spec.rs | 499 ++++++++++++------ crates/erg_compiler/context/mod.rs | 9 + crates/erg_compiler/context/register.rs | 109 +++- crates/erg_compiler/lower.rs | 41 +- crates/erg_parser/ast.rs | 4 +- tests/should_err/poly_type_spec.er | 5 + tests/should_ok/poly_type_spec.er | 7 + tests/test.rs | 10 + 11 files changed, 553 insertions(+), 180 deletions(-) create mode 100644 tests/should_err/poly_type_spec.er create mode 100644 tests/should_ok/poly_type_spec.er diff --git a/crates/erg_common/lib.rs b/crates/erg_common/lib.rs index 4e9b5fefb..583c24ef8 100644 --- a/crates/erg_common/lib.rs +++ b/crates/erg_common/lib.rs @@ -151,6 +151,31 @@ where Ok(v) } +pub fn failable_map_mut(i: I, mut f: F) -> Result, (Vec, Vec)> +where + F: FnMut(T) -> Result, + I: Iterator, +{ + let mut v = vec![]; + let mut errs = vec![]; + for x in i { + match f(x) { + Ok(y) => { + v.push(y); + } + Err((y, e)) => { + v.push(y); + errs.push(e); + } + } + } + if errs.is_empty() { + Ok(v) + } else { + Err((v, errs)) + } +} + pub fn unique_in_place(v: &mut Vec) { let mut uniques = Set::new(); v.retain(|e| uniques.insert(e.clone())); diff --git a/crates/erg_common/set.rs b/crates/erg_common/set.rs index b5780e17f..cb97aeb2a 100644 --- a/crates/erg_common/set.rs +++ b/crates/erg_common/set.rs @@ -186,6 +186,7 @@ impl Set { self.elems.contains(value) } + /// newly inserted: true, already present: false #[inline] pub fn insert(&mut self, value: T) -> bool { self.elems.insert(value) diff --git a/crates/erg_compiler/context/eval.rs b/crates/erg_compiler/context/eval.rs index 4d5d0fc5a..8332fd8b2 100644 --- a/crates/erg_compiler/context/eval.rs +++ b/crates/erg_compiler/context/eval.rs @@ -718,12 +718,19 @@ impl Context { fn eval_const_def(&mut self, def: &Def) -> EvalResult { if def.is_const() { + let mut errs = EvalErrors::empty(); let __name__ = def.sig.ident().unwrap().inspect(); let vis = self.instantiate_vis_modifier(def.sig.vis())?; let tv_cache = match &def.sig { Signature::Subr(subr) => { let ty_cache = - self.instantiate_ty_bounds(&subr.bounds, RegistrationMode::Normal)?; + match self.instantiate_ty_bounds(&subr.bounds, RegistrationMode::Normal) { + Ok(ty_cache) => ty_cache, + Err((ty_cache, es)) => { + errs.extend(es); + ty_cache + } + }; Some(ty_cache) } Signature::Var(_) => None, @@ -740,7 +747,8 @@ impl Context { } else { None }; - let (_ctx, errs) = self.check_decls_and_pop(); + let (_ctx, es) = self.check_decls_and_pop(); + errs.extend(es); self.register_gen_const( def.sig.ident().unwrap(), obj, @@ -917,10 +925,16 @@ impl Context { /// FIXME: grow fn eval_const_lambda(&self, lambda: &Lambda) -> EvalResult { + let mut errs = EvalErrors::empty(); let mut tmp_tv_cache = - self.instantiate_ty_bounds(&lambda.sig.bounds, RegistrationMode::Normal)?; + match self.instantiate_ty_bounds(&lambda.sig.bounds, RegistrationMode::Normal) { + Ok(ty_cache) => ty_cache, + Err((ty_cache, es)) => { + errs.extend(es); + ty_cache + } + }; let mut non_default_params = Vec::with_capacity(lambda.sig.params.non_defaults.len()); - let mut errs = EvalErrors::empty(); for sig in lambda.sig.params.non_defaults.iter() { match self.instantiate_param_ty( sig, @@ -2090,6 +2104,7 @@ impl Context { TyParam::ProjCall { obj, attr, args } => Ok(proj_call(*obj, attr, args)), // TyParam::Erased(_t) => Ok(Type::Obj), TyParam::Value(v) => self.convert_value_into_type(v).map_err(TyParam::Value), + TyParam::Erased(t) if t.is_type() => Ok(Type::Obj), // TODO: Dict, Set other => Err(other), } diff --git a/crates/erg_compiler/context/instantiate_spec.rs b/crates/erg_compiler/context/instantiate_spec.rs index 4e22f569d..8a47ed5c0 100644 --- a/crates/erg_compiler/context/instantiate_spec.rs +++ b/crates/erg_compiler/context/instantiate_spec.rs @@ -1,10 +1,10 @@ use std::option::Option; // conflicting to Type::Option +use erg_common::levenshtein::get_similar_name; #[allow(unused)] use erg_common::log; use erg_common::traits::{Locational, Stream}; -use erg_common::Str; -use erg_common::{assume_unreachable, dict, set, try_map_mut}; +use erg_common::{assume_unreachable, dict, failable_map_mut, set, Str}; use ast::{ NonDefaultParamSignature, ParamTySpec, PreDeclTypeSpec, TypeBoundSpec, TypeBoundSpecs, TypeSpec, @@ -166,36 +166,43 @@ impl Context { &self, bounds: &TypeBoundSpecs, mode: RegistrationMode, - ) -> TyCheckResult { + ) -> Failable { + let mut errs = TyCheckErrors::empty(); let mut tv_cache = TyVarCache::new(self.level, self); for bound in bounds.iter() { - self.instantiate_ty_bound(bound, &mut tv_cache, mode)?; + if let Err(es) = self.instantiate_ty_bound(bound, &mut tv_cache, mode) { + errs.extend(es); + } } for tv in tv_cache.tyvar_instances.values() { if tv.constraint().map(|c| c.is_uninited()).unwrap_or(false) { - return Err(TyCheckErrors::from(TyCheckError::no_var_error( + errs.push(TyCheckError::no_var_error( self.cfg.input.clone(), line!() as usize, bounds.loc(), self.caused_by(), &tv.local_name(), self.get_similar_name(&tv.local_name()), - ))); + )); } } for tp in tv_cache.typaram_instances.values() { if tp.constraint().map(|c| c.is_uninited()).unwrap_or(false) { - return Err(TyCheckErrors::from(TyCheckError::no_var_error( + errs.push(TyCheckError::no_var_error( self.cfg.input.clone(), line!() as usize, bounds.loc(), self.caused_by(), &tp.to_string(), self.get_similar_name(&tp.to_string()), - ))); + )); } } - Ok(tv_cache) + if errs.is_empty() { + Ok(tv_cache) + } else { + Err((tv_cache, errs)) + } } pub(crate) fn instantiate_var_sig_t( @@ -237,9 +244,13 @@ impl Context { } None => None, }; - let mut tmp_tv_cache = self - .instantiate_ty_bounds(&sig.bounds, PreRegister) - .map_err(|errs| (Type::Failure, errs))?; + let mut tmp_tv_cache = match self.instantiate_ty_bounds(&sig.bounds, PreRegister) { + Ok(tv_cache) => tv_cache, + Err((tv_cache, es)) => { + errs.extend(es); + tv_cache + } + }; let mut non_defaults = vec![]; for (n, param) in sig.params.non_defaults.iter().enumerate() { let opt_decl_t = opt_decl_sig_t @@ -343,7 +354,7 @@ impl Context { } } else { // preregisterならouter scopeで型宣言(see inference.md) - let level = if mode == PreRegister { + let level = if mode.is_preregister() { self.level } else { self.level + 1 @@ -374,7 +385,7 @@ impl Context { not_found_is_qvar: bool, ) -> Failable { let gen_free_t = || { - let level = if mode == PreRegister { + let level = if mode.is_preregister() { self.level } else { self.level + 1 @@ -392,7 +403,12 @@ impl Context { .map_err(|errs| { ( opt_decl_t.map_or(Type::Failure, |pt| pt.typ().clone()), - errs, + // Ignore errors if `mode == Normal`, because the errors have already been collected. + if mode.is_normal() { + TyCheckErrors::empty() + } else { + errs + }, ) })? } else { @@ -870,61 +886,121 @@ impl Context { )); }; let mut errs = TyCheckErrors::empty(); - // FIXME: kw args let mut new_params = vec![]; for ((i, arg), (name, param_vi)) in args.pos_args().enumerate().zip(ctx.params.iter()) { - let param = self.instantiate_const_expr( + match self.instantiate_arg( &arg.expr, - Some((ctx, i)), + param_vi, + name.as_ref(), + ctx, + i, tmp_tv_cache, not_found_is_qvar, - ); - let param = param - .or_else(|e| { - if not_found_is_qvar { - let name = arg.expr.to_string(); - // FIXME: handle `::` as a right way - let name = Str::rc(name.trim_start_matches("::")); - let tp = TyParam::named_free_var( - name.clone(), - self.level, - Constraint::Uninited, - ); - let varname = VarName::from_str(name); - tmp_tv_cache.push_or_init_typaram(&varname, &tp, self)?; - Ok(tp) - } else { - Err(e) + ) { + Ok(tp) => new_params.push(tp), + Err((tp, es)) => { + errs.extend(es); + new_params.push(tp); + } + } + } + let mut missing_args = vec![]; + // Fill kw params + for (_, param_vi) in ctx.params.iter().skip(args.pos_args.len()) { + new_params.push(TyParam::erased(param_vi.t.clone())); + } + let mut passed_kw_args = set! {}; + for (i, (name, param_vi)) in ctx.params.iter().skip(args.pos_args.len()).enumerate() + { + if let Some(idx) = name.as_ref().and_then(|name| { + args.kw_args + .iter() + .position(|arg| arg.keyword.inspect() == name.inspect()) + }) { + let kw_arg = &args.kw_args[idx]; + let already_passed = + !passed_kw_args.insert(kw_arg.keyword.inspect().clone()); + if already_passed { + errs.push(TyCheckError::multiple_args_error( + self.cfg.input.clone(), + line!() as usize, + kw_arg.loc(), + other, + self.caused_by(), + name.as_ref().map_or("_", |n| &n.inspect()[..]), + )); + } + let tp = match self.instantiate_arg( + &kw_arg.expr, + param_vi, + name.as_ref(), + ctx, + i, + tmp_tv_cache, + not_found_is_qvar, + ) { + Ok(tp) => tp, + Err((tp, es)) => { + errs.extend(es); + tp } - }) - .map_err(|err| (Type::Failure, err))?; - let arg_t = self - .get_tp_t(¶m) - .map_err(|err| { - log!(err "{param}: {err}"); - err - }) - .unwrap_or(Obj); - if self.subtype_of(&arg_t, ¶m_vi.t) { - new_params.push(param); - } else { - new_params.push(TyParam::erased(param_vi.t.clone())); - errs.push(TyCheckError::type_mismatch_error( - self.cfg.input.clone(), - line!() as usize, - arg.expr.loc(), - self.caused_by(), - name.as_ref().map_or("", |n| &n.inspect()[..]), - Some(i), - ¶m_vi.t, - &arg_t, - None, - None, - )); + }; + if let Some(old) = new_params.get_mut(args.pos_args.len() + idx) { + *old = tp; + } else { + log!(err "{tp} / {} / {idx}", args.pos_args.len()); + // TODO: too many kw args + } + } else if !param_vi.kind.has_default() { + missing_args + .push(name.as_ref().map_or("_".into(), |n| n.inspect().clone())); } } + if !missing_args.is_empty() { + errs.push(TyCheckError::args_missing_error( + self.cfg.input.clone(), + line!() as usize, + args.loc(), + other, + self.caused_by(), + missing_args, + )); + } + if ctx.params.len() < args.pos_args.len() + args.kw_args.len() { + errs.push(TyCheckError::too_many_args_error( + self.cfg.input.clone(), + line!() as usize, + args.loc(), + other, + self.caused_by(), + ctx.params.len(), + args.pos_args.len(), + args.kw_args.len(), + )); + } + let param_names = ctx + .params + .iter() + .filter_map(|(n, _)| n.as_ref().map(|n| &n.inspect()[..])) + .collect::>(); + for unexpected in args + .kw_args + .iter() + .filter(|kw| !passed_kw_args.contains(&kw.keyword.inspect()[..])) + { + let kw = unexpected.keyword.inspect(); + errs.push(TyCheckError::unexpected_kw_arg_error( + self.cfg.input.clone(), + line!() as usize, + unexpected.loc(), + other, + self.caused_by(), + unexpected.keyword.inspect(), + get_similar_name(param_names.iter(), kw).copied(), + )); + } // FIXME: non-builtin let t = poly(ctx.typ.qual_name(), new_params); if errs.is_empty() { @@ -936,6 +1012,63 @@ impl Context { } } + #[allow(clippy::too_many_arguments)] + fn instantiate_arg( + &self, + arg: &ConstExpr, + param_vi: &VarInfo, + name: Option<&VarName>, + ctx: &Context, + i: usize, + tmp_tv_cache: &mut TyVarCache, + not_found_is_qvar: bool, + ) -> Failable { + let param = + self.instantiate_const_expr(arg, Some((ctx, i)), tmp_tv_cache, not_found_is_qvar); + let param = param + .or_else(|e| { + if not_found_is_qvar { + let name = arg.to_string(); + // FIXME: handle `::` as a right way + let name = Str::rc(name.trim_start_matches("::")); + let tp = + TyParam::named_free_var(name.clone(), self.level, Constraint::Uninited); + let varname = VarName::from_str(name); + tmp_tv_cache.push_or_init_typaram(&varname, &tp, self)?; + Ok(tp) + } else { + Err(e) + } + }) + .map_err(|err| (TyParam::Failure, err))?; + let arg_t = self + .get_tp_t(¶m) + .map_err(|err| { + log!(err "{param}: {err}"); + err + }) + .unwrap_or(Obj); + if self.subtype_of(&arg_t, ¶m_vi.t) { + Ok(param) + } else { + let tp = TyParam::erased(param_vi.t.clone()); + let errs = TyCheckError::type_mismatch_error( + self.cfg.input.clone(), + line!() as usize, + arg.loc(), + self.caused_by(), + name.as_ref().map_or("", |n| &n.inspect()[..]), + Some(i), + ¶m_vi.t, + &arg_t, + None, + None, + ) + .into(); + Err((tp, errs)) + } + } + fn instantiate_acc( &self, acc: &ast::ConstAccessor, @@ -1232,37 +1365,52 @@ impl Context { Ok(TyParam::Record(tp_rec)) } ast::ConstExpr::Lambda(lambda) => { - let _tmp_tv_cache = - self.instantiate_ty_bounds(&lambda.sig.bounds, RegistrationMode::Normal)?; + let mut errs = TyCheckErrors::empty(); + let _tmp_tv_cache = match self + .instantiate_ty_bounds(&lambda.sig.bounds, RegistrationMode::Normal) + { + Ok(tv_cache) => tv_cache, + Err((tv_cache, es)) => { + errs.extend(es); + tv_cache + } + }; // Since there are type variables and other variables that can be constrained within closures, // they are `merge`d once and then `purge`d of type variables that are only used internally after instantiation. tmp_tv_cache.merge(&_tmp_tv_cache); let mut nd_params = Vec::with_capacity(lambda.sig.params.non_defaults.len()); for sig in lambda.sig.params.non_defaults.iter() { - let pt = self - .instantiate_param_ty( - sig, - None, - tmp_tv_cache, - RegistrationMode::Normal, - ParamKind::NonDefault, - not_found_is_qvar, - ) - // TODO: continue - .map_err(|(_, errs)| errs)?; + let pt = match self.instantiate_param_ty( + sig, + None, + tmp_tv_cache, + RegistrationMode::Normal, + ParamKind::NonDefault, + not_found_is_qvar, + ) { + Ok(pt) => pt, + Err((pt, es)) => { + errs.extend(es); + pt + } + }; nd_params.push(pt); } let var_params = if let Some(p) = lambda.sig.params.var_params.as_ref() { - let pt = self - .instantiate_param_ty( - p, - None, - tmp_tv_cache, - RegistrationMode::Normal, - ParamKind::VarParams, - not_found_is_qvar, - ) - .map_err(|(_, errs)| errs)?; + let pt = match self.instantiate_param_ty( + p, + None, + tmp_tv_cache, + RegistrationMode::Normal, + ParamKind::VarParams, + not_found_is_qvar, + ) { + Ok(pt) => pt, + Err((pt, es)) => { + errs.extend(es); + pt + } + }; Some(pt) } else { None @@ -1270,29 +1418,37 @@ impl Context { let mut d_params = Vec::with_capacity(lambda.sig.params.defaults.len()); for sig in lambda.sig.params.defaults.iter() { let expr = self.eval_const_expr(&sig.default_val)?; - let pt = self - .instantiate_param_ty( - &sig.sig, - None, - tmp_tv_cache, - RegistrationMode::Normal, - ParamKind::Default(expr.t()), - not_found_is_qvar, - ) - .map_err(|(_, errs)| errs)?; + let pt = match self.instantiate_param_ty( + &sig.sig, + None, + tmp_tv_cache, + RegistrationMode::Normal, + ParamKind::Default(expr.t()), + not_found_is_qvar, + ) { + Ok(pt) => pt, + Err((pt, es)) => { + errs.extend(es); + pt + } + }; d_params.push(pt); } let kw_var_params = if let Some(p) = lambda.sig.params.kw_var_params.as_ref() { - let pt = self - .instantiate_param_ty( - p, - None, - tmp_tv_cache, - RegistrationMode::Normal, - ParamKind::KwVarParams, - not_found_is_qvar, - ) - .map_err(|(_, errs)| errs)?; + let pt = match self.instantiate_param_ty( + p, + None, + tmp_tv_cache, + RegistrationMode::Normal, + ParamKind::KwVarParams, + not_found_is_qvar, + ) { + Ok(pt) => pt, + Err((pt, es)) => { + errs.extend(es); + pt + } + }; Some(pt) } else { None @@ -1344,14 +1500,18 @@ impl Context { body.push(param); } tmp_tv_cache.purge(&_tmp_tv_cache); - Ok(TyParam::Lambda(TyParamLambda::new( - lambda.clone(), - nd_params, - var_params, - d_params, - kw_var_params, - body, - ))) + if errs.is_empty() { + Ok(TyParam::Lambda(TyParamLambda::new( + lambda.clone(), + nd_params, + var_params, + d_params, + kw_var_params, + body, + ))) + } else { + Err(errs) + } } ast::ConstExpr::BinOp(bin) => { let Some(op) = token_kind_to_op_kind(bin.op.kind) else { @@ -1535,31 +1695,43 @@ impl Context { tmp_tv_cache: &mut TyVarCache, mode: RegistrationMode, not_found_is_qvar: bool, - ) -> TyCheckResult { - let t = self.instantiate_typespec_full( + ) -> Failable { + let mut errs = TyCheckErrors::empty(); + let t = match self.instantiate_typespec_full( &p.ty, opt_decl_t, tmp_tv_cache, mode, not_found_is_qvar, - )?; - if let Some(default_t) = default_t { - Ok(ParamTy::kw_default( - p.name.as_ref().unwrap().inspect().to_owned(), - t, - self.instantiate_typespec_full( - default_t, - opt_decl_t, - tmp_tv_cache, - mode, - not_found_is_qvar, - )?, - )) + ) { + Ok(t) => t, + Err(es) => { + errs.extend(es); + Type::Failure + } + }; + let pt = if let Some(default_t) = default_t { + let default = match self.instantiate_typespec_full( + default_t, + opt_decl_t, + tmp_tv_cache, + mode, + not_found_is_qvar, + ) { + Ok(t) => t, + Err(es) => { + errs.extend(es); + Type::Failure + } + }; + ParamTy::kw_default(p.name.as_ref().unwrap().inspect().to_owned(), t, default) + } else { + ParamTy::pos_or_kw(p.name.as_ref().map(|t| t.inspect().to_owned()), t) + }; + if errs.is_empty() { + Ok(pt) } else { - Ok(ParamTy::pos_or_kw( - p.name.as_ref().map(|t| t.inspect().to_owned()), - t, - )) + Err((pt, errs)) } } @@ -1860,8 +2032,15 @@ impl Context { Ok(int_interval(op, l, r)) } TypeSpec::Subr(subr) => { + let mut errs = TyCheckErrors::empty(); let mut inner_tv_ctx = if !subr.bounds.is_empty() { - let tv_cache = self.instantiate_ty_bounds(&subr.bounds, mode)?; + let tv_cache = match self.instantiate_ty_bounds(&subr.bounds, mode) { + Ok(tv_cache) => tv_cache, + Err((tv_cache, es)) => { + errs.extend(es); + tv_cache + } + }; Some(tv_cache) } else { None @@ -1874,7 +2053,7 @@ impl Context { } else { tmp_tv_cache }; - let non_defaults = try_map_mut(subr.non_defaults.iter(), |p| { + let non_defaults = match failable_map_mut(subr.non_defaults.iter(), |p| { self.instantiate_func_param_spec( p, opt_decl_t, @@ -1883,7 +2062,15 @@ impl Context { mode, not_found_is_qvar, ) - })?; + }) { + Ok(v) => v, + Err((v, es)) => { + for e in es { + errs.extend(e); + } + v + } + }; let var_params = subr .var_params .as_ref() @@ -1896,9 +2083,10 @@ impl Context { mode, not_found_is_qvar, ) + .map_err(|(_, es)| es) }) .transpose()?; - let defaults = try_map_mut(subr.defaults.iter(), |p| { + let defaults = failable_map_mut(subr.defaults.iter(), |p| { self.instantiate_func_param_spec( &p.param, opt_decl_t, @@ -1907,7 +2095,13 @@ impl Context { mode, not_found_is_qvar, ) - })? + }) + .unwrap_or_else(|(pts, es)| { + for e in es { + errs.extend(e); + } + pts + }) .into_iter() .collect(); let kw_var_params = subr @@ -1922,24 +2116,35 @@ impl Context { mode, not_found_is_qvar, ) + .map_err(|(_, es)| es) }) .transpose()?; - let return_t = self.instantiate_typespec_full( + let return_t = match self.instantiate_typespec_full( &subr.return_t, opt_decl_t, tmp_tv_ctx, mode, not_found_is_qvar, - )?; + ) { + Ok(t) => t, + Err(es) => { + errs.extend(es); + Type::Failure + } + }; // no quantification at this point (in `generalize_t`) - Ok(subr_t( - SubrKind::from(subr.arrow.kind), - non_defaults, - var_params, - defaults, - kw_var_params, - return_t, - )) + if errs.is_empty() { + Ok(subr_t( + SubrKind::from(subr.arrow.kind), + non_defaults, + var_params, + defaults, + kw_var_params, + return_t, + )) + } else { + Err(errs) + } } TypeSpec::TypeApp { spec, args } => { type_feature_error!( @@ -1960,9 +2165,9 @@ impl Context { tmp_tv_cache.push_refine_var(&name, t.clone(), self); let pred = self .instantiate_pred_from_expr(&refine.pred, tmp_tv_cache) - .map_err(|err| { + .map_err(|errs| { tmp_tv_cache.remove(name.inspect()); - err + errs })?; tmp_tv_cache.remove(name.inspect()); let refine = diff --git a/crates/erg_compiler/context/mod.rs b/crates/erg_compiler/context/mod.rs index d4240d2b6..6d8ff4dcf 100644 --- a/crates/erg_compiler/context/mod.rs +++ b/crates/erg_compiler/context/mod.rs @@ -555,6 +555,15 @@ pub enum RegistrationMode { Normal, } +impl RegistrationMode { + pub const fn is_preregister(&self) -> bool { + matches!(self, Self::PreRegister) + } + pub const fn is_normal(&self) -> bool { + matches!(self, Self::Normal) + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct ContextInfo { mod_id: usize, diff --git a/crates/erg_compiler/context/register.rs b/crates/erg_compiler/context/register.rs index e30a4f8ad..5f4fb66b9 100644 --- a/crates/erg_compiler/context/register.rs +++ b/crates/erg_compiler/context/register.rs @@ -17,7 +17,7 @@ use ast::{ ConstIdentifier, Decorator, DefId, Identifier, OperationKind, PolyTypeSpec, PreDeclTypeSpec, VarName, }; -use erg_parser::ast::{self, ClassAttr, TypeSpecWithOp}; +use erg_parser::ast::{self, ClassAttr, RecordAttrOrIdent, TypeSpecWithOp}; use crate::ty::constructors::{ free_var, func, func0, func1, module, proc, py_module, ref_, ref_mut, str_dict_t, tp_enum, @@ -1059,7 +1059,7 @@ impl Context { for expr in block.iter() { match expr { ast::Expr::Def(def) => { - if let Err(errs) = self.register_const_def(def) { + if let Err(errs) = self.register_def(def) { total_errs.extend(errs); } if def.def_kind().is_import() { @@ -1069,7 +1069,7 @@ impl Context { } } ast::Expr::ClassDef(class_def) => { - if let Err(errs) = self.register_const_def(&class_def.def) { + if let Err(errs) = self.register_def(&class_def.def) { total_errs.extend(errs); } let vis = self @@ -1092,7 +1092,7 @@ impl Context { self.grow(&class.local_name(), kind, vis.clone(), None); for attr in methods.attrs.iter() { if let ClassAttr::Def(def) = attr { - if let Err(errs) = self.register_const_def(def) { + if let Err(errs) = self.register_def(def) { total_errs.extend(errs); } } @@ -1113,7 +1113,7 @@ impl Context { } } ast::Expr::PatchDef(patch_def) => { - if let Err(errs) = self.register_const_def(&patch_def.def) { + if let Err(errs) = self.register_def(&patch_def.def) { total_errs.extend(errs); } } @@ -1206,7 +1206,7 @@ impl Context { } } - pub(crate) fn register_const_def(&mut self, def: &ast::Def) -> TyCheckResult<()> { + pub(crate) fn register_def(&mut self, def: &ast::Def) -> TyCheckResult<()> { let id = Some(def.body.id); let __name__ = def.sig.ident().map(|i| i.inspect()).unwrap_or(UBAR); let call = if let Some(ast::Expr::Call(call)) = &def.body.block.first() { @@ -1217,7 +1217,9 @@ impl Context { match &def.sig { ast::Signature::Subr(sig) => { if sig.is_const() { - let tv_cache = self.instantiate_ty_bounds(&sig.bounds, PreRegister)?; + let tv_cache = self + .instantiate_ty_bounds(&sig.bounds, PreRegister) + .map_err(|(_, errs)| errs)?; let vis = self.instantiate_vis_modifier(sig.vis())?; self.grow(__name__, ContextKind::Proc, vis, Some(tv_cache)); let (obj, const_t) = match self.eval_const_block(&def.body.block) { @@ -2515,7 +2517,9 @@ impl Context { false } // TODO: - _ => false, + PreDeclTypeSpec::Subscr { namespace: ns, .. } => { + self.inc_ref_expr(ns, namespace, tmp_tv_cache) + } } } @@ -2556,6 +2560,9 @@ impl Context { for arg in poly.args.kw_args() { self.inc_ref_expr(&arg.expr.clone().downgrade(), namespace, tmp_tv_cache); } + if let Some(arg) = poly.args.kw_var.as_ref() { + self.inc_ref_expr(&arg.expr.clone().downgrade(), namespace, tmp_tv_cache); + } self.inc_ref_acc(&poly.acc.clone().downgrade(), namespace, tmp_tv_cache) } @@ -2596,6 +2603,51 @@ impl Context { res } + fn inc_ref_params( + &self, + params: &ast::Params, + namespace: &Context, + tmp_tv_cache: &TyVarCache, + ) -> bool { + let mut res = false; + for param in params.non_defaults.iter() { + if let Some(expr) = param.t_spec.as_ref().map(|ts| &ts.t_spec_as_expr) { + if self.inc_ref_expr(expr, namespace, tmp_tv_cache) { + res = true; + } + } + } + if let Some(expr) = params + .var_params + .as_ref() + .and_then(|p| p.t_spec.as_ref().map(|ts| &ts.t_spec_as_expr)) + { + if self.inc_ref_expr(expr, namespace, tmp_tv_cache) { + res = true; + } + } + for param in params.defaults.iter() { + if let Some(expr) = param.sig.t_spec.as_ref().map(|ts| &ts.t_spec_as_expr) { + if self.inc_ref_expr(expr, namespace, tmp_tv_cache) { + res = true; + } + } + if self.inc_ref_expr(¶m.default_val, namespace, tmp_tv_cache) { + res = true; + } + } + if let Some(expr) = params + .kw_var_params + .as_ref() + .and_then(|p| p.t_spec.as_ref().map(|ts| &ts.t_spec_as_expr)) + { + if self.inc_ref_expr(expr, namespace, tmp_tv_cache) { + res = true; + } + } + res + } + fn inc_ref_expr( &self, expr: &ast::Expr, @@ -2638,6 +2690,24 @@ impl Context { } res } + ast::Expr::Record(ast::Record::Mixed(rec)) => { + let mut res = false; + for val in rec.attrs.iter() { + match val { + RecordAttrOrIdent::Attr(attr) => { + if self.inc_ref_block(&attr.body.block, namespace, tmp_tv_cache) { + res = true; + } + } + RecordAttrOrIdent::Ident(ident) => { + if self.inc_ref_local(ident, namespace, tmp_tv_cache) { + res = true; + } + } + } + } + res + } ast::Expr::Array(ast::Array::Normal(arr)) => { let mut res = false; for val in arr.elems.pos_args().iter() { @@ -2705,6 +2775,29 @@ impl Context { } res } + ast::Expr::TypeAscription(ascription) => { + self.inc_ref_expr(&ascription.expr, namespace, tmp_tv_cache) + } + ast::Expr::Compound(comp) => { + let mut res = false; + for expr in comp.exprs.iter() { + if self.inc_ref_expr(expr, namespace, tmp_tv_cache) { + res = true; + } + } + res + } + ast::Expr::Lambda(lambda) => { + let mut res = false; + // FIXME: assign params + if self.inc_ref_params(&lambda.sig.params, namespace, tmp_tv_cache) { + res = true; + } + if self.inc_ref_block(&lambda.body, namespace, tmp_tv_cache) { + res = true; + } + res + } other => { log!(err "inc_ref_expr: {other}"); false diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index 9b87e5b45..96acb56b6 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -577,13 +577,10 @@ impl GenericASTLowerer { .grow("", ContextKind::Dummy, Private, None); for attr in record.attrs.iter() { if attr.sig.is_const() { - self.module - .context - .register_const_def(attr) - .map_err(|errs| { - self.pop_append_errs(); - errs - })?; + self.module.context.register_def(attr).map_err(|errs| { + self.pop_append_errs(); + errs + })?; } } for attr in record.attrs.into_iter() { @@ -1573,11 +1570,18 @@ impl GenericASTLowerer { expect: Option<&SubrType>, ) -> LowerResult { log!(info "entered {}({})", fn_name!(), params); - let mut tmp_tv_ctx = self + let mut errs = LowerErrors::empty(); + let mut tmp_tv_ctx = match self .module .context - .instantiate_ty_bounds(&bounds, RegistrationMode::Normal)?; - let mut errs = LowerErrors::empty(); + .instantiate_ty_bounds(&bounds, RegistrationMode::Normal) + { + Ok(tv_ctx) => tv_ctx, + Err((tv_ctx, es)) => { + errs.extend(es); + tv_ctx + } + }; let mut hir_non_defaults = vec![]; for non_default in params.non_defaults.into_iter() { match self.lower_non_default_param(non_default) { @@ -1681,7 +1685,8 @@ impl GenericASTLowerer { let tv_cache = self .module .context - .instantiate_ty_bounds(&lambda.sig.bounds, RegistrationMode::Normal)?; + .instantiate_ty_bounds(&lambda.sig.bounds, RegistrationMode::Normal) + .map_err(|(_, es)| es)?; if !in_statement { self.module .context @@ -1937,7 +1942,8 @@ impl GenericASTLowerer { let tv_cache = self .module .context - .instantiate_ty_bounds(&sig.bounds, RegistrationMode::Normal)?; + .instantiate_ty_bounds(&sig.bounds, RegistrationMode::Normal) + .map_err(|(_, es)| es)?; self.module.context.grow(&name, kind, vis, Some(tv_cache)); self.lower_subr_def(sig, def.body) } @@ -2394,13 +2400,10 @@ impl GenericASTLowerer { def.sig.col_begin().unwrap(), )); } - self.module - .context - .register_const_def(def) - .map_err(|errs| { - self.pop_append_errs(); - errs - })?; + self.module.context.register_def(def).map_err(|errs| { + self.pop_append_errs(); + errs + })?; } ast::ClassAttr::Decl(_) | ast::ClassAttr::Doc(_) => {} } diff --git a/crates/erg_parser/ast.rs b/crates/erg_parser/ast.rs index 56432e474..9743cbfab 100644 --- a/crates/erg_parser/ast.rs +++ b/crates/erg_parser/ast.rs @@ -2921,9 +2921,9 @@ impl ConstKwArg { #[pyclass] #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct ConstArgs { - pos_args: Vec, + pub pos_args: Vec, pub var_args: Option>, - kw_args: Vec, + pub kw_args: Vec, pub kw_var: Option>, paren: Option<(Token, Token)>, } diff --git a/tests/should_err/poly_type_spec.er b/tests/should_err/poly_type_spec.er new file mode 100644 index 000000000..8a8162ab7 --- /dev/null +++ b/tests/should_err/poly_type_spec.er @@ -0,0 +1,5 @@ +f _: Array!() = None # ERR +g _: Array!(Int, M := 1) = None # ERR +h _: Array!(Int, N := 1, N := 2) = None # ERR + +_ = f, g, h diff --git a/tests/should_ok/poly_type_spec.er b/tests/should_ok/poly_type_spec.er new file mode 100644 index 000000000..1aedeb3db --- /dev/null +++ b/tests/should_ok/poly_type_spec.er @@ -0,0 +1,7 @@ +f _: Array!(Int, _) = None +g _: Array!(Int) = None +h _: Array!(Int, N := 1) = None + +f ![1, 2] +g ![1, 2] +h ![1] diff --git a/tests/test.rs b/tests/test.rs index 0fbe1642a..bccbb4349 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -304,6 +304,11 @@ fn exec_pattern() -> Result<(), ()> { expect_success("tests/should_ok/pattern.er", 0) } +#[test] +fn exec_poly_type_spec() -> Result<(), ()> { + expect_success("tests/should_ok/poly_type_spec.er", 0) +} + #[test] fn exec_pyimport_test() -> Result<(), ()> { // HACK: When running the test with Windows, the exit code is 1 (the cause is unknown) @@ -555,6 +560,11 @@ fn exec_move_check() -> Result<(), ()> { expect_failure("examples/move_check.er", 1, 1) } +#[test] +fn exec_poly_type_spec_err() -> Result<(), ()> { + expect_failure("tests/should_err/poly_type_spec.er", 0, 3) +} + #[test] fn exec_pyimport() -> Result<(), ()> { if cfg!(unix) {