-
-
Notifications
You must be signed in to change notification settings - Fork 186
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rewrite TypeVar defaults for Generator / AsyncGenerator
- Loading branch information
Showing
4 changed files
with
217 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from __future__ import annotations | ||
|
||
import ast | ||
from typing import Iterable | ||
|
||
from tokenize_rt import Offset | ||
from tokenize_rt import Token | ||
|
||
from pyupgrade._ast_helpers import ast_to_offset | ||
from pyupgrade._ast_helpers import is_name_attr | ||
from pyupgrade._data import register | ||
from pyupgrade._data import State | ||
from pyupgrade._data import TokenFunc | ||
from pyupgrade._token_helpers import find_op | ||
from pyupgrade._token_helpers import parse_call_args | ||
|
||
|
||
def _fix_typevar_default(i: int, tokens: list[Token]) -> None: | ||
j = find_op(tokens, i, '[') | ||
args, end = parse_call_args(tokens, j) | ||
# remove the trailing `None` arguments | ||
del tokens[args[0][1]:args[-1][1]] | ||
|
||
|
||
def _should_rewrite(state: State) -> bool: | ||
return ( | ||
state.settings.min_version >= (3, 13) or ( | ||
not state.settings.keep_runtime_typing and | ||
state.in_annotation and | ||
'annotations' in state.from_imports['__future__'] | ||
) | ||
) | ||
|
||
|
||
def _is_none(node: ast.AST) -> bool: | ||
return isinstance(node, ast.Constant) and node.value is None | ||
|
||
|
||
@register(ast.Subscript) | ||
def visit_Subscript( | ||
state: State, | ||
node: ast.Subscript, | ||
parent: ast.AST, | ||
) -> Iterable[tuple[Offset, TokenFunc]]: | ||
if not _should_rewrite(state): | ||
return | ||
|
||
if ( | ||
is_name_attr( | ||
node.value, | ||
state.from_imports, | ||
('collections.abc', 'typing', 'typing_extensions'), | ||
('Generator',), | ||
) and | ||
isinstance(node.slice, ast.Tuple) and | ||
len(node.slice.elts) == 3 and | ||
_is_none(node.slice.elts[1]) and | ||
_is_none(node.slice.elts[2]) | ||
): | ||
yield ast_to_offset(node), _fix_typevar_default | ||
elif ( | ||
is_name_attr( | ||
node.value, | ||
state.from_imports, | ||
('collections.abc', 'typing', 'typing_extensions'), | ||
('AsyncGenerator',), | ||
) and | ||
isinstance(node.slice, ast.Tuple) and | ||
len(node.slice.elts) == 2 and | ||
_is_none(node.slice.elts[1]) | ||
): | ||
yield ast_to_offset(node), _fix_typevar_default |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from __future__ import annotations | ||
|
||
import pytest | ||
|
||
from pyupgrade._data import Settings | ||
from pyupgrade._main import _fix_plugins | ||
|
||
|
||
@pytest.mark.parametrize( | ||
('s', 'version'), | ||
( | ||
pytest.param( | ||
'from collections.abc import Generator\n' | ||
'def f() -> Generator[int, None, None]: yield 1\n', | ||
(3, 12), | ||
id='not 3.13+, no __future__.annotations', | ||
), | ||
pytest.param( | ||
'from __future__ import annotations\n' | ||
'from collections.abc import Generator\n' | ||
'def f() -> Generator[int]: yield 1\n', | ||
(3, 12), | ||
id='already converted!', | ||
), | ||
pytest.param( | ||
'from __future__ import annotations\n' | ||
'from collections.abc import Generator\n' | ||
'def f() -> Generator[int, int, None]: yield 1\n' | ||
'def g() -> Generator[int, int, int]: yield 1\n', | ||
(3, 12), | ||
id='non-None send/return type', | ||
), | ||
), | ||
) | ||
def test_fix_pep696_noop(s, version): | ||
assert _fix_plugins(s, settings=Settings(min_version=version)) == s | ||
|
||
|
||
def test_fix_pep696_noop_keep_runtime_typing(): | ||
settings = Settings(min_version=(3, 12), keep_runtime_typing=True) | ||
s = '''\ | ||
from __future__ import annotations | ||
from collections.abc import Generator | ||
def f() -> Generator[int, None, None]: yield 1 | ||
''' | ||
assert _fix_plugins(s, settings=settings) == s | ||
|
||
|
||
@pytest.mark.parametrize( | ||
('s', 'expected'), | ||
( | ||
pytest.param( | ||
'from __future__ import annotations\n' | ||
'from typing import Generator\n' | ||
'def f() -> Generator[int, None, None]: yield 1\n', | ||
'from __future__ import annotations\n' | ||
'from collections.abc import Generator\n' | ||
'def f() -> Generator[int]: yield 1\n', | ||
id='typing.Generator', | ||
), | ||
pytest.param( | ||
'from __future__ import annotations\n' | ||
'from typing_extensions import Generator\n' | ||
'def f() -> Generator[int, None, None]: yield 1\n', | ||
'from __future__ import annotations\n' | ||
'from typing_extensions import Generator\n' | ||
'def f() -> Generator[int]: yield 1\n', | ||
id='typing_extensions.Generator', | ||
), | ||
pytest.param( | ||
'from __future__ import annotations\n' | ||
'from collections.abc import Generator\n' | ||
'def f() -> Generator[int, None, None]: yield 1\n', | ||
'from __future__ import annotations\n' | ||
'from collections.abc import Generator\n' | ||
'def f() -> Generator[int]: yield 1\n', | ||
id='collections.abc.Generator', | ||
), | ||
pytest.param( | ||
'from __future__ import annotations\n' | ||
'from collections.abc import AsyncGenerator\n' | ||
'async def f() -> AsyncGenerator[int, None]: yield 1\n', | ||
'from __future__ import annotations\n' | ||
'from collections.abc import AsyncGenerator\n' | ||
'async def f() -> AsyncGenerator[int]: yield 1\n', | ||
id='collections.abc.AsyncGenerator', | ||
), | ||
), | ||
) | ||
def test_fix_pep696_with_future_annotations(s, expected): | ||
assert _fix_plugins(s, settings=Settings(min_version=(3, 12))) == expected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
('s', 'expected'), | ||
( | ||
pytest.param( | ||
'from collections.abc import Generator\n' | ||
'def f() -> Generator[int, None, None]: yield 1\n', | ||
'from collections.abc import Generator\n' | ||
'def f() -> Generator[int]: yield 1\n', | ||
id='Generator', | ||
), | ||
pytest.param( | ||
'from collections.abc import AsyncGenerator\n' | ||
'async def f() -> AsyncGenerator[int, None]: yield 1\n', | ||
'from collections.abc import AsyncGenerator\n' | ||
'async def f() -> AsyncGenerator[int]: yield 1\n', | ||
id='AsyncGenerator', | ||
), | ||
), | ||
) | ||
def test_fix_pep696_with_3_13(s, expected): | ||
assert _fix_plugins(s, settings=Settings(min_version=(3, 13))) == expected |