diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index db34b68fd..85bee96ae 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -372,33 +372,40 @@ impl Context { kind: SubrKind, pos_args: &[hir::PosArg], kw_args: &[hir::KwArg], - ) -> TyCheckResult { + ) -> FailableOption { + let mut errs = TyCheckErrors::empty(); if !kw_args.is_empty() { // TODO: this error desc is not good - return Err(TyCheckErrors::from(TyCheckError::default_param_error( - self.cfg.input.clone(), - line!() as usize, - kw_args[0].loc(), - self.caused_by(), - "match", - ))); + return Err(( + None, + TyCheckErrors::from(TyCheckError::default_param_error( + self.cfg.input.clone(), + line!() as usize, + kw_args[0].loc(), + self.caused_by(), + "match", + )), + )); } for pos_arg in pos_args.iter().skip(1) { let t = pos_arg.expr.ref_t(); // Allow only anonymous functions to be passed as match arguments (for aesthetic reasons) if !matches!(&pos_arg.expr, hir::Expr::Lambda(_)) { - return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error( - self.cfg.input.clone(), - line!() as usize, - pos_arg.loc(), - self.caused_by(), - "match", + return Err(( None, - &mono("LambdaFunc"), - t, - self.get_candidates(t), - self.get_simple_type_mismatch_hint(&mono("LambdaFunc"), t), - ))); + TyCheckErrors::from(TyCheckError::type_mismatch_error( + self.cfg.input.clone(), + line!() as usize, + pos_arg.loc(), + self.caused_by(), + "match", + None, + &mono("LambdaFunc"), + t, + self.get_candidates(t), + self.get_simple_type_mismatch_hint(&mono("LambdaFunc"), t), + )), + )); } } let match_target_expr_t = pos_args[0].expr.ref_t(); @@ -408,36 +415,48 @@ impl Context { for (i, pos_arg) in pos_args.iter().skip(1).enumerate() { let lambda = erg_common::enum_unwrap!(&pos_arg.expr, hir::Expr::Lambda); // already checked if !lambda.params.defaults.is_empty() { - return Err(TyCheckErrors::from(TyCheckError::default_param_error( - self.cfg.input.clone(), - line!() as usize, - pos_args[i + 1].loc(), - self.caused_by(), - "match", - ))); + return Err(( + None, + TyCheckErrors::from(TyCheckError::default_param_error( + self.cfg.input.clone(), + line!() as usize, + pos_args[i + 1].loc(), + self.caused_by(), + "match", + )), + )); } if lambda.params.len() != 1 { - return Err(TyCheckErrors::from(TyCheckError::param_error( - self.cfg.input.clone(), - line!() as usize, - pos_args[i + 1].loc(), - self.caused_by(), - 1, - lambda.params.len(), - ))); + return Err(( + None, + TyCheckErrors::from(TyCheckError::param_error( + self.cfg.input.clone(), + line!() as usize, + pos_args[i + 1].loc(), + self.caused_by(), + 1, + lambda.params.len(), + )), + )); } let mut dummy_tv_cache = TyVarCache::new(self.level, self); - let rhs = self - .instantiate_param_sig_t( - &lambda.params.non_defaults[0].raw, - None, - &mut dummy_tv_cache, - Normal, - ParamKind::NonDefault, - false, - ) - // TODO: continue - .map_err(|(_, errs)| errs)?; + let rhs = match self.instantiate_param_sig_t( + &lambda.params.non_defaults[0].raw, + None, + &mut dummy_tv_cache, + Normal, + ParamKind::NonDefault, + false, + ) { + Ok(ty) => ty, + Err((ty, es)) => { + errs.extend(es); + ty + } + }; + if lambda.params.non_defaults[0].raw.t_spec.is_none() && rhs.is_free_var() { + rhs.link(&Obj, None); + } union_pat_t = self.union(&union_pat_t, &rhs); arm_ts.push(rhs); } @@ -447,7 +466,7 @@ impl Context { if cfg!(feature = "debug") { eprintln!("match error: {err}"); } - return Err(TyCheckErrors::from(TyCheckError::match_error( + errs.push(TyCheckError::match_error( self.cfg.input.clone(), line!() as usize, pos_args[0].loc(), @@ -455,7 +474,8 @@ impl Context { match_target_expr_t, &union_pat_t, arm_ts, - ))); + )); + return Err((None, errs)); } let branch_ts = pos_args .iter() @@ -466,17 +486,17 @@ impl Context { .get(0) .and_then(|branch| branch.typ().return_t().cloned()) else { - return Err(TyCheckErrors::from(TyCheckError::args_missing_error( + errs.push(TyCheckError::args_missing_error( self.cfg.input.clone(), line!() as usize, pos_args[0].loc(), "match", self.caused_by(), vec![Str::ever("obj")], - ))); + )); + return Err((None, errs)); }; for arg_t in branch_ts.iter().skip(1) { - // TODO: handle unwrap errors return_t = self.union(&return_t, arg_t.typ().return_t().unwrap_or(&Type::Never)); } let param_ty = ParamTy::Pos(match_target_expr_t.clone()); @@ -486,10 +506,15 @@ impl Context { } else { proc(param_ts, None, vec![], return_t) }; - Ok(VarInfo { + let vi = VarInfo { t, ..VarInfo::default() - }) + }; + if errs.is_empty() { + Ok(vi) + } else { + Err((Some(vi), errs)) + } } pub(crate) fn rec_get_var_info( @@ -2149,14 +2174,10 @@ impl Context { if local.vis().is_private() { match &local.inspect()[..] { "match" => { - return self - .get_match_call_t(SubrKind::Func, pos_args, kw_args) - .map_err(|errs| (None, errs)); + return self.get_match_call_t(SubrKind::Func, pos_args, kw_args); } "match!" => { - return self - .get_match_call_t(SubrKind::Proc, pos_args, kw_args) - .map_err(|errs| (None, errs)); + return self.get_match_call_t(SubrKind::Proc, pos_args, kw_args); } _ => {} } diff --git a/tests/should_ok/match.er b/tests/should_ok/match.er new file mode 100644 index 000000000..23e6f6959 --- /dev/null +++ b/tests/should_ok/match.er @@ -0,0 +1,9 @@ +f x: Obj = + match x: + (s: Str) -> s + "a" + { foo; bar } -> foo + bar + a -> a + +assert f("a") == "aa" +assert f({ foo = "a"; bar = "b" }) == "ab" +assert str(f(1)) == "1" diff --git a/tests/test.rs b/tests/test.rs index a136dc7d0..08e35b96e 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -211,6 +211,11 @@ fn exec_map() -> Result<(), ()> { expect_success("tests/should_ok/map.er", 0) } +#[test] +fn exec_match() -> Result<(), ()> { + expect_success("tests/should_ok/match.er", 0) +} + #[test] fn exec_method() -> Result<(), ()> { expect_success("tests/should_ok/method.er", 0)