From df403c3a98943be5bcfb1220557ec690e8a478d7 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 10 May 2024 14:00:33 +0200
Subject: [PATCH] use matcher for matching instead of custom logic
---
flake8_async/visitors/helpers.py | 87 ++++++++++++++++------------
tests/eval_files/async912.py | 29 +++++++++-
tests/eval_files/async912_asyncio.py | 23 ++++++++
3 files changed, 101 insertions(+), 38 deletions(-)
diff --git a/flake8_async/visitors/helpers.py b/flake8_async/visitors/helpers.py
index 7b3147e..d33e099 100644
--- a/flake8_async/visitors/helpers.py
+++ b/flake8_async/visitors/helpers.py
@@ -323,55 +323,68 @@ class AttributeCall(NamedTuple):
function: str
+# the custom __or__ in libcst breaks pyright type checking. It's possible to use
+# `Union` as a workaround ... except pyupgrade will automatically replace that.
+# So we have to resort to specifying one of the base classes.
+# See https://github.com/Instagram/LibCST/issues/1143
+def build_cst_matcher(attr: str) -> m.BaseExpression:
+ """Build a cst matcher structure with attributes&names matching a string `a.b.c`."""
+ if "." not in attr:
+ return m.Name(value=attr)
+ body, tail = attr.rsplit(".")
+ return m.Attribute(value=build_cst_matcher(body), attr=m.Name(value=tail))
+
+
+def identifier_to_string(attr: cst.Name | cst.Attribute) -> str:
+ if isinstance(attr, cst.Name):
+ return attr.value
+ assert isinstance(attr.value, (cst.Attribute, cst.Name))
+ return identifier_to_string(attr.value) + "." + attr.attr.value
+
+
def with_has_call(
node: cst.With, *names: str, base: Iterable[str] = ("trio", "anyio")
) -> list[AttributeCall]:
+ """Check if a with statement has a matching call, returning a list with matches.
+
+ `names` specify the names of functions to match, `base` specifies the
+ library/module(s) the function must be in.
+ The list elements in the return value are named tuples with the matched node,
+ base and function.
+
+ Examples_
+
+ `with_has_call(node, "bar", base="foo")` matches foo.bar.
+ `with_has_call(node, "bar", "bee", base=("foo", "a.b.c")` matches
+ `foo.bar`, `foo.bee`, `a.b.c.bar`, and `a.b.c.bee`.
+
+ """
if isinstance(base, str):
base = (base,) # pragma: no cover
- for b in base:
- if b.count(".") > 1: # pragma: no cover
- raise NotImplementedError("Does not support 3-module bases atm.")
+ # build matcher, using SaveMatchedNode to save the base and the function name.
+ matcher = m.Call(
+ func=m.Attribute(
+ value=m.SaveMatchedNode(
+ m.OneOf(*(build_cst_matcher(b) for b in base)), name="base"
+ ),
+ attr=m.SaveMatchedNode(
+ oneof_names(*names),
+ name="function",
+ ),
+ )
+ )
res_list: list[AttributeCall] = []
for item in node.items:
- if res := m.extract(
- item.item,
- m.Call(
- func=m.Attribute(
- value=m.SaveMatchedNode(m.Name() | m.Attribute(), name="library"),
- attr=m.SaveMatchedNode(
- oneof_names(*names),
- name="function",
- ),
- )
- ),
- ):
+ if res := m.extract(item.item, matcher):
assert isinstance(item.item, cst.Call)
- assert isinstance(res["library"], (cst.Name, cst.Attribute))
+ assert isinstance(res["base"], (cst.Name, cst.Attribute))
assert isinstance(res["function"], cst.Name)
- library_node = res["library"]
- for library_str in base:
- if (
- isinstance(library_node, cst.Name)
- and library_str == library_node.value
- ):
- break
- if (
- isinstance(library_node, cst.Attribute)
- and isinstance(library_node.value, cst.Name)
- and "." in library_str
- ):
- base_1, base_2 = library_str.split(".")
- if (
- library_node.attr.value == base_2
- and library_node.value.value == base_1
- ):
- break
- else:
- continue
res_list.append(
- AttributeCall(item.item, library_str, res["function"].value)
+ AttributeCall(
+ item.item, identifier_to_string(res["base"]), res["function"].value
+ )
)
return res_list
diff --git a/tests/eval_files/async912.py b/tests/eval_files/async912.py
index 3ebbc67..c2abf04 100644
--- a/tests/eval_files/async912.py
+++ b/tests/eval_files/async912.py
@@ -7,6 +7,8 @@
# of not testing both in the same file, or running with NOAUTOFIX.
# NOAUTOFIX
+from typing import TypeVar
+
import trio
@@ -27,6 +29,7 @@ async def foo():
with trio.CancelScope(0.1): # ASYNC100: 9, "trio", "CancelScope"
...
+ # conditional cases trigger ASYNC912
with trio.move_on_after(0.1): # ASYNC912: 9
if bar():
await trio.lowlevel.checkpoint()
@@ -51,16 +54,23 @@ async def foo():
with open(""):
...
+ # don't error with guaranteed checkpoint
with trio.move_on_after(0.1):
await trio.lowlevel.checkpoint()
+ with trio.move_on_after(0.1):
+ if bar():
+ await trio.lowlevel.checkpoint()
+ else:
+ await trio.lowlevel.checkpoint()
+ # both scopes error in nested cases
with trio.move_on_after(0.1): # ASYNC912: 9
with trio.move_on_after(0.1): # ASYNC912: 13
if bar():
await trio.lowlevel.checkpoint()
# We don't know which cancelscope will trigger first, so to avoid false
- # positives on tricky-but-valid cases we don't raise any error for the outer one.
+ # alarms on tricky-but-valid cases we don't raise any error for the outer one.
with trio.move_on_after(0.1):
with trio.move_on_after(0.1):
await trio.lowlevel.checkpoint()
@@ -75,6 +85,7 @@ async def foo():
await trio.lowlevel.checkpoint()
await trio.lowlevel.checkpoint()
+ # check correct line gives error
# fmt: off
with (
# a
@@ -94,6 +105,7 @@ async def foo():
await trio.lowlevel.checkpoint()
# fmt: on
+ # error on each call with multiple matching calls in the same with
with (
trio.move_on_after(0.1), # ASYNC912: 8
trio.fail_at(5), # ASYNC912: 8
@@ -101,6 +113,21 @@ async def foo():
if bar():
await trio.lowlevel.checkpoint()
+ # wrapped calls do not raise errors
+ T = TypeVar("T")
+
+ def customWrapper(a: T) -> T:
+ return a
+
+ with customWrapper(trio.fail_at(10)):
+ ...
+ with (res := trio.fail_at(10)):
+ ...
+ # but saving with `as` does
+ with trio.fail_at(10) as res: # ASYNC912: 9
+ if bar():
+ await trio.lowlevel.checkpoint()
+
# TODO: issue #240
async def livelocks():
diff --git a/tests/eval_files/async912_asyncio.py b/tests/eval_files/async912_asyncio.py
index 5a4227e..ef9200b 100644
--- a/tests/eval_files/async912_asyncio.py
+++ b/tests/eval_files/async912_asyncio.py
@@ -11,11 +11,16 @@
import asyncio
+from typing import Any
+
def bar() -> bool:
return False
+def customWrapper(a: object) -> object: ...
+
+
async def foo():
# async100
async with asyncio.timeout(10): # ASYNC100: 15, "asyncio", "timeout"
@@ -51,3 +56,21 @@ async def foo():
async with asyncio.timeouts.timeout_at(10): # ASYNC912: 15
if bar():
await foo()
+
+ # double check that helper methods used by visitor don't trigger erroneously
+ timeouts: Any
+ timeout_at: Any
+ async with asyncio.timeout_at.timeouts(10):
+ ...
+ async with timeouts.asyncio.timeout_at(10):
+ ...
+ async with timeouts.timeout_at.asyncio(10):
+ ...
+ async with timeout_at.asyncio.timeouts(10):
+ ...
+ async with timeout_at.timeouts.asyncio(10):
+ ...
+ async with foo.timeout(10):
+ ...
+ async with asyncio.timeouts(10):
+ ...