Skip to content

Commit

Permalink
fix: match type check
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Oct 12, 2023
1 parent 3f4520d commit e6b56d4
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 58 deletions.
137 changes: 79 additions & 58 deletions crates/erg_compiler/context/inquire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,33 +372,40 @@ impl Context {
kind: SubrKind,
pos_args: &[hir::PosArg],
kw_args: &[hir::KwArg],
) -> TyCheckResult<VarInfo> {
) -> FailableOption<VarInfo> {
let mut errs = TyCheckErrors::empty();
if !kw_args.is_empty() {
// TODO: this error desc is not good
return Err(TyCheckErrors::from(TyCheckError::default_param_error(
self.cfg.input.clone(),
line!() as usize,
kw_args[0].loc(),
self.caused_by(),
"match",
)));
return Err((
None,
TyCheckErrors::from(TyCheckError::default_param_error(
self.cfg.input.clone(),
line!() as usize,
kw_args[0].loc(),
self.caused_by(),
"match",
)),
));
}
for pos_arg in pos_args.iter().skip(1) {
let t = pos_arg.expr.ref_t();
// Allow only anonymous functions to be passed as match arguments (for aesthetic reasons)
if !matches!(&pos_arg.expr, hir::Expr::Lambda(_)) {
return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error(
self.cfg.input.clone(),
line!() as usize,
pos_arg.loc(),
self.caused_by(),
"match",
return Err((
None,
&mono("LambdaFunc"),
t,
self.get_candidates(t),
self.get_simple_type_mismatch_hint(&mono("LambdaFunc"), t),
)));
TyCheckErrors::from(TyCheckError::type_mismatch_error(
self.cfg.input.clone(),
line!() as usize,
pos_arg.loc(),
self.caused_by(),
"match",
None,
&mono("LambdaFunc"),
t,
self.get_candidates(t),
self.get_simple_type_mismatch_hint(&mono("LambdaFunc"), t),
)),
));
}
}
let match_target_expr_t = pos_args[0].expr.ref_t();
Expand All @@ -408,36 +415,48 @@ impl Context {
for (i, pos_arg) in pos_args.iter().skip(1).enumerate() {
let lambda = erg_common::enum_unwrap!(&pos_arg.expr, hir::Expr::Lambda); // already checked
if !lambda.params.defaults.is_empty() {
return Err(TyCheckErrors::from(TyCheckError::default_param_error(
self.cfg.input.clone(),
line!() as usize,
pos_args[i + 1].loc(),
self.caused_by(),
"match",
)));
return Err((
None,
TyCheckErrors::from(TyCheckError::default_param_error(
self.cfg.input.clone(),
line!() as usize,
pos_args[i + 1].loc(),
self.caused_by(),
"match",
)),
));
}
if lambda.params.len() != 1 {
return Err(TyCheckErrors::from(TyCheckError::param_error(
self.cfg.input.clone(),
line!() as usize,
pos_args[i + 1].loc(),
self.caused_by(),
1,
lambda.params.len(),
)));
return Err((
None,
TyCheckErrors::from(TyCheckError::param_error(
self.cfg.input.clone(),
line!() as usize,
pos_args[i + 1].loc(),
self.caused_by(),
1,
lambda.params.len(),
)),
));
}
let mut dummy_tv_cache = TyVarCache::new(self.level, self);
let rhs = self
.instantiate_param_sig_t(
&lambda.params.non_defaults[0].raw,
None,
&mut dummy_tv_cache,
Normal,
ParamKind::NonDefault,
false,
)
// TODO: continue
.map_err(|(_, errs)| errs)?;
let rhs = match self.instantiate_param_sig_t(
&lambda.params.non_defaults[0].raw,
None,
&mut dummy_tv_cache,
Normal,
ParamKind::NonDefault,
false,
) {
Ok(ty) => ty,
Err((ty, es)) => {
errs.extend(es);
ty
}
};
if lambda.params.non_defaults[0].raw.t_spec.is_none() && rhs.is_free_var() {
rhs.link(&Obj, None);
}
union_pat_t = self.union(&union_pat_t, &rhs);
arm_ts.push(rhs);
}
Expand All @@ -447,15 +466,16 @@ impl Context {
if cfg!(feature = "debug") {
eprintln!("match error: {err}");
}
return Err(TyCheckErrors::from(TyCheckError::match_error(
errs.push(TyCheckError::match_error(
self.cfg.input.clone(),
line!() as usize,
pos_args[0].loc(),
self.caused_by(),
match_target_expr_t,
&union_pat_t,
arm_ts,
)));
));
return Err((None, errs));
}
let branch_ts = pos_args
.iter()
Expand All @@ -466,17 +486,17 @@ impl Context {
.get(0)
.and_then(|branch| branch.typ().return_t().cloned())
else {
return Err(TyCheckErrors::from(TyCheckError::args_missing_error(
errs.push(TyCheckError::args_missing_error(
self.cfg.input.clone(),
line!() as usize,
pos_args[0].loc(),
"match",
self.caused_by(),
vec![Str::ever("obj")],
)));
));
return Err((None, errs));
};
for arg_t in branch_ts.iter().skip(1) {
// TODO: handle unwrap errors
return_t = self.union(&return_t, arg_t.typ().return_t().unwrap_or(&Type::Never));
}
let param_ty = ParamTy::Pos(match_target_expr_t.clone());
Expand All @@ -486,10 +506,15 @@ impl Context {
} else {
proc(param_ts, None, vec![], return_t)
};
Ok(VarInfo {
let vi = VarInfo {
t,
..VarInfo::default()
})
};
if errs.is_empty() {
Ok(vi)
} else {
Err((Some(vi), errs))
}
}

pub(crate) fn rec_get_var_info(
Expand Down Expand Up @@ -2149,14 +2174,10 @@ impl Context {
if local.vis().is_private() {
match &local.inspect()[..] {
"match" => {
return self
.get_match_call_t(SubrKind::Func, pos_args, kw_args)
.map_err(|errs| (None, errs));
return self.get_match_call_t(SubrKind::Func, pos_args, kw_args);
}
"match!" => {
return self
.get_match_call_t(SubrKind::Proc, pos_args, kw_args)
.map_err(|errs| (None, errs));
return self.get_match_call_t(SubrKind::Proc, pos_args, kw_args);
}
_ => {}
}
Expand Down
9 changes: 9 additions & 0 deletions tests/should_ok/match.er
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
f x: Obj =
match x:
(s: Str) -> s + "a"
{ foo; bar } -> foo + bar
a -> a

assert f("a") == "aa"
assert f({ foo = "a"; bar = "b" }) == "ab"
assert str(f(1)) == "1"
5 changes: 5 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ fn exec_map() -> Result<(), ()> {
expect_success("tests/should_ok/map.er", 0)
}

#[test]
fn exec_match() -> Result<(), ()> {
expect_success("tests/should_ok/match.er", 0)
}

#[test]
fn exec_method() -> Result<(), ()> {
expect_success("tests/should_ok/method.er", 0)
Expand Down

0 comments on commit e6b56d4

Please sign in to comment.