From 769691ec010093e705da96e86ba8418f128436fa Mon Sep 17 00:00:00 2001 From: tsudol Date: Tue, 10 Oct 2023 14:59:49 -0700 Subject: [PATCH 01/11] Disable 2-iterable signature for set.union. This is causing a crash that's proving very hard to debug. PiperOrigin-RevId: 572376335 --- pytype/stubs/builtins/builtins.pytd | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytype/stubs/builtins/builtins.pytd b/pytype/stubs/builtins/builtins.pytd index 796c70c45..69ea7573c 100644 --- a/pytype/stubs/builtins/builtins.pytd +++ b/pytype/stubs/builtins/builtins.pytd @@ -694,8 +694,9 @@ class set(Set[_T]): self = set[Union[_T, _T2]] @overload def union(self, other: Iterable[_T2]) -> set[_T | _T2]: ... - @overload - def union(self, other1: Iterable[_T2], other2: Iterable[_T3]) -> set[_T | _T2 | _T3]: ... + # TODO(b/304591130): Re-enable this overload. + # @overload + # def union(self, other1: Iterable[_T2], other2: Iterable[_T3]) -> set[_T | _T2 | _T3]: ... @overload def union(self, *args: Iterable[Any]) -> set[Any]: ... def intersection(self, *args: Iterable[Any]) -> set[_T]: ... @@ -720,8 +721,9 @@ class frozenset(FrozenSet[_T]): def issuperset(self, y: Iterable) -> bool: ... @overload def union(self, other: Iterable[_T2]) -> frozenset[_T | _T2]: ... - @overload - def union(self, other1: Iterable[_T2], other2: Iterable[_T3]) -> frozenset[_T | _T2 | _T3]: ... + # TODO(b/304591130): Re-enable this overload. + # @overload + # def union(self, other1: Iterable[_T2], other2: Iterable[_T3]) -> frozenset[_T | _T2 | _T3]: ... @overload def union(self, *args: Iterable[Any]) -> frozenset[Any]: ... def intersection(self, *args: Iterable[Any]) -> frozenset[_T]: ... From 656dbf143702d50b0ccb446307d1213c5622dcec Mon Sep 17 00:00:00 2001 From: rechen Date: Tue, 10 Oct 2023 16:02:08 -0700 Subject: [PATCH 02/11] Add some typing.Self tests. Adds tests to capture the state of our current partial support. PiperOrigin-RevId: 572392523 --- pytype/tests/CMakeLists.txt | 9 +++ pytype/tests/test_typing_self.py | 99 ++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+) create mode 100644 pytype/tests/test_typing_self.py diff --git a/pytype/tests/CMakeLists.txt b/pytype/tests/CMakeLists.txt index 4cbae5f0d..6c377eb26 100644 --- a/pytype/tests/CMakeLists.txt +++ b/pytype/tests/CMakeLists.txt @@ -1255,3 +1255,12 @@ py_test( DEPS .test_base ) + +py_test( + NAME + test_typing_self + SRCS + test_typing_self.py + DEPS + .test_base +) diff --git a/pytype/tests/test_typing_self.py b/pytype/tests/test_typing_self.py new file mode 100644 index 000000000..34e53b84c --- /dev/null +++ b/pytype/tests/test_typing_self.py @@ -0,0 +1,99 @@ +"""Tests for typing.Self.""" + +from pytype.tests import test_base + + +class SelfPyiTest(test_base.BaseTest): + """Tests for typing.Self usage in type stubs.""" + + def test_instance_method_return(self): + with self.DepTree([("foo.pyi", """ + from typing import Self + class A: + def f(self) -> Self: ... + """)]): + self.Check(""" + import foo + class B(foo.A): + pass + assert_type(foo.A().f(), foo.A) + assert_type(B().f(), B) + """) + + def test_classmethod_return(self): + with self.DepTree([("foo.pyi", """ + from typing import Self + class A: + @classmethod + def f(cls) -> Self: ... + """)]): + self.Check(""" + import foo + class B(foo.A): + pass + assert_type(foo.A.f(), foo.A) + assert_type(B.f(), B) + """) + + def test_new_return(self): + with self.DepTree([("foo.pyi", """ + from typing import Self + class A: + def __new__(cls) -> Self: ... + """)]): + self.Check(""" + import foo + class B(foo.A): + pass + assert_type(foo.A(), foo.A) + assert_type(B(), B) + """) + + def test_parameterized_return(self): + with self.DepTree([("foo.pyi", """ + from typing import Self + class A: + def f(self) -> list[Self]: ... + """)]): + self.Check(""" + import foo + class B(foo.A): + pass + assert_type(foo.A().f(), "List[foo.A]") + assert_type(B().f(), "List[B]") + """) + + def test_parameter(self): + with self.DepTree([("foo.pyi", """ + from typing import Self + class A: + def f(self, other: Self) -> bool: ... + """)]): + errors = self.CheckWithErrors(""" + import foo + class B(foo.A): + pass + B().f(B()) # ok + B().f(0) # wrong-arg-types[e] + """) + self.assertErrorSequences( + errors, {"e": ["Expected", "B", "Actual", "int"]}) + + def test_nested_class(self): + with self.DepTree([("foo.pyi", """ + from typing import Self + class A: + class B: + def f(self) -> Self: ... + """)]): + self.Check(""" + import foo + class C(foo.A.B): + pass + assert_type(foo.A.B().f(), foo.A.B) + assert_type(C().f(), C) + """) + + +if __name__ == "__main__": + test_base.main() From 7404b1418ff06c3febdfa8146e88fffd939fd0c1 Mon Sep 17 00:00:00 2001 From: mdemello Date: Tue, 10 Oct 2023 17:11:06 -0700 Subject: [PATCH 03/11] Move python 3.11 flow control tests to the main test suite. PiperOrigin-RevId: 572409804 --- pytype/tests/CMakeLists.txt | 9 +++ pytype/tests/test_flow3.py | 145 ++++++++++++++++++++++++++++++++++++ pytype/tests/test_py_311.py | 129 -------------------------------- 3 files changed, 154 insertions(+), 129 deletions(-) create mode 100644 pytype/tests/test_flow3.py diff --git a/pytype/tests/CMakeLists.txt b/pytype/tests/CMakeLists.txt index 6c377eb26..e68564b19 100644 --- a/pytype/tests/CMakeLists.txt +++ b/pytype/tests/CMakeLists.txt @@ -531,6 +531,15 @@ py_test( .test_base ) +py_test( + NAME + test_flow3 + SRCS + test_flow3.py + DEPS + .test_base +) + py_test( NAME test_methods1 diff --git a/pytype/tests/test_flow3.py b/pytype/tests/test_flow3.py new file mode 100644 index 000000000..9b1201d4e --- /dev/null +++ b/pytype/tests/test_flow3.py @@ -0,0 +1,145 @@ +"""Tests for control flow cases that involve the exception table in 3.11+. + +Python 3.11 changed the way exceptions and some other control structures were +compiled, and in particular some of them require examining the exception table +as well as the bytecode. +""" + +from pytype.tests import test_base + + +class TestPy311(test_base.BaseTest): + """Tests for python 3.11 support.""" + + def test_context_manager(self): + self.Check(""" + class A: + def __enter__(self): + pass + def __exit__(self, a, b, c): + pass + + lock = A() + + def f() -> str: + path = '' + with lock: + try: + pass + except: + pass + return path + """) + + def test_exception_type(self): + self.Check(""" + class FooError(Exception): + pass + try: + raise FooError() + except FooError as e: + assert_type(e, FooError) + """) + + def test_try_with(self): + self.Check(""" + def f(obj, x): + try: + with __any_object__: + obj.get(x) + except: + pass + """) + + def test_try_if_with(self): + self.Check(""" + from typing import Any + import os + pytz: Any + def f(): + tz_env = os.environ.get('TZ') + try: + if tz_env == 'localtime': + with open('localtime') as localtime: + return pytz.tzfile.build_tzinfo('', localtime) + except IOError: + return pytz.UTC + """) + + def test_try_finally(self): + self.Check(""" + import tempfile + dir_ = None + def f(): + global dir_ + try: + if dir_: + return dir_ + dir_ = tempfile.mkdtemp() + finally: + print(dir_) + """) + + def test_nested_try_in_for(self): + self.Check(""" + def f(x): + for i in x: + fd = __any_object__ + try: + try: + if __random__: + return True + except ValueError: + continue + finally: + fd.close() + """) + + def test_while_and_nested_try(self): + self.Check(""" + def f(p): + try: + while __random__: + try: + return p.communicate() + except KeyboardInterrupt: + pass + finally: + pass + """) + + def test_while_and_nested_try_2(self): + self.Check(""" + def f(): + i = j = 0 + while True: + try: + try: + i += 1 + finally: + j += 1 + except: + break + return + """) + + def test_while_and_nested_try_3(self): + self.Check(""" + import os + + def RmDirs(dir_name): + try: + parent_directory = os.path.dirname(dir_name) + while parent_directory: + try: + os.rmdir(parent_directory) + except OSError as err: + pass + parent_directory = os.path.dirname(parent_directory) + except OSError as err: + pass + """) + + +if __name__ == "__main__": + test_base.main() diff --git a/pytype/tests/test_py_311.py b/pytype/tests/test_py_311.py index dfac25587..620f8e2e5 100644 --- a/pytype/tests/test_py_311.py +++ b/pytype/tests/test_py_311.py @@ -41,26 +41,6 @@ def f(x): return any(x) """) - def test_context_manager(self): - self.Check(""" - class A: - def __enter__(self): - pass - def __exit__(self, a, b, c): - pass - - lock = A() - - def f() -> str: - path = '' - with lock: - try: - pass - except: - pass - return path - """) - def test_deref1(self): self.Check(""" def f(*args): @@ -103,115 +83,6 @@ def g(x, y): return (x, y) """) - def test_exception_type(self): - self.Check(""" - class FooError(Exception): - pass - try: - raise FooError() - except FooError as e: - assert_type(e, FooError) - """) - - def test_try_with(self): - self.Check(""" - def f(obj, x): - try: - with __any_object__: - obj.get(x) - except: - pass - """) - - def test_try_if_with(self): - self.Check(""" - from typing import Any - import os - pytz: Any - def f(): - tz_env = os.environ.get('TZ') - try: - if tz_env == 'localtime': - with open('localtime') as localtime: - return pytz.tzfile.build_tzinfo('', localtime) - except IOError: - return pytz.UTC - """) - - def test_try_finally(self): - self.Check(""" - import tempfile - dir_ = None - def f(): - global dir_ - try: - if dir_: - return dir_ - dir_ = tempfile.mkdtemp() - finally: - print(dir_) - """) - - def test_nested_try_in_for(self): - self.Check(""" - def f(x): - for i in x: - fd = __any_object__ - try: - try: - if __random__: - return True - except ValueError: - continue - finally: - fd.close() - """) - - def test_while_and_nested_try(self): - self.Check(""" - def f(p): - try: - while __random__: - try: - return p.communicate() - except KeyboardInterrupt: - pass - finally: - pass - """) - - def test_while_and_nested_try_2(self): - self.Check(""" - def f(): - i = j = 0 - while True: - try: - try: - i += 1 - finally: - j += 1 - except: - break - return - """) - - def test_while_and_nested_try_3(self): - self.Check(""" - import os - - def RmDirs(dir_name): - try: - parent_directory = os.path.dirname(dir_name) - while parent_directory: - try: - os.rmdir(parent_directory) - except OSError as err: - pass - parent_directory = os.path.dirname(parent_directory) - except OSError as err: - pass - """) - if __name__ == "__main__": test_base.main() From 252aabb7df14d685ebbc6964d79b8a04cb42bf62 Mon Sep 17 00:00:00 2001 From: rechen Date: Tue, 10 Oct 2023 17:48:18 -0700 Subject: [PATCH 04/11] Fix crash caused by misidentifying decorators. PiperOrigin-RevId: 572417428 --- pytype/tests/test_py_311.py | 10 ++++++++++ pytype/vm.py | 6 +++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/pytype/tests/test_py_311.py b/pytype/tests/test_py_311.py index 620f8e2e5..be38a21e8 100644 --- a/pytype/tests/test_py_311.py +++ b/pytype/tests/test_py_311.py @@ -83,6 +83,16 @@ def g(x, y): return (x, y) """) + def test_callable_parameter_in_function(self): + # Tests that we don't mis-identify the defaultdict call as a decorator. + self.Check(""" + import collections + class C: + def __init__(self): + self.x = collections.defaultdict( + lambda key: key) # pytype: disable=wrong-arg-types + """) + if __name__ == "__main__": test_base.main() diff --git a/pytype/vm.py b/pytype/vm.py index 007487525..e6460dccc 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -629,9 +629,9 @@ def simple_stack(self, opcode=None): def _in_3_11_decoration(self): """Are we in a Python 3.11 decorator call?""" - if self.ctx.python_version != (3, 11): - return False - if not isinstance(self.current_opcode, opcodes.CALL): + if not (self.ctx.python_version == (3, 11) and + isinstance(self.current_opcode, opcodes.CALL) and + self.current_line in self._director.decorated_functions): return False prev = self.current_opcode # Skip past the PRECALL opcode. From 663abbe9eb20d9a83ce1819eae1b784970f35696 Mon Sep 17 00:00:00 2001 From: rechen Date: Tue, 10 Oct 2023 18:44:35 -0700 Subject: [PATCH 05/11] Redirect jumps that go to deleted opcodes. PiperOrigin-RevId: 572427153 --- pytype/pyc/opcodes.py | 10 ++++++++-- pytype/tests/test_py_311.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/pytype/pyc/opcodes.py b/pytype/pyc/opcodes.py index 3d88164aa..60706b108 100644 --- a/pytype/pyc/opcodes.py +++ b/pytype/pyc/opcodes.py @@ -1129,11 +1129,17 @@ def _make_opcode_list(offset_to_op, python_version): # In 3.11 `async for` is compiled into an infinite loop, relying on the # exception handler to break out. This causes the block graph to be # pruned abruptly, so we need to remove the loop opcode. - index -= 1 - continue + skip = True elif (isinstance(op, JUMP_BACKWARD_NO_INTERRUPT) and isinstance(offset_to_op[op.argval], SEND)): # Likewise, `await` is compiled into an infinite loop which we remove. + skip = True + else: + skip = False + if skip: + # We map the offset to the index of the next opcode so that jumps to + # `op` are redirected correctly. + offset_to_index[off] = index index -= 1 continue op.index = index diff --git a/pytype/tests/test_py_311.py b/pytype/tests/test_py_311.py index be38a21e8..969df851b 100644 --- a/pytype/tests/test_py_311.py +++ b/pytype/tests/test_py_311.py @@ -93,6 +93,16 @@ def __init__(self): lambda key: key) # pytype: disable=wrong-arg-types """) + def test_async_for(self): + self.Check(""" + class Client: + async def get_or_create_tensorboard(self): + response = await __any_object__ + async for page in response.pages: + if page.tensorboards: + return response.tensorboards[0].name + """) + if __name__ == "__main__": test_base.main() From 15ddcfce83c357989cdcc5b85a112d84765e706f Mon Sep 17 00:00:00 2001 From: rechen Date: Wed, 11 Oct 2023 10:24:17 -0700 Subject: [PATCH 06/11] Support enum.StrEnum. This is needed to support the etils library (specifically: https://github.com/google/etils/blob/ea1f4876517656ef86deedaa41de47c87f00a999/etils/epy/py_utils.py#L33) in 3.11. PiperOrigin-RevId: 572613882 --- pytype/overlays/enum_overlay.py | 6 +++--- pytype/stubs/stdlib/enum.pytd | 7 ++++++- pytype/tests/test_enums.py | 11 ++++++++++- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/pytype/overlays/enum_overlay.py b/pytype/overlays/enum_overlay.py index 0255d37d7..0aa630830 100644 --- a/pytype/overlays/enum_overlay.py +++ b/pytype/overlays/enum_overlay.py @@ -44,9 +44,8 @@ # These members have been added in Python 3.11 and are not yet supported. -_unsupported = ("StrEnum", "ReprEnum", "EnumCheck", "FlagBoundary", "verify", - "property", "member", "nonmember", "global_enum", - "show_flag_values") +_unsupported = ("ReprEnum", "EnumCheck", "FlagBoundary", "verify", "property", + "member", "nonmember", "global_enum", "show_flag_values") class EnumOverlay(overlay.Overlay): @@ -59,6 +58,7 @@ def __init__(self, ctx): "EnumMeta": EnumMeta, "EnumType": EnumMeta, "IntEnum": overlay.add_name("IntEnum", EnumBuilder), + "StrEnum": overlay.add_name("StrEnum", EnumBuilder), **{name: overlay.add_name(name, overlay_utils.not_supported_yet) for name in _unsupported}, } diff --git a/pytype/stubs/stdlib/enum.pytd b/pytype/stubs/stdlib/enum.pytd index 3b96b3f96..1cc79ee32 100644 --- a/pytype/stubs/stdlib/enum.pytd +++ b/pytype/stubs/stdlib/enum.pytd @@ -66,7 +66,12 @@ class IntFlag(int, Flag): # 3.11: Not supported yet. EnumType = EnumMeta -class StrEnum: ... + +class StrEnum(str, Enum): + # Workaround for b/201603421. + @classmethod + def __iter__(cls: Type[_T]) -> Iterator[_T]: ... + class ReprEnum: ... class EnumCheck: ... class FlagBoundary: ... diff --git a/pytype/tests/test_enums.py b/pytype/tests/test_enums.py index 240dd9caa..0dfab035d 100644 --- a/pytype/tests/test_enums.py +++ b/pytype/tests/test_enums.py @@ -924,6 +924,15 @@ class IF(enum.IntFlag): A = 1 """) + def test_strenum(self): + self.Check(""" + import enum + class MyEnum(enum.StrEnum): + A = 'A' + for x in MyEnum: + assert_type(x, MyEnum) + """) + def test_unique_enum_in_dict(self): # Regression test for a recursion error in matcher.py self.assertNoCrash(self.Check, """ @@ -1358,7 +1367,7 @@ class M(enum.Enum): def test_not_supported_yet(self): self.CheckWithErrors(""" import enum - enum.StrEnum # not-supported-yet + enum.ReprEnum # not-supported-yet """) From 7454d8e79b90943d4903d1e65091db43b2061e30 Mon Sep 17 00:00:00 2001 From: tristenallen Date: Wed, 11 Oct 2023 11:03:31 -0700 Subject: [PATCH 07/11] Small fix to `parse_pyi`. `parse_pyi.py` currently raises an exception when trying to decompose the results of `parser.parse_pyi`. It only returns a single value, but the left-hand side of the assignment specifies two variables. Since the second variable is unused, all this CL does is remove that portion of the assignment statement. #tftypes PiperOrigin-RevId: 572626607 --- pytype/pyi/parse_pyi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytype/pyi/parse_pyi.py b/pytype/pyi/parse_pyi.py index ceb993ad1..bbc9316bf 100644 --- a/pytype/pyi/parse_pyi.py +++ b/pytype/pyi/parse_pyi.py @@ -18,7 +18,7 @@ module_name = module_utils.path_to_module_name(filename) try: - out, _ = parser.parse_pyi(src, filename, module_name, debug_mode=True) + out = parser.parse_pyi(src, filename, module_name, debug_mode=True) except _ParseError as e: print(e) sys.exit(1) From 14e6d0e532f895dd5c6c109f3b9f7b99f0f37fc3 Mon Sep 17 00:00:00 2001 From: rechen Date: Wed, 11 Oct 2023 11:09:53 -0700 Subject: [PATCH 08/11] Re-implement the SEND opcode. The way this is supposed to work in 3.11 is that SEND gets a value from the generator, YIELD_VALUE yields it, and JUMP_BACKWARDS_NO_INTERRUPT jumps back to the SEND, repeatedly, until there are no more values to yield. At that point, SEND pushes the generator's return value onto the stack and jumps past the backwards jump. However, we removed the jump opcode, so we can't follow this implementation. Reusing 3.10's YIELD_FROM implementation seems to work. I also renamed some of the variables in our generator-related VM code to make it easier to follow. PiperOrigin-RevId: 572628644 --- pytype/tests/test_py_311.py | 13 ++++ pytype/vm.py | 144 ++++++++++++++++++------------------ 2 files changed, 84 insertions(+), 73 deletions(-) diff --git a/pytype/tests/test_py_311.py b/pytype/tests/test_py_311.py index 969df851b..01a2a8f45 100644 --- a/pytype/tests/test_py_311.py +++ b/pytype/tests/test_py_311.py @@ -103,6 +103,19 @@ async def get_or_create_tensorboard(self): return response.tensorboards[0].name """) + def test_yield_from(self): + self.Check(""" + def f(): + yield 1 + return 'a', 'b' + def g(): + a, b = yield from f() + assert_type(a, str) + assert_type(b, str) + for x in g(): + assert_type(x, int) + """) + if __name__ == "__main__": test_base.main() diff --git a/pytype/vm.py b/pytype/vm.py index e6460dccc..380451ead 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -2826,27 +2826,29 @@ def byte_CALL_FUNCTION_EX(self, state, op): starstarargs=starstarargs) return state.push(ret) - def _check_frame_yield(self, state, ret): + def _check_frame_yield(self, state, yield_value): if not self.frame.check_return: return None - ret_type = self.frame.allowed_returns - assert ret_type is not None - self._check_return(state.node, ret, - ret_type.get_formal_type_parameter(abstract_utils.T)) - return ret_type + generator_type = self.frame.allowed_returns + assert generator_type is not None + self._check_return( + state.node, yield_value, + generator_type.get_formal_type_parameter(abstract_utils.T)) + return generator_type def byte_YIELD_VALUE(self, state, op): """Yield a value from a generator.""" - state, ret = state.pop() - value = self.frame.yield_variable.AssignToNewVariable(state.node) - value.PasteVariable(ret, state.node) - self.frame.yield_variable = value - ret_type = self._check_frame_yield(state, ret) if self.ctx.python_version >= (3, 11) and isinstance(op.prev, opcodes.SEND): - send_var = ret - elif ret_type: - send_var = self.init_class( - state.node, ret_type.get_formal_type_parameter(abstract_utils.T2)) + # See byte_SEND for what's happening here. + return state + state, yield_value = state.pop() + yield_variable = self.frame.yield_variable.AssignToNewVariable(state.node) + yield_variable.PasteVariable(yield_value, state.node) + self.frame.yield_variable = yield_variable + generator_type = self._check_frame_yield(state, yield_value) + if generator_type: + send_type = generator_type.get_formal_type_parameter(abstract_utils.T2) + send_var = self.init_class(state.node, send_type) else: send_var = self.ctx.new_unsolvable(state.node) return state.push(send_var) @@ -3111,52 +3113,57 @@ def byte_GET_AWAITABLE(self, state, op): ret = self.ctx.new_unsolvable(state.node) return state.push(ret) - def _yield_from_value(self, state, var, yield_variable): - """Helper function for YIELD_FROM and SEND.""" - result = self.ctx.program.NewVariable() - for b in var.bindings: - val = b.data - if val.full_name == "builtins.generator": - yield_variable.PasteVariable( - val.get_instance_type_parameter(abstract_utils.T), state.node) - if isinstance(val, (abstract.Generator, - abstract.Coroutine, abstract.Unsolvable)): - ret_var = val.get_instance_type_parameter(abstract_utils.V) - result.PasteVariable(ret_var, state.node, {b}) - elif (isinstance(val, abstract.Instance) - and isinstance(val.cls, + def _get_generator_yield(self, node, generator_var): + yield_var = self.frame.yield_variable.AssignToNewVariable(node) + for generator in generator_var.data: + if generator.full_name == "builtins.generator": + yield_value = generator.get_instance_type_parameter(abstract_utils.T) + yield_var.PasteVariable(yield_value, node) + return yield_var + + def _get_generator_return(self, node, generator_var): + """Gets generator_var's return value.""" + ret_var = self.ctx.program.NewVariable() + for b in generator_var.bindings: + generator = b.data + if isinstance(generator, (abstract.Generator, + abstract.Coroutine, abstract.Unsolvable)): + ret = generator.get_instance_type_parameter(abstract_utils.V) + ret_var.PasteVariable(ret, node, {b}) + elif (isinstance(generator, abstract.Instance) + and isinstance(generator.cls, (abstract.ParameterizedClass, abstract.PyTDClass)) - and val.cls.full_name in ("typing.Awaitable", - "builtins.coroutine", - "builtins.generator")): - if val.cls.full_name == "typing.Awaitable": - ret_var = val.get_instance_type_parameter(abstract_utils.T) + and generator.cls.full_name in ("typing.Awaitable", + "builtins.coroutine", + "builtins.generator")): + if generator.cls.full_name == "typing.Awaitable": + ret = generator.get_instance_type_parameter(abstract_utils.T) else: - ret_var = val.get_instance_type_parameter(abstract_utils.V) - if ret_var.bindings: - result.PasteVariable(ret_var, state.node, {b}) + ret = generator.get_instance_type_parameter(abstract_utils.V) + if ret.bindings: + ret_var.PasteVariable(ret, node, {b}) else: - result.AddBinding(self.ctx.convert.unsolvable, {b}, state.node) + ret_var.AddBinding(self.ctx.convert.unsolvable, {b}, node) else: - result.AddBinding(val, {b}, state.node) - return result + ret_var.AddBinding(generator, {b}, node) + if not ret_var.bindings: + ret_var.AddBinding(self.ctx.convert.unsolvable, [], node) + return ret_var + + def _yield_from(self, state): + """Helper function for YIELD_FROM and SEND.""" + state, unused_send = state.pop() + state, generator_var = state.pop() + yield_var = self._get_generator_yield(state.node, generator_var) + if yield_var.bindings: + self.frame.yield_variable = yield_var + _ = self._check_frame_yield(state, yield_var) + ret_var = self._get_generator_return(state.node, generator_var) + return state.push(ret_var) def byte_YIELD_FROM(self, state, op): """Implementation of the YIELD_FROM opcode.""" - state, unused_none_var = state.pop() - state, var = state.pop() - yield_variable = self.frame.yield_variable.AssignToNewVariable(state.node) - result = self._yield_from_value(state, var, yield_variable) - if yield_variable.bindings: - self.frame.yield_variable = yield_variable - if self.frame.check_return: - assert self.frame.allowed_returns is not None - ret_type = self.frame.allowed_returns.get_formal_type_parameter( - abstract_utils.T) - self._check_return(state.node, yield_variable, ret_type) - if not result.bindings: - result.AddBinding(self.ctx.convert.unsolvable, [], state.node) - return state.push(result) + return self._yield_from(state) def byte_LOAD_METHOD(self, state, op): """Implementation of the LOAD_METHOD opcode.""" @@ -3451,25 +3458,16 @@ def byte_BINARY_OP(self, state, op): def byte_SEND(self, state, op): """Implementation of SEND opcode.""" - state, var = state.pop() - state, recv = state.pop() - node, next_meth, _ = self._retrieve_attr(state.node, recv, "__next__") - if self._var_is_none(var) and next_meth: - state = state.change_cfg_node(node) - state, ret = self.call_function_with_state(state, next_meth, ()) - else: - yield_variable = self.frame.yield_variable.AssignToNewVariable(state.node) - ret = self._yield_from_value(state, recv, yield_variable) - if yield_variable.bindings: - self.frame.yield_variable = yield_variable - if self.frame.check_return: - assert self.frame.allowed_returns is not None - ret_type = self.frame.allowed_returns.get_formal_type_parameter( - abstract_utils.T) - self._check_return(state.node, yield_variable, ret_type) - if not ret.bindings: - ret.AddBinding(self.ctx.convert.unsolvable, [], state.node) - return state.push(ret) + # In 3.11, SEND + YIELD_VALUE + JUMP_BACKWARD_NO_INTERRUPT are used to + # implement `yield from`, which in 3.10 was implemented by the YIELD_FROM + # opcode. See + # https://github.com/python/cpython/blob/c6d5628be950bdf2c31243b4cc0d9e0b658458dd/Python/ceval.c#L2577 + # for the 3.11 CPython source. To avoid an infinite loop, we have removed + # the JUMP_BACKWARD_NO_INTERRUPT. So instead of attempting to follow the + # 3.11 implementation, we have SEND implement YIELD_FROM and YIELD_VALUE do + # nothing when it detects that the previous opcode was a SEND. + assert isinstance(op.next, opcodes.YIELD_VALUE) + return self._yield_from(state) def byte_POP_JUMP_FORWARD_IF_NOT_NONE(self, state, op): return vm_utils.jump_if(state, op, self.ctx, From 4dfd53c5f63b6d3eecd44365c4e76f954e26a4d3 Mon Sep 17 00:00:00 2001 From: tristenallen Date: Wed, 11 Oct 2023 13:54:09 -0700 Subject: [PATCH 09/11] Remove extraneous check in RenameModuleVisitor. `pytd_visitors.RenameModuleVisitor` previously skipped renaming a module if, when matching against the text of a node, it detected a `.` character in any text trailing the previous module name. This prevents the visitor from correctly removing module names added during pytype's serialization of ASTs. This change removes that check and only that check. The remaining check (which verifies that there is no text preceding the match) should be sufficient to prevent erroneous replacement in partial-match situations. #tftypes PiperOrigin-RevId: 572677520 --- pytype/pytd/pytd_visitors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytype/pytd/pytd_visitors.py b/pytype/pytd/pytd_visitors.py index f818fc848..1382cf5ce 100644 --- a/pytype/pytd/pytd_visitors.py +++ b/pytype/pytd/pytd_visitors.py @@ -157,7 +157,7 @@ def _MaybeNewName(self, name): if name == self._old[:-1]: return self._module_name before, match, after = name.partition(self._old) - if match and not before and "." not in after: + if match and not before: return self._new + after else: return name From ba60f002f14fc93f1aa7e1bfc725bd738c452378 Mon Sep 17 00:00:00 2001 From: rechen Date: Wed, 11 Oct 2023 14:40:27 -0700 Subject: [PATCH 10/11] Change SEND to follow the 3.11 runtime implementation more closely. After staring at this long enough, I think I figured out how to implement SEND properly. This (hopefully) has no user-visible effect. PiperOrigin-RevId: 572691337 --- pytype/pyc/opcodes.py | 36 ++++++++++++------------------------ pytype/vm.py | 39 +++++++++++++++++---------------------- 2 files changed, 29 insertions(+), 46 deletions(-) diff --git a/pytype/pyc/opcodes.py b/pytype/pyc/opcodes.py index 60706b108..a1594ff88 100644 --- a/pytype/pyc/opcodes.py +++ b/pytype/pyc/opcodes.py @@ -1122,26 +1122,17 @@ def _make_opcode_list(offset_to_op, python_version): op_items = sorted(offset_to_op.items()) for i, (off, op) in enumerate(op_items): index += 1 - if python_version == (3, 11): - if (isinstance(op, JUMP_BACKWARD) and - i + 1 < len(op_items) and - isinstance(op_items[i + 1][1], END_ASYNC_FOR)): - # In 3.11 `async for` is compiled into an infinite loop, relying on the - # exception handler to break out. This causes the block graph to be - # pruned abruptly, so we need to remove the loop opcode. - skip = True - elif (isinstance(op, JUMP_BACKWARD_NO_INTERRUPT) and - isinstance(offset_to_op[op.argval], SEND)): - # Likewise, `await` is compiled into an infinite loop which we remove. - skip = True - else: - skip = False - if skip: - # We map the offset to the index of the next opcode so that jumps to - # `op` are redirected correctly. - offset_to_index[off] = index - index -= 1 - continue + if (python_version == (3, 11) and isinstance(op, JUMP_BACKWARD) and + i + 1 < len(op_items) and + isinstance(op_items[i + 1][1], END_ASYNC_FOR)): + # In 3.11 `async for` is compiled into an infinite loop, relying on the + # exception handler to break out. This causes the block graph to be + # pruned abruptly, so we need to remove the loop opcode. + # We map the offset to the index of the next opcode so that jumps to + # `op` are redirected correctly. + offset_to_index[off] = index + index -= 1 + continue op.index = index offset_to_index[off] = index if prev_op: @@ -1157,10 +1148,7 @@ def _add_jump_targets(ops, offset_to_index): """Map the target of jump instructions to the opcode they jump to.""" for op in ops: op = cast(OpcodeWithArg, op) - if isinstance(op, SEND): - # This has a target in the bytecode, but is not a jump - op.target = None - elif op.target: + if op.target: # We have already set op.target, we need to fill in its index in op.arg op.arg = op.argval = op.target.index elif op.has_known_jump(): diff --git a/pytype/vm.py b/pytype/vm.py index 380451ead..867946044 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -2838,9 +2838,6 @@ def _check_frame_yield(self, state, yield_value): def byte_YIELD_VALUE(self, state, op): """Yield a value from a generator.""" - if self.ctx.python_version >= (3, 11) and isinstance(op.prev, opcodes.SEND): - # See byte_SEND for what's happening here. - return state state, yield_value = state.pop() yield_variable = self.frame.yield_variable.AssignToNewVariable(state.node) yield_variable.PasteVariable(yield_value, state.node) @@ -3150,21 +3147,16 @@ def _get_generator_return(self, node, generator_var): ret_var.AddBinding(self.ctx.convert.unsolvable, [], node) return ret_var - def _yield_from(self, state): - """Helper function for YIELD_FROM and SEND.""" - state, unused_send = state.pop() - state, generator_var = state.pop() - yield_var = self._get_generator_yield(state.node, generator_var) + def byte_YIELD_FROM(self, state, op): + """Implementation of the YIELD_FROM opcode.""" + state, (generator, unused_send) = state.popn(2) + yield_var = self._get_generator_yield(state.node, generator) if yield_var.bindings: self.frame.yield_variable = yield_var _ = self._check_frame_yield(state, yield_var) - ret_var = self._get_generator_return(state.node, generator_var) + ret_var = self._get_generator_return(state.node, generator) return state.push(ret_var) - def byte_YIELD_FROM(self, state, op): - """Implementation of the YIELD_FROM opcode.""" - return self._yield_from(state) - def byte_LOAD_METHOD(self, state, op): """Implementation of the LOAD_METHOD opcode.""" name = op.argval @@ -3458,16 +3450,19 @@ def byte_BINARY_OP(self, state, op): def byte_SEND(self, state, op): """Implementation of SEND opcode.""" - # In 3.11, SEND + YIELD_VALUE + JUMP_BACKWARD_NO_INTERRUPT are used to - # implement `yield from`, which in 3.10 was implemented by the YIELD_FROM - # opcode. See + # In Python 3.11, a SEND + YIELD_VALUE + JUMP_BACKWARD_NO_INTERRUPT sequence + # is used to implement `yield from` (previously implemented by the + # YIELD_FROM opcode). SEND gets a value from a generator, YIELD_VALUE yields + # the value, and JUMP_BACKWARD_NO_INTERRUPT jumps back to SEND, repeatedly, + # until the generator runs out of values. Then SEND pushes the generator's + # return value onto the stack and jumps past JUMP_BACKWARD_NO_INTERRUPT. See # https://github.com/python/cpython/blob/c6d5628be950bdf2c31243b4cc0d9e0b658458dd/Python/ceval.c#L2577 - # for the 3.11 CPython source. To avoid an infinite loop, we have removed - # the JUMP_BACKWARD_NO_INTERRUPT. So instead of attempting to follow the - # 3.11 implementation, we have SEND implement YIELD_FROM and YIELD_VALUE do - # nothing when it detects that the previous opcode was a SEND. - assert isinstance(op.next, opcodes.YIELD_VALUE) - return self._yield_from(state) + # for the CPython source. + state, (generator, unused_send) = state.popn(2) + yield_var = self._get_generator_yield(state.node, generator) + ret_var = self._get_generator_return(state.node, generator) + self.store_jump(op.target, state.push(ret_var)) + return state.push(generator).push(yield_var) def byte_POP_JUMP_FORWARD_IF_NOT_NONE(self, state, op): return vm_utils.jump_if(state, op, self.ctx, From 2058ba4a361b20db8a07625b1e48bfa3c0d82c37 Mon Sep 17 00:00:00 2001 From: rechen Date: Thu, 12 Oct 2023 19:18:56 -0700 Subject: [PATCH 11/11] 3.11: Fix crash caused by stray abstract.Splat. PiperOrigin-RevId: 573074456 --- pytype/tests/test_py_311.py | 9 +++++++++ pytype/vm.py | 9 ++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pytype/tests/test_py_311.py b/pytype/tests/test_py_311.py index 01a2a8f45..b2320f7cf 100644 --- a/pytype/tests/test_py_311.py +++ b/pytype/tests/test_py_311.py @@ -116,6 +116,15 @@ def g(): assert_type(x, int) """) + def test_splat(self): + self.Check(""" + def f(value, g): + converted = [] + if isinstance(value, (dict, *tuple({}))): + converted.append(value) + return g(*converted) + """) + if __name__ == "__main__": test_base.main() diff --git a/pytype/vm.py b/pytype/vm.py index 867946044..8190deebb 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -2372,10 +2372,13 @@ def byte_LIST_EXTEND(self, state, op): # Before Python 3.9, BUILD_TUPLE_UNPACK took care of tuple unpacking. In # 3.9+, this opcode is replaced by LIST_EXTEND+LIST_TO_TUPLE+CALL_FUNCTION, # so CALL_FUNCTION needs to be considered as consuming the list. - if self.ctx.python_version >= (3, 9): - stop_classes = blocks.STORE_OPCODES + (opcodes.CALL_FUNCTION,) + if self.ctx.python_version >= (3, 11): + call_consumers = (opcodes.CALL,) + elif self.ctx.python_version >= (3, 9): + call_consumers = (opcodes.CALL_FUNCTION,) else: - stop_classes = blocks.STORE_OPCODES + call_consumers = () + stop_classes = blocks.STORE_OPCODES + call_consumers while next_op: next_op = next_op.next if isinstance(next_op, opcodes.CALL_FUNCTION_EX):