Skip to content

Commit

Permalink
rewrite TypeVar defaults for Generator / AsyncGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
asottile committed Jun 8, 2024
1 parent 049b8e3 commit 5f2131c
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 0 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,24 @@ Availability:
...
```

### pep 696 TypeVar defaults

Availability:
- File imports `from __future__ import annotations`
- Unless `--keep-runtime-typing` is passed on the commandline.
- `--py313-plus` is passed on the commandline.

```diff
-def f() -> Generator[int, None, None]:
+def f() -> Generator[int]:
yield 1
```

```diff
-async def f() -> AsyncGenerator[int, None]:
+async def f() -> AsyncGenerator[int]:
yield 1
```

### remove quoted annotations

Expand Down
1 change: 1 addition & 0 deletions pyupgrade/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class State(NamedTuple):
'__future__',
'asyncio',
'collections',
'collections.abc',
'functools',
'mmap',
'os',
Expand Down
72 changes: 72 additions & 0 deletions pyupgrade/_plugins/typing_pep696_typevar_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

import ast
from typing import Iterable

from tokenize_rt import Offset
from tokenize_rt import Token

from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._ast_helpers import is_name_attr
from pyupgrade._data import register
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
from pyupgrade._token_helpers import find_op
from pyupgrade._token_helpers import parse_call_args


def _fix_typevar_default(i: int, tokens: list[Token]) -> None:
j = find_op(tokens, i, '[')
args, end = parse_call_args(tokens, j)
# remove the trailing `None` arguments
del tokens[args[0][1]:args[-1][1]]


def _should_rewrite(state: State) -> bool:
return (
state.settings.min_version >= (3, 13) or (
not state.settings.keep_runtime_typing and
state.in_annotation and
'annotations' in state.from_imports['__future__']
)
)


def _is_none(node: ast.AST) -> bool:
return isinstance(node, ast.Constant) and node.value is None


@register(ast.Subscript)
def visit_Subscript(
state: State,
node: ast.Subscript,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
if not _should_rewrite(state):
return

if (
is_name_attr(
node.value,
state.from_imports,
('collections.abc', 'typing', 'typing_extensions'),
('Generator',),
) and
isinstance(node.slice, ast.Tuple) and
len(node.slice.elts) == 3 and
_is_none(node.slice.elts[1]) and
_is_none(node.slice.elts[2])
):
yield ast_to_offset(node), _fix_typevar_default
elif (
is_name_attr(
node.value,
state.from_imports,
('collections.abc', 'typing', 'typing_extensions'),
('AsyncGenerator',),
) and
isinstance(node.slice, ast.Tuple) and
len(node.slice.elts) == 2 and
_is_none(node.slice.elts[1])
):
yield ast_to_offset(node), _fix_typevar_default
126 changes: 126 additions & 0 deletions tests/features/typing_pep696_typevar_defaults_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from __future__ import annotations

import pytest

from pyupgrade._data import Settings
from pyupgrade._main import _fix_plugins


@pytest.mark.parametrize(
('s', 'version'),
(
pytest.param(
'from collections.abc import Generator\n'
'def f() -> Generator[int, None, None]: yield 1\n',
(3, 12),
id='not 3.13+, no __future__.annotations',
),
pytest.param(
'from __future__ import annotations\n'
'from collections.abc import Generator\n'
'def f() -> Generator[int]: yield 1\n',
(3, 12),
id='already converted!',
),
pytest.param(
'from __future__ import annotations\n'
'from collections.abc import Generator\n'
'def f() -> Generator[int, int, None]: yield 1\n'
'def g() -> Generator[int, int, int]: yield 1\n',
(3, 12),
id='non-None send/return type',
),
),
)
def test_fix_pep696_noop(s, version):
assert _fix_plugins(s, settings=Settings(min_version=version)) == s


def test_fix_pep696_noop_keep_runtime_typing():
settings = Settings(min_version=(3, 12), keep_runtime_typing=True)
s = '''\
from __future__ import annotations
from collections.abc import Generator
def f() -> Generator[int, None, None]: yield 1
'''
assert _fix_plugins(s, settings=settings) == s


@pytest.mark.parametrize(
('s', 'expected'),
(
pytest.param(
'from __future__ import annotations\n'
'from typing import Generator\n'
'def f() -> Generator[int, None, None]: yield 1\n',
'from __future__ import annotations\n'
'from collections.abc import Generator\n'
'def f() -> Generator[int]: yield 1\n',
id='typing.Generator',
),
pytest.param(
'from __future__ import annotations\n'
'from typing_extensions import Generator\n'
'def f() -> Generator[int, None, None]: yield 1\n',
'from __future__ import annotations\n'
'from typing_extensions import Generator\n'
'def f() -> Generator[int]: yield 1\n',
id='typing_extensions.Generator',
),
pytest.param(
'from __future__ import annotations\n'
'from collections.abc import Generator\n'
'def f() -> Generator[int, None, None]: yield 1\n',
'from __future__ import annotations\n'
'from collections.abc import Generator\n'
'def f() -> Generator[int]: yield 1\n',
id='collections.abc.Generator',
),
pytest.param(
'from __future__ import annotations\n'
'from collections.abc import AsyncGenerator\n'
'async def f() -> AsyncGenerator[int, None]: yield 1\n',
'from __future__ import annotations\n'
'from collections.abc import AsyncGenerator\n'
'async def f() -> AsyncGenerator[int]: yield 1\n',
id='collections.abc.AsyncGenerator',
),
),
)
def test_fix_pep696_with_future_annotations(s, expected):
assert _fix_plugins(s, settings=Settings(min_version=(3, 12))) == expected


@pytest.mark.parametrize(
('s', 'expected'),
(
pytest.param(
'from collections.abc import Generator\n'
'def f() -> Generator[int, None, None]: yield 1\n',
'from collections.abc import Generator\n'
'def f() -> Generator[int]: yield 1\n',
id='Generator',
),
pytest.param(
'from collections.abc import AsyncGenerator\n'
'async def f() -> AsyncGenerator[int, None]: yield 1\n',
'from collections.abc import AsyncGenerator\n'
'async def f() -> AsyncGenerator[int]: yield 1\n',
id='AsyncGenerator',
),
),
)
def test_fix_pep696_with_3_13(s, expected):
assert _fix_plugins(s, settings=Settings(min_version=(3, 13))) == expected

0 comments on commit 5f2131c

Please sign in to comment.