From df403c3a98943be5bcfb1220557ec690e8a478d7 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 10 May 2024 14:00:33 +0200 Subject: [PATCH] use matcher for matching instead of custom logic --- flake8_async/visitors/helpers.py | 87 ++++++++++++++++------------ tests/eval_files/async912.py | 29 +++++++++- tests/eval_files/async912_asyncio.py | 23 ++++++++ 3 files changed, 101 insertions(+), 38 deletions(-) diff --git a/flake8_async/visitors/helpers.py b/flake8_async/visitors/helpers.py index 7b3147e..d33e099 100644 --- a/flake8_async/visitors/helpers.py +++ b/flake8_async/visitors/helpers.py @@ -323,55 +323,68 @@ class AttributeCall(NamedTuple): function: str +# the custom __or__ in libcst breaks pyright type checking. It's possible to use +# `Union` as a workaround ... except pyupgrade will automatically replace that. +# So we have to resort to specifying one of the base classes. +# See https://github.com/Instagram/LibCST/issues/1143 +def build_cst_matcher(attr: str) -> m.BaseExpression: + """Build a cst matcher structure with attributes&names matching a string `a.b.c`.""" + if "." not in attr: + return m.Name(value=attr) + body, tail = attr.rsplit(".") + return m.Attribute(value=build_cst_matcher(body), attr=m.Name(value=tail)) + + +def identifier_to_string(attr: cst.Name | cst.Attribute) -> str: + if isinstance(attr, cst.Name): + return attr.value + assert isinstance(attr.value, (cst.Attribute, cst.Name)) + return identifier_to_string(attr.value) + "." + attr.attr.value + + def with_has_call( node: cst.With, *names: str, base: Iterable[str] = ("trio", "anyio") ) -> list[AttributeCall]: + """Check if a with statement has a matching call, returning a list with matches. + + `names` specify the names of functions to match, `base` specifies the + library/module(s) the function must be in. + The list elements in the return value are named tuples with the matched node, + base and function. + + Examples_ + + `with_has_call(node, "bar", base="foo")` matches foo.bar. + `with_has_call(node, "bar", "bee", base=("foo", "a.b.c")` matches + `foo.bar`, `foo.bee`, `a.b.c.bar`, and `a.b.c.bee`. + + """ if isinstance(base, str): base = (base,) # pragma: no cover - for b in base: - if b.count(".") > 1: # pragma: no cover - raise NotImplementedError("Does not support 3-module bases atm.") + # build matcher, using SaveMatchedNode to save the base and the function name. + matcher = m.Call( + func=m.Attribute( + value=m.SaveMatchedNode( + m.OneOf(*(build_cst_matcher(b) for b in base)), name="base" + ), + attr=m.SaveMatchedNode( + oneof_names(*names), + name="function", + ), + ) + ) res_list: list[AttributeCall] = [] for item in node.items: - if res := m.extract( - item.item, - m.Call( - func=m.Attribute( - value=m.SaveMatchedNode(m.Name() | m.Attribute(), name="library"), - attr=m.SaveMatchedNode( - oneof_names(*names), - name="function", - ), - ) - ), - ): + if res := m.extract(item.item, matcher): assert isinstance(item.item, cst.Call) - assert isinstance(res["library"], (cst.Name, cst.Attribute)) + assert isinstance(res["base"], (cst.Name, cst.Attribute)) assert isinstance(res["function"], cst.Name) - library_node = res["library"] - for library_str in base: - if ( - isinstance(library_node, cst.Name) - and library_str == library_node.value - ): - break - if ( - isinstance(library_node, cst.Attribute) - and isinstance(library_node.value, cst.Name) - and "." in library_str - ): - base_1, base_2 = library_str.split(".") - if ( - library_node.attr.value == base_2 - and library_node.value.value == base_1 - ): - break - else: - continue res_list.append( - AttributeCall(item.item, library_str, res["function"].value) + AttributeCall( + item.item, identifier_to_string(res["base"]), res["function"].value + ) ) return res_list diff --git a/tests/eval_files/async912.py b/tests/eval_files/async912.py index 3ebbc67..c2abf04 100644 --- a/tests/eval_files/async912.py +++ b/tests/eval_files/async912.py @@ -7,6 +7,8 @@ # of not testing both in the same file, or running with NOAUTOFIX. # NOAUTOFIX +from typing import TypeVar + import trio @@ -27,6 +29,7 @@ async def foo(): with trio.CancelScope(0.1): # ASYNC100: 9, "trio", "CancelScope" ... + # conditional cases trigger ASYNC912 with trio.move_on_after(0.1): # ASYNC912: 9 if bar(): await trio.lowlevel.checkpoint() @@ -51,16 +54,23 @@ async def foo(): with open(""): ... + # don't error with guaranteed checkpoint with trio.move_on_after(0.1): await trio.lowlevel.checkpoint() + with trio.move_on_after(0.1): + if bar(): + await trio.lowlevel.checkpoint() + else: + await trio.lowlevel.checkpoint() + # both scopes error in nested cases with trio.move_on_after(0.1): # ASYNC912: 9 with trio.move_on_after(0.1): # ASYNC912: 13 if bar(): await trio.lowlevel.checkpoint() # We don't know which cancelscope will trigger first, so to avoid false - # positives on tricky-but-valid cases we don't raise any error for the outer one. + # alarms on tricky-but-valid cases we don't raise any error for the outer one. with trio.move_on_after(0.1): with trio.move_on_after(0.1): await trio.lowlevel.checkpoint() @@ -75,6 +85,7 @@ async def foo(): await trio.lowlevel.checkpoint() await trio.lowlevel.checkpoint() + # check correct line gives error # fmt: off with ( # a @@ -94,6 +105,7 @@ async def foo(): await trio.lowlevel.checkpoint() # fmt: on + # error on each call with multiple matching calls in the same with with ( trio.move_on_after(0.1), # ASYNC912: 8 trio.fail_at(5), # ASYNC912: 8 @@ -101,6 +113,21 @@ async def foo(): if bar(): await trio.lowlevel.checkpoint() + # wrapped calls do not raise errors + T = TypeVar("T") + + def customWrapper(a: T) -> T: + return a + + with customWrapper(trio.fail_at(10)): + ... + with (res := trio.fail_at(10)): + ... + # but saving with `as` does + with trio.fail_at(10) as res: # ASYNC912: 9 + if bar(): + await trio.lowlevel.checkpoint() + # TODO: issue #240 async def livelocks(): diff --git a/tests/eval_files/async912_asyncio.py b/tests/eval_files/async912_asyncio.py index 5a4227e..ef9200b 100644 --- a/tests/eval_files/async912_asyncio.py +++ b/tests/eval_files/async912_asyncio.py @@ -11,11 +11,16 @@ import asyncio +from typing import Any + def bar() -> bool: return False +def customWrapper(a: object) -> object: ... + + async def foo(): # async100 async with asyncio.timeout(10): # ASYNC100: 15, "asyncio", "timeout" @@ -51,3 +56,21 @@ async def foo(): async with asyncio.timeouts.timeout_at(10): # ASYNC912: 15 if bar(): await foo() + + # double check that helper methods used by visitor don't trigger erroneously + timeouts: Any + timeout_at: Any + async with asyncio.timeout_at.timeouts(10): + ... + async with timeouts.asyncio.timeout_at(10): + ... + async with timeouts.timeout_at.asyncio(10): + ... + async with timeout_at.asyncio.timeouts(10): + ... + async with timeout_at.timeouts.asyncio(10): + ... + async with foo.timeout(10): + ... + async with asyncio.timeouts(10): + ...