Skip to content

Commit

Permalink
fix: __call__ overload bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Aug 13, 2024
1 parent b87f3da commit b21d018
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
22 changes: 20 additions & 2 deletions crates/erg_compiler/context/inquire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1848,6 +1848,11 @@ impl Context {
res
}
}
Type::And(_, _) => {
let instance =
self.resolve_overload(obj, instance.clone(), pos_args, kw_args, &obj)?;
self.substitute_call(obj, attr_name, &instance, pos_args, kw_args, namespace)
}
Type::Failure => Ok(SubstituteResult::Ok),
_ => {
self.substitute_dunder_call(obj, attr_name, instance, pos_args, kw_args, namespace)
Expand Down Expand Up @@ -2095,7 +2100,13 @@ impl Context {
)?;
}
let instance = self.instantiate_def_type(&call_vi.t)?;
self.substitute_call(obj, attr_name, &instance, pos_args, kw_args, namespace)?;
let instance = match self
.substitute_call(obj, attr_name, &instance, pos_args, kw_args, namespace)?
{
SubstituteResult::__Call__(instance)
| SubstituteResult::Coerced(instance) => instance,
SubstituteResult::Ok => instance,
};
return Ok(SubstituteResult::__Call__(instance));
}
// instance method __call__
Expand All @@ -2106,7 +2117,14 @@ impl Context {
})
{
let instance = self.instantiate_def_type(&call_vi.t)?;
self.substitute_call(obj, attr_name, &instance, pos_args, kw_args, namespace)?;
let instance = match self
.substitute_call(obj, attr_name, &instance, pos_args, kw_args, namespace)?
{
SubstituteResult::__Call__(instance) | SubstituteResult::Coerced(instance) => {
instance
}
SubstituteResult::Ok => instance,
};
return Ok(SubstituteResult::__Call__(instance));
}
}
Expand Down
4 changes: 4 additions & 0 deletions tests/should_ok/dunder.er
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ assert func() == "foo"
assert module::__file__.endswith "dunder.er"
assert C.new().__file__ == "bar"
assert imp.func().endswith "import.er"

discard Str()
discard Str(1)
discard Str(bytes("aaa", "utf-8"), "utf-8")

0 comments on commit b21d018

Please sign in to comment.