From b21d018adfd55d04cfb9ef78cf89965a08b3e929 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 13 Aug 2024 13:02:14 +0900 Subject: [PATCH] fix: __call__ overload bug --- crates/erg_compiler/context/inquire.rs | 22 ++++++++++++++++++++-- tests/should_ok/dunder.er | 4 ++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/crates/erg_compiler/context/inquire.rs b/crates/erg_compiler/context/inquire.rs index e50bf2cda..c03923a61 100644 --- a/crates/erg_compiler/context/inquire.rs +++ b/crates/erg_compiler/context/inquire.rs @@ -1848,6 +1848,11 @@ impl Context { res } } + Type::And(_, _) => { + let instance = + self.resolve_overload(obj, instance.clone(), pos_args, kw_args, &obj)?; + self.substitute_call(obj, attr_name, &instance, pos_args, kw_args, namespace) + } Type::Failure => Ok(SubstituteResult::Ok), _ => { self.substitute_dunder_call(obj, attr_name, instance, pos_args, kw_args, namespace) @@ -2095,7 +2100,13 @@ impl Context { )?; } let instance = self.instantiate_def_type(&call_vi.t)?; - self.substitute_call(obj, attr_name, &instance, pos_args, kw_args, namespace)?; + let instance = match self + .substitute_call(obj, attr_name, &instance, pos_args, kw_args, namespace)? + { + SubstituteResult::__Call__(instance) + | SubstituteResult::Coerced(instance) => instance, + SubstituteResult::Ok => instance, + }; return Ok(SubstituteResult::__Call__(instance)); } // instance method __call__ @@ -2106,7 +2117,14 @@ impl Context { }) { let instance = self.instantiate_def_type(&call_vi.t)?; - self.substitute_call(obj, attr_name, &instance, pos_args, kw_args, namespace)?; + let instance = match self + .substitute_call(obj, attr_name, &instance, pos_args, kw_args, namespace)? + { + SubstituteResult::__Call__(instance) | SubstituteResult::Coerced(instance) => { + instance + } + SubstituteResult::Ok => instance, + }; return Ok(SubstituteResult::__Call__(instance)); } } diff --git a/tests/should_ok/dunder.er b/tests/should_ok/dunder.er index c1a94b314..fcf4b247a 100644 --- a/tests/should_ok/dunder.er +++ b/tests/should_ok/dunder.er @@ -13,3 +13,7 @@ assert func() == "foo" assert module::__file__.endswith "dunder.er" assert C.new().__file__ == "bar" assert imp.func().endswith "import.er" + +discard Str() +discard Str(1) +discard Str(bytes("aaa", "utf-8"), "utf-8")