Skip to content

Commit

Permalink
fix: forward-referenced method inference bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Oct 29, 2023
1 parent 34a20e7 commit 6713ffe
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 69 deletions.
15 changes: 15 additions & 0 deletions crates/erg_compiler/context/inquire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,21 @@ impl Context {
.map(|(opt_name, vi)| (opt_name.as_ref().unwrap(), vi))
}

pub fn get_method_kv(&self, name: &str) -> Option<(&VarName, &VarInfo)> {
#[cfg(feature = "py_compat")]
let name = self.erg_to_py_names.get(name).map_or(name, |s| &s[..]);
self.get_var_kv(name)
.or_else(|| {
for methods in self.methods_list.iter() {
if let Some(vi) = methods.get_method_kv(name) {
return Some(vi);
}
}
None
})
.or_else(|| self.get_outer().and_then(|ctx| ctx.get_method_kv(name)))
}

pub fn get_singular_ctxs_by_hir_expr(
&self,
obj: &hir::Expr,
Expand Down
148 changes: 99 additions & 49 deletions crates/erg_compiler/context/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -741,13 +741,99 @@ impl Context {
}
}

fn unify_params_t(
&self,
sig: &ast::SubrSignature,
registered_t: &SubrType,
params: &hir::Params,
body_t: &Type,
body_loc: &impl Locational,
) -> TyCheckResult<()> {
let name = &sig.ident.name;
let mut errs = TyCheckErrors::empty();
for (param, pt) in params
.non_defaults
.iter()
.zip(registered_t.non_default_params.iter())
{
pt.typ().lower();
if let Err(es) = self.force_sub_unify(&param.vi.t, pt.typ(), param, None) {
errs.extend(es);
}
pt.typ().lift();
}
// TODO: var_params: [Int; _], pt: Int
/*if let Some((var_params, pt)) = params.var_params.as_deref().zip(registered_t.var_params.as_ref()) {
pt.typ().lower();
if let Err(es) = self.force_sub_unify(&var_params.vi.t, pt.typ(), var_params, None) {
errs.extend(es);
}
pt.typ().lift();
}*/
for (param, pt) in params
.defaults
.iter()
.zip(registered_t.default_params.iter())
{
pt.typ().lower();
if let Err(es) = self.force_sub_unify(&param.sig.vi.t, pt.typ(), param, None) {
errs.extend(es);
}
pt.typ().lift();
}
let spec_ret_t = registered_t.return_t.as_ref();
// spec_ret_t.lower();
let unify_return_result = if let Some(t_spec) = sig.return_t_spec.as_ref() {
self.force_sub_unify(body_t, spec_ret_t, t_spec, None)
} else {
self.force_sub_unify(body_t, spec_ret_t, body_loc, None)
};
// spec_ret_t.lift();
if let Err(unify_errs) = unify_return_result {
let es = TyCheckErrors::new(
unify_errs
.into_iter()
.map(|e| {
let expect = if cfg!(feature = "debug") {
spec_ret_t.clone()
} else {
self.readable_type(spec_ret_t.clone())
};
let found = if cfg!(feature = "debug") {
body_t.clone()
} else {
self.readable_type(body_t.clone())
};
TyCheckError::return_type_error(
self.cfg.input.clone(),
line!() as usize,
e.core.get_loc_with_fallback(),
e.caused_by,
readable_name(name.inspect()),
&expect,
&found,
// e.core.get_hint().map(|s| s.to_string()),
)
})
.collect(),
);
errs.extend(es);
}
if errs.is_empty() {
Ok(())
} else {
Err(errs)
}
}

