Skip to content

Commit

Permalink
fix: closure bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Dec 28, 2023
1 parent a128719 commit b5f5876
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 21 deletions.
22 changes: 11 additions & 11 deletions crates/erg_compiler/context/generalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
match res {
Ok(ty) => {
// TODO: T(:> Nat <: Int) -> T(:> Nat, <: Int) ==> Int -> Nat
// fv.link(&ty);
// Type::FreeVar(fv).destructive_link(&ty);
Ok(ty)
}
Err(errs) => {
Expand Down Expand Up @@ -842,16 +842,8 @@ impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
.check_trait_impl(&sub_t, &super_t, self.qnames, self.loc)?;
}
let is_subtype = self.ctx.subtype_of(&sub_t, &super_t);
let sub_t = if DEBUG_MODE {
sub_t
} else {
self.deref_tyvar(sub_t)?
};
let super_t = if DEBUG_MODE {
super_t
} else {
self.deref_tyvar(super_t)?
};
let sub_t = self.deref_tyvar(sub_t)?;
let super_t = self.deref_tyvar(super_t)?;
if sub_t == super_t {
Ok(sub_t)
} else if is_subtype {
Expand Down Expand Up @@ -1396,6 +1388,7 @@ impl Context {
Type::Or(l, r) => {
let l = self.squash_tyvar(*l);
let r = self.squash_tyvar(*r);
// REVIEW:
if l.is_unnamed_unbound_var() && r.is_unnamed_unbound_var() {
match (self.subtype_of(&l, &r), self.subtype_of(&r, &l)) {
(true, true) | (true, false) => {
Expand All @@ -1409,6 +1402,13 @@ impl Context {
}
self.union(&l, &r)
}
Type::FreeVar(ref fv) if fv.constraint_is_sandwiched() => {
let (sub_t, super_t) = fv.get_subsup().unwrap();
let sub_t = self.squash_tyvar(sub_t);
let super_t = self.squash_tyvar(super_t);
typ.update_tyvar(sub_t, super_t, None, false);
typ
}
other => other,
}
}
Expand Down
10 changes: 6 additions & 4 deletions crates/erg_compiler/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1380,7 +1380,7 @@ impl Context {
&self.shared().promises
}

pub fn current_caller(&self) -> Option<ControlKind> {
pub fn current_control_flow(&self) -> Option<ControlKind> {
self.higher_order_caller
.last()
.and_then(|name| ControlKind::try_from(&name[..]).ok())
Expand All @@ -1395,11 +1395,13 @@ impl Context {
None
}

pub fn current_function_ctx(&self) -> Option<&Context> {
if self.kind.is_subr() {
/// Context of the function that actually creates the scope.
/// Control flow function blocks do not create actual scopes.
pub fn current_true_function_ctx(&self) -> Option<&Context> {
if self.kind.is_subr() && self.current_control_flow().is_none() {
Some(self)
} else if let Some(outer) = self.get_outer() {
outer.current_function_ctx()
outer.current_true_function_ctx()
} else {
None
}
Expand Down
8 changes: 2 additions & 6 deletions crates/erg_compiler/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -951,11 +951,7 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
if !ident.vi.is_toplevel()
&& ident.vi.def_namespace() != &self.module.context.name
&& ident.vi.kind.can_capture()
&& self
.module
.context
.current_function_ctx()
.is_some_and(|ctx| ctx.control_kind().is_none())
&& self.module.context.current_true_function_ctx().is_some()
{
self.module.context.captured_names.push(ident.clone());
}
Expand Down Expand Up @@ -1178,7 +1174,7 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
if self
.module
.context
.current_caller()
.current_control_flow()
.map_or(true, |kind| !kind.is_if())
&& expect.is_some_and(|subr| !subr.essential_qnames().is_empty())
{
Expand Down
17 changes: 17 additions & 0 deletions tests/should_ok/closure.er
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,20 @@ for! [1], _ =>

push! "a", "b"
assert result == "| a | b |\n"

{SemVer;} = import "semver"

Versions! = Class Dict! { Str: Array!(SemVer) }
Versions!.
new() = Versions!::__new__ !{:}
insert!(ref! self, name: Str, version: SemVer) =
if! self::base.get(name) == None:
do!:
self::base.insert! name, ![version]
do!:
if! all(map(v -> v != version, self::base[name])), do!:
self::base[name].push! version

vs = Versions!.new()
_ = vs.insert! "foo", SemVer.from_str "1.0.0"
_ = vs.insert! "foo", SemVer.from_str "1.0.1"

0 comments on commit b5f5876

Please sign in to comment.