From fae18d3c153491b4a564a0bc741f700a23e37646 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Fri, 16 Feb 2024 17:00:04 +0900 Subject: [PATCH] Update inquire.rs --- crates/erg_compiler/context/inquire.rs | 374 +++++++++++++------------ 1 file changed, 190 insertions(+), 184 deletions(-) diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index 416d8b9f4..096bd5674 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -1751,203 +1751,209 @@ impl Context { // instance must be instantiated Type::Quantified(_) => unreachable_error!(TyCheckErrors, TyCheckError, self), Type::Subr(subr) => { - let mut errs = TyCheckErrors::empty(); - // method: obj: 1, subr: (self: Int, other: Int) -> Int - // non-method: obj: Int, subr: (self: Int, other: Int) -> Int - // FIXME: staticmethod - let is_method = subr - .self_t() - .map_or(false, |self_t| self.subtype_of(obj.ref_t(), self_t)); - let callee = if let Some(ident) = attr_name { - if is_method { - obj.clone() - } else { - let attr = - hir::Attribute::new(obj.clone(), hir::Identifier::bare(ident.clone())); - hir::Expr::Accessor(hir::Accessor::Attr(attr)) - } - } else { - obj.clone() - }; - let params_len = if is_method { - subr.non_default_params.len().saturating_sub(1) + subr.default_params.len() + let res = self.substitute_subr_call(obj, attr_name, subr, pos_args, kw_args); + // TODO: change polymorphic type syntax + if res.is_err() && subr.return_t.is_class_type() { + self.substitute_dunder_call( + obj, attr_name, instance, pos_args, kw_args, namespace, + ) + .or(res) } else { - subr.non_default_params.len() + subr.default_params.len() - }; - if (params_len < pos_args.len() || params_len < pos_args.len() + kw_args.len()) - && subr.is_no_var() - { - return Err( - self.gen_too_many_args_error(&callee, subr, is_method, pos_args, kw_args) - ); + res } - let mut passed_params = set! {}; - let non_default_params = if is_method { - let mut non_default_params = subr.non_default_params.iter(); - let self_pt = non_default_params.next().unwrap(); + } + Type::Failure => Ok(SubstituteResult::Ok), + _ => { + self.substitute_dunder_call(obj, attr_name, instance, pos_args, kw_args, namespace) + } + } + } + + fn substitute_subr_call( + &self, + obj: &hir::Expr, + attr_name: &Option, + subr: &SubrType, + pos_args: &[hir::PosArg], + kw_args: &[hir::KwArg], + ) -> TyCheckResult { + let mut errs = TyCheckErrors::empty(); + // method: obj: 1, subr: (self: Int, other: Int) -> Int + // non-method: obj: Int, subr: (self: Int, other: Int) -> Int + // FIXME: staticmethod + let is_method = subr + .self_t() + .map_or(false, |self_t| self.subtype_of(obj.ref_t(), self_t)); + let callee = if let Some(ident) = attr_name { + if is_method { + obj.clone() + } else { + let attr = hir::Attribute::new(obj.clone(), hir::Identifier::bare(ident.clone())); + hir::Expr::Accessor(hir::Accessor::Attr(attr)) + } + } else { + obj.clone() + }; + let params_len = if is_method { + subr.non_default_params.len().saturating_sub(1) + subr.default_params.len() + } else { + subr.non_default_params.len() + subr.default_params.len() + }; + if (params_len < pos_args.len() || params_len < pos_args.len() + kw_args.len()) + && subr.is_no_var() + { + return Err(self.gen_too_many_args_error(&callee, subr, is_method, pos_args, kw_args)); + } + let mut passed_params = set! {}; + let non_default_params = if is_method { + let mut non_default_params = subr.non_default_params.iter(); + let self_pt = non_default_params.next().unwrap(); + if let Err(mut es) = self.sub_unify(obj.ref_t(), self_pt.typ(), obj, self_pt.name()) { + errs.append(&mut es); + } + passed_params.insert("self".into()); + non_default_params + } else { + subr.non_default_params.iter() + }; + let non_default_params_len = non_default_params.len(); + if pos_args.len() >= non_default_params_len { + let (non_default_args, var_args) = pos_args.split_at(non_default_params_len); + let mut args = non_default_args + .iter() + .zip(non_default_params) + .enumerate() + .collect::>(); + // TODO: remove `obj.local_name() != Some("__contains__")` + if obj.local_name() != Some("__contains__") && !subr.essential_qnames().is_empty() { + args.sort_by(|(_, (l, _)), (_, (r, _))| { + l.expr.complexity().cmp(&r.expr.complexity()) + }); + } + for (i, (nd_arg, nd_param)) in args { + if let Err(mut es) = self.substitute_pos_arg( + &callee, + attr_name, + &nd_arg.expr, + i + 1, + nd_param, + &mut passed_params, + ) { + errs.append(&mut es); + } + } + let mut nth = 1 + non_default_params_len; + if let Some(var_param) = subr.var_params.as_ref() { + for var_arg in var_args.iter() { if let Err(mut es) = - self.sub_unify(obj.ref_t(), self_pt.typ(), obj, self_pt.name()) + self.substitute_var_arg(&callee, attr_name, &var_arg.expr, nth, var_param) { errs.append(&mut es); } - passed_params.insert("self".into()); - non_default_params - } else { - subr.non_default_params.iter() - }; - let non_default_params_len = non_default_params.len(); - if pos_args.len() >= non_default_params_len { - let (non_default_args, var_args) = pos_args.split_at(non_default_params_len); - let mut args = non_default_args - .iter() - .zip(non_default_params) - .enumerate() - .collect::>(); - // TODO: remove `obj.local_name() != Some("__contains__")` - if obj.local_name() != Some("__contains__") - && !subr.essential_qnames().is_empty() - { - args.sort_by(|(_, (l, _)), (_, (r, _))| { - l.expr.complexity().cmp(&r.expr.complexity()) - }); - } - for (i, (nd_arg, nd_param)) in args { - if let Err(mut es) = self.substitute_pos_arg( - &callee, - attr_name, - &nd_arg.expr, - i + 1, - nd_param, - &mut passed_params, - ) { - errs.append(&mut es); - } - } - let mut nth = 1 + non_default_params_len; - if let Some(var_param) = subr.var_params.as_ref() { - for var_arg in var_args.iter() { - if let Err(mut es) = self.substitute_var_arg( - &callee, - attr_name, - &var_arg.expr, - nth, - var_param, - ) { - errs.append(&mut es); - } - nth += 1; - } - } else { - for (arg, pt) in var_args.iter().zip(subr.default_params.iter()) { - if let Err(mut es) = self.substitute_pos_arg( - &callee, - attr_name, - &arg.expr, - nth, - pt, - &mut passed_params, - ) { - errs.append(&mut es); - } - nth += 1; - } - } - for kw_arg in kw_args.iter() { - if let Err(mut es) = self.substitute_kw_arg( - &callee, - attr_name, - kw_arg, - nth, - subr, - &mut passed_params, - ) { - errs.append(&mut es); - } - nth += 1; - } - for not_passed in subr - .default_params - .iter() - .filter(|pt| !passed_params.contains(pt.name().unwrap())) - { - if let ParamTy::KwWithDefault { ty, default, .. } = ¬_passed { - if let Err(mut es) = self.sub_unify(default, ty, obj, not_passed.name()) - { - errs.append(&mut es); - } - } - } - } else { - let mut nth = 1; - // pos_args.len() < non_default_params_len - // don't use `zip` - let mut params = non_default_params.chain(subr.default_params.iter()); - for pos_arg in pos_args.iter() { - if let Err(mut es) = self.substitute_pos_arg( - &callee, - attr_name, - &pos_arg.expr, - nth, - params.next().unwrap(), - &mut passed_params, - ) { - errs.append(&mut es); - } - nth += 1; - } - for kw_arg in kw_args.iter() { - if let Err(mut es) = self.substitute_kw_arg( - &callee, - attr_name, - kw_arg, - nth, - subr, - &mut passed_params, - ) { - errs.append(&mut es); - } - nth += 1; + nth += 1; + } + } else { + for (arg, pt) in var_args.iter().zip(subr.default_params.iter()) { + if let Err(mut es) = self.substitute_pos_arg( + &callee, + attr_name, + &arg.expr, + nth, + pt, + &mut passed_params, + ) { + errs.append(&mut es); } - let missing_params = subr - .non_default_params - .iter() - .enumerate() - .filter(|(_, pt)| { - pt.name().map_or(true, |name| !passed_params.contains(name)) - }) - .map(|(i, pt)| { - let n = if is_method { i } else { i + 1 }; - let nth = format!("({} param)", ordinal_num(n)); - pt.name() - .map_or(nth.clone(), |name| format!("{name} {nth}")) - .into() - }) - .collect::>(); - if !missing_params.is_empty() { - return Err(TyCheckErrors::from(TyCheckError::args_missing_error( - self.cfg.input.clone(), - line!() as usize, - callee.loc(), - &callee.to_string(), - self.caused_by(), - missing_params, - ))); + nth += 1; + } + } + for kw_arg in kw_args.iter() { + if let Err(mut es) = self.substitute_kw_arg( + &callee, + attr_name, + kw_arg, + nth, + subr, + &mut passed_params, + ) { + errs.append(&mut es); + } + nth += 1; + } + for not_passed in subr + .default_params + .iter() + .filter(|pt| !passed_params.contains(pt.name().unwrap())) + { + if let ParamTy::KwWithDefault { ty, default, .. } = ¬_passed { + if let Err(mut es) = self.sub_unify(default, ty, obj, not_passed.name()) { + errs.append(&mut es); } } - if errs.is_empty() { - /*if subr.has_qvar() { - panic!("{subr} has qvar"); - }*/ - Ok(SubstituteResult::Ok) - } else { - Err(errs) + } + } else { + let mut nth = 1; + // pos_args.len() < non_default_params_len + // don't use `zip` + let mut params = non_default_params.chain(subr.default_params.iter()); + for pos_arg in pos_args.iter() { + if let Err(mut es) = self.substitute_pos_arg( + &callee, + attr_name, + &pos_arg.expr, + nth, + params.next().unwrap(), + &mut passed_params, + ) { + errs.append(&mut es); } + nth += 1; } - Type::Failure => Ok(SubstituteResult::Ok), - _ => { - self.substitute_dunder_call(obj, attr_name, instance, pos_args, kw_args, namespace) + for kw_arg in kw_args.iter() { + if let Err(mut es) = self.substitute_kw_arg( + &callee, + attr_name, + kw_arg, + nth, + subr, + &mut passed_params, + ) { + errs.append(&mut es); + } + nth += 1; + } + let missing_params = subr + .non_default_params + .iter() + .enumerate() + .filter(|(_, pt)| pt.name().map_or(true, |name| !passed_params.contains(name))) + .map(|(i, pt)| { + let n = if is_method { i } else { i + 1 }; + let nth = format!("({} param)", ordinal_num(n)); + pt.name() + .map_or(nth.clone(), |name| format!("{name} {nth}")) + .into() + }) + .collect::>(); + if !missing_params.is_empty() { + return Err(TyCheckErrors::from(TyCheckError::args_missing_error( + self.cfg.input.clone(), + line!() as usize, + callee.loc(), + &callee.to_string(), + self.caused_by(), + missing_params, + ))); } } + if errs.is_empty() { + /*if subr.has_qvar() { + panic!("{subr} has qvar"); + }*/ + Ok(SubstituteResult::Ok) + } else { + Err(errs) + } } fn substitute_dunder_call(