/// ## Errors
/// * TypeError: if `return_t` != typeof `body`
/// * AssignError: if `name` has already been registered
pub(crate) fn assign_subr(
&mut self,
sig: &ast::SubrSignature,
id: DefId,
params: &hir::Params,
body_t: &Type,
body_loc: &impl Locational,
) -> Result<VarInfo, (TyCheckErrors, VarInfo)> {
Expand All @@ -772,63 +858,27 @@ impl Context {
};
let name = &sig.ident.name;
// FIXME: constでない関数
let t = self.get_current_scope_var(name).map(|vi| &vi.t).unwrap();
debug_assert!(t.is_subr(), "{t} is not subr");
let empty = vec![];
let non_default_params = t.non_default_params().unwrap_or(&empty);
let var_args = t.var_params();
let default_params = t.default_params().unwrap_or(&empty);
if let Some(spec_ret_t) = t.return_t() {
let unify_result = if let Some(t_spec) = sig.return_t_spec.as_ref() {
self.sub_unify(body_t, spec_ret_t, t_spec, None)
} else {
self.sub_unify(body_t, spec_ret_t, body_loc, None)
};
if let Err(unify_errs) = unify_result {
let es = TyCheckErrors::new(
unify_errs
.into_iter()
.map(|e| {
let expect = if cfg!(feature = "debug") {
spec_ret_t.clone()
} else {
self.readable_type(spec_ret_t.clone())
};
let found = if cfg!(feature = "debug") {
body_t.clone()
} else {
self.readable_type(body_t.clone())
};
TyCheckError::return_type_error(
self.cfg.input.clone(),
line!() as usize,
e.core.get_loc_with_fallback(),
e.caused_by,
readable_name(name.inspect()),
&expect,
&found,
// e.core.get_hint().map(|s| s.to_string()),
)
})
.collect(),
);
errs.extend(es);
}
let subr_t = self.get_current_scope_var(name).map(|vi| &vi.t).unwrap();
let Ok(subr_t) = <&SubrType>::try_from(subr_t) else {
panic!("{subr_t} is not subr");
};
if let Err(es) = self.unify_params_t(sig, subr_t, params, body_t, body_loc) {
errs.extend(es);
}
// NOTE: not `body_t.clone()` because the body may contain `return`
let return_t = t.return_t().unwrap().clone();
let return_t = subr_t.return_t.as_ref().clone();
let sub_t = if sig.ident.is_procedural() {
proc(
non_default_params.clone(),
var_args.cloned(),
default_params.clone(),
subr_t.non_default_params.clone(),
subr_t.var_params.as_deref().cloned(),
subr_t.default_params.clone(),
return_t,
)
} else {
func(
non_default_params.clone(),
var_args.cloned(),
default_params.clone(),
subr_t.non_default_params.clone(),
subr_t.var_params.as_deref().cloned(),
subr_t.default_params.clone(),
return_t,
)
};
Expand Down
16 changes: 16 additions & 0 deletions crates/erg_compiler/context/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ impl Context {
}
}

pub fn assert_attr_type(&self, receiver_t: &Type, attr: &str, ty: &Type) -> Result<(), ()> {
let Some(ctx) = self.get_nominal_type_ctx(receiver_t) else {
panic!("type not found: {receiver_t}");
};
let Some((_, vi)) = ctx.get_method_kv(attr) else {
panic!("attribute not found: {attr}");
};
println!("{attr}: {}", vi.t);
if vi.t.structural_eq(ty) {
Ok(())
} else {
println!("{attr} is not the type of {ty}");
Err(())
}
}

pub fn test_refinement_subtyping(&self) -> Result<(), ()> {
// Nat :> {I: Int | I >= 1} ?
let lhs = Nat;
Expand Down
54 changes: 39 additions & 15 deletions crates/erg_compiler/context/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub struct Unifier<'c, 'l, 'u, L: Locational> {
ctx: &'c Context,
loc: &'l L,
undoable: Option<&'u UndoableLinkedList>,
change_generalized: bool,
param_name: Option<Str>,
}

Expand All @@ -39,12 +40,14 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
ctx: &'c Context,
loc: &'l L,
undoable: Option<&'u UndoableLinkedList>,
change_generalized: bool,
param_name: Option<Str>,
) -> Self {
Self {
ctx,
loc,
undoable,
change_generalized,
param_name,
}
}
Expand Down Expand Up @@ -326,7 +329,11 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
}
Ok(())
}
(TyParam::FreeVar(sub_fv), _) if sub_fv.is_generalized() => Ok(()),
(TyParam::FreeVar(sub_fv), _)
if !self.change_generalized && sub_fv.is_generalized() =>
{
Ok(())
}
(TyParam::FreeVar(sub_fv), sup_tp) => {
match &*sub_fv.borrow() {
FreeKind::Linked(l) | FreeKind::UndoableLinked { t: l, .. } => {
Expand Down Expand Up @@ -366,7 +373,11 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
)))
}
}
(_, TyParam::FreeVar(sup_fv)) if sup_fv.is_generalized() => Ok(()),
(_, TyParam::FreeVar(sup_fv))
if !self.change_generalized && sup_fv.is_generalized() =>
{
Ok(())
}
(sub_tp, TyParam::FreeVar(sup_fv)) => {
match &*sup_fv.borrow() {
FreeKind::Linked(l) | FreeKind::UndoableLinked { t: l, .. } => {
Expand Down Expand Up @@ -760,7 +771,8 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
(FreeVar(sub_fv), FreeVar(sup_fv))
if sub_fv.constraint_is_sandwiched() && sup_fv.constraint_is_sandwiched() =>
{
if sub_fv.is_generalized() || sup_fv.is_generalized() {
if !self.change_generalized && (sub_fv.is_generalized() || sup_fv.is_generalized())
{
log!(info "generalized:\nmaybe_sub: {maybe_sub}\nmaybe_sup: {maybe_sup}");
return Ok(());
}
Expand Down Expand Up @@ -860,7 +872,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
},
FreeVar(sup_fv),
) if sup_fv.constraint_is_sandwiched() => {
if sup_fv.is_generalized() {
if !self.change_generalized && sup_fv.is_generalized() {
log!(info "generalized:\nmaybe_sub: {maybe_sub}\nmaybe_sup: {maybe_sup}");
return Ok(());
}
Expand Down Expand Up @@ -958,7 +970,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
// e.g. Structural({ .method = (self: T) -> Int })/T
(Structural(sub), FreeVar(sup_fv))
if sup_fv.is_unbound() && sub.contains_tvar(sup_fv) => {}
(_, FreeVar(sup_fv)) if sup_fv.is_generalized() => {}
(_, FreeVar(sup_fv)) if !self.change_generalized && sup_fv.is_generalized() => {}
(_, FreeVar(sup_fv)) if sup_fv.is_unbound() => {
// * sub_unify(Nat, ?E(<: Eq(?E)))
// sub !<: l => OK (sub will widen)
Expand Down Expand Up @@ -1037,7 +1049,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
(FreeVar(sub_fv), Ref(sup)) if sub_fv.is_unbound() => {
self.sub_unify(maybe_sub, sup)?;
}
(FreeVar(sub_fv), _) if sub_fv.is_generalized() => {}
(FreeVar(sub_fv), _) if !self.change_generalized && sub_fv.is_generalized() => {}
(FreeVar(sub_fv), _) if sub_fv.is_unbound() => {
// sub !<: r => Error
// * sub_unify(?T(:> Int, <: _), Nat): (/* Error */)
Expand Down Expand Up @@ -1165,7 +1177,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
.iter()
.zip(sup_subr.non_default_params.iter())
.try_for_each(|(sub, sup)| {
if sub.typ().is_generalized() {
if !self.change_generalized && sub.typ().is_generalized() {
Ok(())
}
// contravariant
Expand All @@ -1179,7 +1191,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
.iter()
.find(|sub_pt| sub_pt.name() == sup_pt.name())
{
if sup_pt.typ().is_generalized() {
if !self.change_generalized && sup_pt.typ().is_generalized() {
continue;
}
// contravariant
Expand All @@ -1203,7 +1215,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
.zip(sup_subr.non_default_params.iter())
.try_for_each(|(sub, sup)| {
// contravariant
if sup.typ().is_generalized() {
if !self.change_generalized && sup.typ().is_generalized() {
Ok(())
} else {
self.sub_unify(sup.typ(), sub.typ())
Expand All @@ -1216,7 +1228,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
.find(|sub_pt| sub_pt.name() == sup_pt.name())
{
// contravariant
if sup_pt.typ().is_generalized() {
if !self.change_generalized && sup_pt.typ().is_generalized() {
continue;
}
self.sub_unify(sup_pt.typ(), sub_pt.typ())?;
Expand Down Expand Up @@ -1555,7 +1567,7 @@ impl Context {
maybe_sup: &Type,
loc: &impl Locational,
) -> TyCheckResult<()> {
let unifier = Unifier::new(self, loc, None, None);
let unifier = Unifier::new(self, loc, None, false, None);
unifier.occur(maybe_sub, maybe_sup)
}

Expand All @@ -1567,7 +1579,7 @@ impl Context {
loc: &impl Locational,
is_structural: bool,
) -> TyCheckResult<()> {
let unifier = Unifier::new(self, loc, None, None);
let unifier = Unifier::new(self, loc, None, false, None);
unifier.sub_unify_tp(maybe_sub, maybe_sup, variance, is_structural)
}

Expand All @@ -1579,7 +1591,19 @@ impl Context {
loc: &impl Locational,
param_name: Option<&Str>,
) -> TyCheckResult<()> {
let unifier = Unifier::new(self, loc, None, param_name.cloned());
let unifier = Unifier::new(self, loc, None, false, param_name.cloned());
unifier.sub_unify(maybe_sub, maybe_sup)
}

/// This will rewrite generalized type variables.
pub(crate) fn force_sub_unify(
&self,
maybe_sub: &Type,
maybe_sup: &Type,
loc: &impl Locational,
param_name: Option<&Str>,
) -> TyCheckResult<()> {
let unifier = Unifier::new(self, loc, None, true, param_name.cloned());
unifier.sub_unify(maybe_sub, maybe_sup)
}

Expand All @@ -1591,12 +1615,12 @@ impl Context {
list: &UndoableLinkedList,
param_name: Option<&Str>,
) -> TyCheckResult<()> {
let unifier = Unifier::new(self, loc, Some(list), param_name.cloned());
let unifier = Unifier::new(self, loc, Some(list), false, param_name.cloned());
unifier.sub_unify(maybe_sub, maybe_sup)
}

pub(crate) fn unify(&self, lhs: &Type, rhs: &Type) -> Option<Type> {
let unifier = Unifier::new(self, &(), None, None);
let unifier = Unifier::new(self, &(), None, false, None);
unifier.unify(lhs, rhs)
}
}
Loading

0 comments on commit 6713ffe

Please sign in to comment.