From 5f2131c080c946afe1b62590f082fbb240400e2a Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Sat, 8 Jun 2024 17:02:07 -0400 Subject: [PATCH] rewrite TypeVar defaults for Generator / AsyncGenerator --- README.md | 18 +++ pyupgrade/_data.py | 1 + .../typing_pep696_typevar_defaults.py | 72 ++++++++++ .../typing_pep696_typevar_defaults_test.py | 126 ++++++++++++++++++ 4 files changed, 217 insertions(+) create mode 100644 pyupgrade/_plugins/typing_pep696_typevar_defaults.py create mode 100644 tests/features/typing_pep696_typevar_defaults_test.py diff --git a/README.md b/README.md index 623a20b1..9e5ef11f 100644 --- a/README.md +++ b/README.md @@ -754,6 +754,24 @@ Availability: ... ``` +### pep 696 TypeVar defaults + +Availability: +- File imports `from __future__ import annotations` + - Unless `--keep-runtime-typing` is passed on the commandline. +- `--py313-plus` is passed on the commandline. + +```diff +-def f() -> Generator[int, None, None]: ++def f() -> Generator[int]: + yield 1 +``` + +```diff +-async def f() -> AsyncGenerator[int, None]: ++async def f() -> AsyncGenerator[int]: + yield 1 +``` ### remove quoted annotations diff --git a/pyupgrade/_data.py b/pyupgrade/_data.py index ab3a4942..eae66c57 100644 --- a/pyupgrade/_data.py +++ b/pyupgrade/_data.py @@ -40,6 +40,7 @@ class State(NamedTuple): '__future__', 'asyncio', 'collections', + 'collections.abc', 'functools', 'mmap', 'os', diff --git a/pyupgrade/_plugins/typing_pep696_typevar_defaults.py b/pyupgrade/_plugins/typing_pep696_typevar_defaults.py new file mode 100644 index 00000000..3fe7c87e --- /dev/null +++ b/pyupgrade/_plugins/typing_pep696_typevar_defaults.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import ast +from typing import Iterable + +from tokenize_rt import Offset +from tokenize_rt import Token + +from pyupgrade._ast_helpers import ast_to_offset +from pyupgrade._ast_helpers import is_name_attr +from pyupgrade._data import register +from pyupgrade._data import State +from pyupgrade._data import TokenFunc +from pyupgrade._token_helpers import find_op +from pyupgrade._token_helpers import parse_call_args + + +def _fix_typevar_default(i: int, tokens: list[Token]) -> None: + j = find_op(tokens, i, '[') + args, end = parse_call_args(tokens, j) + # remove the trailing `None` arguments + del tokens[args[0][1]:args[-1][1]] + + +def _should_rewrite(state: State) -> bool: + return ( + state.settings.min_version >= (3, 13) or ( + not state.settings.keep_runtime_typing and + state.in_annotation and + 'annotations' in state.from_imports['__future__'] + ) + ) + + +def _is_none(node: ast.AST) -> bool: + return isinstance(node, ast.Constant) and node.value is None + + +@register(ast.Subscript) +def visit_Subscript( + state: State, + node: ast.Subscript, + parent: ast.AST, +) -> Iterable[tuple[Offset, TokenFunc]]: + if not _should_rewrite(state): + return + + if ( + is_name_attr( + node.value, + state.from_imports, + ('collections.abc', 'typing', 'typing_extensions'), + ('Generator',), + ) and + isinstance(node.slice, ast.Tuple) and + len(node.slice.elts) == 3 and + _is_none(node.slice.elts[1]) and + _is_none(node.slice.elts[2]) + ): + yield ast_to_offset(node), _fix_typevar_default + elif ( + is_name_attr( + node.value, + state.from_imports, + ('collections.abc', 'typing', 'typing_extensions'), + ('AsyncGenerator',), + ) and + isinstance(node.slice, ast.Tuple) and + len(node.slice.elts) == 2 and + _is_none(node.slice.elts[1]) + ): + yield ast_to_offset(node), _fix_typevar_default diff --git a/tests/features/typing_pep696_typevar_defaults_test.py b/tests/features/typing_pep696_typevar_defaults_test.py new file mode 100644 index 00000000..99bb5f4a --- /dev/null +++ b/tests/features/typing_pep696_typevar_defaults_test.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import pytest + +from pyupgrade._data import Settings +from pyupgrade._main import _fix_plugins + + +@pytest.mark.parametrize( + ('s', 'version'), + ( + pytest.param( + 'from collections.abc import Generator\n' + 'def f() -> Generator[int, None, None]: yield 1\n', + (3, 12), + id='not 3.13+, no __future__.annotations', + ), + pytest.param( + 'from __future__ import annotations\n' + 'from collections.abc import Generator\n' + 'def f() -> Generator[int]: yield 1\n', + (3, 12), + id='already converted!', + ), + pytest.param( + 'from __future__ import annotations\n' + 'from collections.abc import Generator\n' + 'def f() -> Generator[int, int, None]: yield 1\n' + 'def g() -> Generator[int, int, int]: yield 1\n', + (3, 12), + id='non-None send/return type', + ), + ), +) +def test_fix_pep696_noop(s, version): + assert _fix_plugins(s, settings=Settings(min_version=version)) == s + + +def test_fix_pep696_noop_keep_runtime_typing(): + settings = Settings(min_version=(3, 12), keep_runtime_typing=True) + s = '''\ +from __future__ import annotations +from collections.abc import Generator +def f() -> Generator[int, None, None]: yield 1 +''' + assert _fix_plugins(s, settings=settings) == s + + +@pytest.mark.parametrize( + ('s', 'expected'), + ( + pytest.param( + 'from __future__ import annotations\n' + 'from typing import Generator\n' + 'def f() -> Generator[int, None, None]: yield 1\n', + + 'from __future__ import annotations\n' + 'from collections.abc import Generator\n' + 'def f() -> Generator[int]: yield 1\n', + + id='typing.Generator', + ), + pytest.param( + 'from __future__ import annotations\n' + 'from typing_extensions import Generator\n' + 'def f() -> Generator[int, None, None]: yield 1\n', + + 'from __future__ import annotations\n' + 'from typing_extensions import Generator\n' + 'def f() -> Generator[int]: yield 1\n', + + id='typing_extensions.Generator', + ), + pytest.param( + 'from __future__ import annotations\n' + 'from collections.abc import Generator\n' + 'def f() -> Generator[int, None, None]: yield 1\n', + + 'from __future__ import annotations\n' + 'from collections.abc import Generator\n' + 'def f() -> Generator[int]: yield 1\n', + + id='collections.abc.Generator', + ), + pytest.param( + 'from __future__ import annotations\n' + 'from collections.abc import AsyncGenerator\n' + 'async def f() -> AsyncGenerator[int, None]: yield 1\n', + + 'from __future__ import annotations\n' + 'from collections.abc import AsyncGenerator\n' + 'async def f() -> AsyncGenerator[int]: yield 1\n', + + id='collections.abc.AsyncGenerator', + ), + ), +) +def test_fix_pep696_with_future_annotations(s, expected): + assert _fix_plugins(s, settings=Settings(min_version=(3, 12))) == expected + + +@pytest.mark.parametrize( + ('s', 'expected'), + ( + pytest.param( + 'from collections.abc import Generator\n' + 'def f() -> Generator[int, None, None]: yield 1\n', + + 'from collections.abc import Generator\n' + 'def f() -> Generator[int]: yield 1\n', + + id='Generator', + ), + pytest.param( + 'from collections.abc import AsyncGenerator\n' + 'async def f() -> AsyncGenerator[int, None]: yield 1\n', + + 'from collections.abc import AsyncGenerator\n' + 'async def f() -> AsyncGenerator[int]: yield 1\n', + + id='AsyncGenerator', + ), + ), +) +def test_fix_pep696_with_3_13(s, expected): + assert _fix_plugins(s, settings=Settings(min_version=(3, 13))) == expected