diff --git a/flake8_trio/runner.py b/flake8_trio/runner.py index 229979f..a1b843d 100644 --- a/flake8_trio/runner.py +++ b/flake8_trio/runner.py @@ -26,7 +26,7 @@ from libcst import Module from .base import Error, Options - from .visitors.flake8triovisitor import Flake8TrioVisitor, Flake8TrioVisitor_cst + from .visitors.flake8triovisitor import Flake8AsyncVisitor, Flake8AsyncVisitor_cst @dataclass @@ -73,7 +73,7 @@ def visit(self, node: ast.AST): """Visit a node.""" # tracks the subclasses that, from this node on, iterated through it's subfields # we need to remember it so we can restore it at the end of the function. - novisit: set[Flake8TrioVisitor] = set() + novisit: set[Flake8AsyncVisitor] = set() method = "visit_" + node.__class__.__name__ @@ -122,14 +122,14 @@ def __init__(self, options: Options, module: Module): # Could possibly enable/disable utility visitors here, if visitors declared # dependencies - self.utility_visitors: tuple[Flake8TrioVisitor_cst, ...] = tuple( + self.utility_visitors: tuple[Flake8AsyncVisitor_cst, ...] = tuple( v(self.state) for v in utility_visitors ) # sort the error classes to get predictable behaviour when multiple autofixers # are enabled sorted_error_classes_cst = sorted(ERROR_CLASSES_CST, key=lambda x: x.__name__) - self.visitors: tuple[Flake8TrioVisitor_cst, ...] = tuple( + self.visitors: tuple[Flake8AsyncVisitor_cst, ...] = tuple( v(self.state) for v in sorted_error_classes_cst if self.selected(v.error_codes) diff --git a/flake8_trio/visitors/__init__.py b/flake8_trio/visitors/__init__.py index 6aa2fe7..ab392ba 100644 --- a/flake8_trio/visitors/__init__.py +++ b/flake8_trio/visitors/__init__.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .flake8triovisitor import Flake8TrioVisitor, Flake8TrioVisitor_cst + from .flake8triovisitor import Flake8AsyncVisitor, Flake8AsyncVisitor_cst __all__ = [ "ERROR_CLASSES", @@ -19,11 +19,11 @@ "utility_visitors", "utility_visitors_cst", ] -ERROR_CLASSES: set[type[Flake8TrioVisitor]] = set() -ERROR_CLASSES_CST: set[type[Flake8TrioVisitor_cst]] = set() +ERROR_CLASSES: set[type[Flake8AsyncVisitor]] = set() +ERROR_CLASSES_CST: set[type[Flake8AsyncVisitor_cst]] = set() default_disabled_error_codes: list[str] = [] -utility_visitors: set[type[Flake8TrioVisitor]] = set() -utility_visitors_cst: set[type[Flake8TrioVisitor_cst]] = set() +utility_visitors: set[type[Flake8AsyncVisitor]] = set() +utility_visitors_cst: set[type[Flake8AsyncVisitor_cst]] = set() # Import all visitors so their decorators run, filling the above containers # This has to be done at the end to avoid circular imports diff --git a/flake8_trio/visitors/flake8triovisitor.py b/flake8_trio/visitors/flake8triovisitor.py index 3de8c6b..ae09ddf 100644 --- a/flake8_trio/visitors/flake8triovisitor.py +++ b/flake8_trio/visitors/flake8triovisitor.py @@ -24,7 +24,7 @@ ERROR_CODE_LEN = 8 -class Flake8TrioVisitor(ast.NodeVisitor, ABC): +class Flake8AsyncVisitor(ast.NodeVisitor, ABC): # abstract attribute by not providing a value error_codes: Mapping[str, str] @@ -39,7 +39,7 @@ def __init__(self, shared_state: SharedState): # mark variables that shouldn't be saved/loaded in self.get_state self.nocopy = { - "_Flake8TrioVisitor__state", + "_Flake8AsyncVisitor__state", "error_codes", "nocopy", "novisit", @@ -158,7 +158,7 @@ def add_library(self, name: str) -> None: self.__state.library = (*self.__state.library, name) -class Flake8TrioVisitor_cst(cst.CSTTransformer, ABC): +class Flake8AsyncVisitor_cst(cst.CSTTransformer, ABC): # abstract attribute by not providing a value error_codes: Mapping[str, str] METADATA_DEPENDENCIES = (PositionProvider,) diff --git a/flake8_trio/visitors/helpers.py b/flake8_trio/visitors/helpers.py index d52c6f1..2eac570 100644 --- a/flake8_trio/visitors/helpers.py +++ b/flake8_trio/visitors/helpers.py @@ -25,12 +25,16 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Sequence - from .flake8triovisitor import Flake8TrioVisitor, Flake8TrioVisitor_cst, HasLineCol + from .flake8triovisitor import ( + Flake8AsyncVisitor, + Flake8AsyncVisitor_cst, + HasLineCol, + ) - T = TypeVar("T", bound=Flake8TrioVisitor) - T_CST = TypeVar("T_CST", bound=Flake8TrioVisitor_cst) + T = TypeVar("T", bound=Flake8AsyncVisitor) + T_CST = TypeVar("T_CST", bound=Flake8AsyncVisitor_cst) T_EITHER = TypeVar( - "T_EITHER", bound=Union[Flake8TrioVisitor, Flake8TrioVisitor_cst] + "T_EITHER", bound=Union[Flake8AsyncVisitor, Flake8AsyncVisitor_cst] ) diff --git a/flake8_trio/visitors/visitor100.py b/flake8_trio/visitors/visitor100.py index 187ee9e..7284372 100644 --- a/flake8_trio/visitors/visitor100.py +++ b/flake8_trio/visitors/visitor100.py @@ -13,7 +13,7 @@ import libcst as cst import libcst.matchers as m -from .flake8triovisitor import Flake8TrioVisitor_cst +from .flake8triovisitor import Flake8AsyncVisitor_cst from .helpers import ( AttributeCall, error_class_cst, @@ -26,7 +26,7 @@ @error_class_cst -class Visitor100_libcst(Flake8TrioVisitor_cst): +class Visitor100_libcst(Flake8AsyncVisitor_cst): error_codes: Mapping[str, str] = { "ASYNC100": ( "{0}.{1} context contains no checkpoints, remove the context or add" diff --git a/flake8_trio/visitors/visitor101.py b/flake8_trio/visitors/visitor101.py index 31091a7..29cbbf7 100644 --- a/flake8_trio/visitors/visitor101.py +++ b/flake8_trio/visitors/visitor101.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any -from .flake8triovisitor import Flake8TrioVisitor_cst +from .flake8triovisitor import Flake8AsyncVisitor_cst from .helpers import ( cancel_scope_names, error_class_cst, @@ -23,7 +23,7 @@ @error_class_cst -class Visitor101(Flake8TrioVisitor_cst): +class Visitor101(Flake8AsyncVisitor_cst): error_codes: Mapping[str, str] = { "ASYNC101": ( "`yield` inside a nursery or cancel scope is only safe when implementing " diff --git a/flake8_trio/visitors/visitor102.py b/flake8_trio/visitors/visitor102.py index d02725d..5afc63e 100644 --- a/flake8_trio/visitors/visitor102.py +++ b/flake8_trio/visitors/visitor102.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any from ..base import Statement -from .flake8triovisitor import Flake8TrioVisitor +from .flake8triovisitor import Flake8AsyncVisitor from .helpers import cancel_scope_names, critical_except, error_class, get_matching_call if TYPE_CHECKING: @@ -17,7 +17,7 @@ @error_class -class Visitor102(Flake8TrioVisitor): +class Visitor102(Flake8AsyncVisitor): error_codes: Mapping[str, str] = { "ASYNC102": ( "await inside {0.name} on line {0.lineno} must have shielded cancel " diff --git a/flake8_trio/visitors/visitor103_104.py b/flake8_trio/visitors/visitor103_104.py index c5571fb..8ce2992 100644 --- a/flake8_trio/visitors/visitor103_104.py +++ b/flake8_trio/visitors/visitor103_104.py @@ -11,7 +11,7 @@ import ast from typing import TYPE_CHECKING, Any -from .flake8triovisitor import Flake8TrioVisitor +from .flake8triovisitor import Flake8AsyncVisitor from .helpers import critical_except, error_class, iter_guaranteed_once if TYPE_CHECKING: @@ -36,7 +36,7 @@ @error_class -class Visitor103_104(Flake8TrioVisitor): +class Visitor103_104(Flake8AsyncVisitor): error_codes: Mapping[str, str] = _error_codes def __init__(self, *args: Any, **kwargs: Any): diff --git a/flake8_trio/visitors/visitor105.py b/flake8_trio/visitors/visitor105.py index fc0138a..594ad89 100644 --- a/flake8_trio/visitors/visitor105.py +++ b/flake8_trio/visitors/visitor105.py @@ -5,7 +5,7 @@ import ast from typing import TYPE_CHECKING, Any -from .flake8triovisitor import Flake8TrioVisitor +from .flake8triovisitor import Flake8AsyncVisitor from .helpers import error_class if TYPE_CHECKING: @@ -42,7 +42,7 @@ @error_class -class Visitor105(Flake8TrioVisitor): +class Visitor105(Flake8AsyncVisitor): error_codes: Mapping[str, str] = { "ASYNC105": "{0} async {1} must be immediately awaited.", } diff --git a/flake8_trio/visitors/visitor111.py b/flake8_trio/visitors/visitor111.py index bd10951..397d0a2 100644 --- a/flake8_trio/visitors/visitor111.py +++ b/flake8_trio/visitors/visitor111.py @@ -5,7 +5,7 @@ import ast from typing import TYPE_CHECKING, Any, NamedTuple -from .flake8triovisitor import Flake8TrioVisitor +from .flake8triovisitor import Flake8AsyncVisitor from .helpers import error_class, get_matching_call if TYPE_CHECKING: @@ -13,7 +13,7 @@ @error_class -class Visitor111(Flake8TrioVisitor): +class Visitor111(Flake8AsyncVisitor): error_codes: Mapping[str, str] = { "ASYNC111": ( "variable {2} is usable within the context manager on line {0}, but that " diff --git a/flake8_trio/visitors/visitor118.py b/flake8_trio/visitors/visitor118.py index 02ce834..b8a1d04 100644 --- a/flake8_trio/visitors/visitor118.py +++ b/flake8_trio/visitors/visitor118.py @@ -10,7 +10,7 @@ import re from typing import TYPE_CHECKING -from .flake8triovisitor import Flake8TrioVisitor +from .flake8triovisitor import Flake8AsyncVisitor from .helpers import error_class if TYPE_CHECKING: @@ -18,7 +18,7 @@ @error_class -class Visitor118(Flake8TrioVisitor): +class Visitor118(Flake8AsyncVisitor): error_codes: Mapping[str, str] = { "ASYNC118": ( "Don't assign the value of `anyio.get_cancelled_exc_class()` to a variable," diff --git a/flake8_trio/visitors/visitor2xx.py b/flake8_trio/visitors/visitor2xx.py index 3203d36..a7349f7 100644 --- a/flake8_trio/visitors/visitor2xx.py +++ b/flake8_trio/visitors/visitor2xx.py @@ -14,7 +14,7 @@ import re from typing import TYPE_CHECKING, Any -from .flake8triovisitor import Flake8TrioVisitor +from .flake8triovisitor import Flake8AsyncVisitor from .helpers import error_class, fnmatch_qualified_name, get_matching_call if TYPE_CHECKING: @@ -22,7 +22,7 @@ @error_class -class Visitor200(Flake8TrioVisitor): +class Visitor200(Flake8AsyncVisitor): error_codes: Mapping[str, str] = { "ASYNC200": ( "User-configured blocking sync call {0} in async function, consider " diff --git a/flake8_trio/visitors/visitor91x.py b/flake8_trio/visitors/visitor91x.py index 9963d1c..c1af6d9 100644 --- a/flake8_trio/visitors/visitor91x.py +++ b/flake8_trio/visitors/visitor91x.py @@ -16,7 +16,7 @@ from libcst.metadata import PositionProvider from ..base import Statement -from .flake8triovisitor import Flake8TrioVisitor_cst +from .flake8triovisitor import Flake8AsyncVisitor_cst from .helpers import ( disabled_by_default, error_class_cst, @@ -226,7 +226,7 @@ def leave_Yield( @error_class_cst @disabled_by_default -class Visitor91X(Flake8TrioVisitor_cst, CommonVisitors): +class Visitor91X(Flake8AsyncVisitor_cst, CommonVisitors): error_codes: Mapping[str, str] = { "ASYNC910": ( "{0} from async function with no guaranteed checkpoint or exception " diff --git a/flake8_trio/visitors/visitor_utility.py b/flake8_trio/visitors/visitor_utility.py index 35efe59..81bdcf0 100644 --- a/flake8_trio/visitors/visitor_utility.py +++ b/flake8_trio/visitors/visitor_utility.py @@ -10,7 +10,7 @@ import libcst.matchers as m from libcst.metadata import PositionProvider -from .flake8triovisitor import Flake8TrioVisitor, Flake8TrioVisitor_cst +from .flake8triovisitor import Flake8AsyncVisitor, Flake8AsyncVisitor_cst from .helpers import utility_visitor, utility_visitor_cst if TYPE_CHECKING: @@ -21,7 +21,7 @@ @utility_visitor -class VisitorTypeTracker(Flake8TrioVisitor): +class VisitorTypeTracker(Flake8AsyncVisitor): def visit_AsyncFunctionDef( self, node: ast.AsyncFunctionDef | ast.FunctionDef | ast.Lambda ): @@ -101,7 +101,7 @@ def visit_With(self, node: ast.With | ast.AsyncWith): @utility_visitor -class VisitorAwaitModifier(Flake8TrioVisitor): +class VisitorAwaitModifier(Flake8AsyncVisitor): def visit_Await(self, node: ast.Await): if isinstance(node.value, ast.Call): # add attribute to indicate it's awaited @@ -109,7 +109,7 @@ def visit_Await(self, node: ast.Await): @utility_visitor -class VisitorLibraryHandler(Flake8TrioVisitor): +class VisitorLibraryHandler(Flake8AsyncVisitor): def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) # check whether library we're working towards has been explicitly @@ -126,7 +126,7 @@ def visit_Import(self, node: ast.Import): @utility_visitor_cst -class VisitorLibraryHandler_cst(Flake8TrioVisitor_cst): +class VisitorLibraryHandler_cst(Flake8AsyncVisitor_cst): def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) # check whether library we're working towards has been explicitly @@ -172,7 +172,7 @@ def _find_noqa(physical_line: str) -> Match[str] | None: @utility_visitor_cst -class NoqaHandler(Flake8TrioVisitor_cst): +class NoqaHandler(Flake8AsyncVisitor_cst): def visit_Comment(self, node: cst.Comment): noqa_match = _find_noqa(node.value) if noqa_match is None: diff --git a/flake8_trio/visitors/visitors.py b/flake8_trio/visitors/visitors.py index a53f34b..9a69112 100644 --- a/flake8_trio/visitors/visitors.py +++ b/flake8_trio/visitors/visitors.py @@ -5,7 +5,7 @@ import ast from typing import TYPE_CHECKING, Any, cast -from .flake8triovisitor import Flake8TrioVisitor +from .flake8triovisitor import Flake8AsyncVisitor from .helpers import disabled_by_default, error_class, get_matching_call, has_decorator if TYPE_CHECKING: @@ -13,7 +13,7 @@ @error_class -class Visitor106(Flake8TrioVisitor): +class Visitor106(Flake8AsyncVisitor): error_codes: Mapping[str, str] = { "ASYNC106": "{0} must be imported with `import {0}` for the linter to work.", } @@ -29,7 +29,7 @@ def visit_Import(self, node: ast.Import): @error_class -class Visitor109(Flake8TrioVisitor): +class Visitor109(Flake8AsyncVisitor): error_codes: Mapping[str, str] = { "ASYNC109": ( "Async function definition with a `timeout` parameter - use " @@ -50,7 +50,7 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): @error_class -class Visitor110(Flake8TrioVisitor): +class Visitor110(Flake8AsyncVisitor): error_codes: Mapping[str, str] = { "ASYNC110": ( "`while : await {0}.sleep()` should be replaced by " @@ -69,7 +69,7 @@ def visit_While(self, node: ast.While): @error_class -class Visitor112(Flake8TrioVisitor): +class Visitor112(Flake8AsyncVisitor): error_codes: Mapping[str, str] = { "ASYNC112": ( "Redundant nursery {}, consider replacing with directly awaiting " @@ -121,7 +121,7 @@ def visit_With(self, node: ast.With | ast.AsyncWith): @error_class -class Visitor113(Flake8TrioVisitor): +class Visitor113(Flake8AsyncVisitor): error_codes: Mapping[str, str] = { "ASYNC113": ( "Dangerous `.start_soon()`, function might not be executed before" @@ -188,7 +188,7 @@ def is_nursery_call(node: ast.expr): # name, so may miss cases where functions are named the same in different modules/classes # and option names are specified including the module name. @error_class -class Visitor114(Flake8TrioVisitor): +class Visitor114(Flake8AsyncVisitor): error_codes: Mapping[str, str] = { "ASYNC114": ( "Startable function {} not in --startable-in-context-manager parameter " @@ -210,7 +210,7 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): # Suggests replacing all `trio.sleep(0)` with the more suggestive # `trio.lowlevel.checkpoint()` @error_class -class Visitor115(Flake8TrioVisitor): +class Visitor115(Flake8AsyncVisitor): error_codes: Mapping[str, str] = { "ASYNC115": "Use `{0}.lowlevel.checkpoint()` instead of `{0}.sleep(0)`.", } @@ -227,7 +227,7 @@ def visit_Call(self, node: ast.Call): @error_class -class Visitor116(Flake8TrioVisitor): +class Visitor116(Flake8AsyncVisitor): error_codes: Mapping[str, str] = { "ASYNC116": ( "{0}.sleep() with >24 hour interval should usually be " @@ -269,7 +269,7 @@ def visit_Call(self, node: ast.Call): # anyio does not have MultiError, so this check is trio-only @error_class -class Visitor117(Flake8TrioVisitor): +class Visitor117(Flake8AsyncVisitor): error_codes: Mapping[str, str] = { "ASYNC117": "Reference to {}, prefer [exceptiongroup.]BaseExceptionGroup.", } @@ -286,7 +286,7 @@ def visit_Attribute(self, node: ast.Attribute): @error_class @disabled_by_default -class Visitor900(Flake8TrioVisitor): +class Visitor900(Flake8AsyncVisitor): error_codes: Mapping[str, str] = { "ASYNC900": "Async generator without `@asynccontextmanager` not allowed." } diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py index 4715ae1..39b800c 100644 --- a/tests/test_flake8_trio.py +++ b/tests/test_flake8_trio.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: from collections.abc import Iterable, Sequence - from flake8_trio.visitors.flake8triovisitor import Flake8TrioVisitor + from flake8_trio.visitors.flake8triovisitor import Flake8AsyncVisitor AUTOFIX_DIR = Path(__file__).parent / "autofix_files" @@ -75,7 +75,7 @@ def check_version(test: str): # mypy does not see that both types have error_codes -ERROR_CODES: dict[str, Flake8TrioVisitor] = { +ERROR_CODES: dict[str, Flake8AsyncVisitor] = { err_code: err_class # type: ignore[misc] for err_class in (*ERROR_CLASSES, *ERROR_CLASSES_CST) for err_code in err_class.error_codes # type: ignore[attr-defined]