Skip to content

Commit

Permalink
use matcher for matching instead of custom logic
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl committed May 10, 2024
1 parent 4720f85 commit df403c3
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 38 deletions.
87 changes: 50 additions & 37 deletions flake8_async/visitors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,55 +323,68 @@ class AttributeCall(NamedTuple):
function: str


# the custom __or__ in libcst breaks pyright type checking. It's possible to use
# `Union` as a workaround ... except pyupgrade will automatically replace that.
# So we have to resort to specifying one of the base classes.
# See https://github.com/Instagram/LibCST/issues/1143
def build_cst_matcher(attr: str) -> m.BaseExpression:
"""Build a cst matcher structure with attributes&names matching a string `a.b.c`."""
if "." not in attr:
return m.Name(value=attr)
body, tail = attr.rsplit(".")
return m.Attribute(value=build_cst_matcher(body), attr=m.Name(value=tail))


def identifier_to_string(attr: cst.Name | cst.Attribute) -> str:
if isinstance(attr, cst.Name):
return attr.value
assert isinstance(attr.value, (cst.Attribute, cst.Name))
return identifier_to_string(attr.value) + "." + attr.attr.value


def with_has_call(
node: cst.With, *names: str, base: Iterable[str] = ("trio", "anyio")
) -> list[AttributeCall]:
"""Check if a with statement has a matching call, returning a list with matches.
`names` specify the names of functions to match, `base` specifies the
library/module(s) the function must be in.
The list elements in the return value are named tuples with the matched node,
base and function.
Examples_
`with_has_call(node, "bar", base="foo")` matches foo.bar.
`with_has_call(node, "bar", "bee", base=("foo", "a.b.c")` matches
`foo.bar`, `foo.bee`, `a.b.c.bar`, and `a.b.c.bee`.
"""
if isinstance(base, str):
base = (base,) # pragma: no cover

for b in base:
if b.count(".") > 1: # pragma: no cover
raise NotImplementedError("Does not support 3-module bases atm.")
# build matcher, using SaveMatchedNode to save the base and the function name.
matcher = m.Call(
func=m.Attribute(
value=m.SaveMatchedNode(
m.OneOf(*(build_cst_matcher(b) for b in base)), name="base"
),
attr=m.SaveMatchedNode(
oneof_names(*names),
name="function",
),
)
)

res_list: list[AttributeCall] = []
for item in node.items:
if res := m.extract(
item.item,
m.Call(
func=m.Attribute(
value=m.SaveMatchedNode(m.Name() | m.Attribute(), name="library"),
attr=m.SaveMatchedNode(
oneof_names(*names),
name="function",
),
)
),
):
if res := m.extract(item.item, matcher):
assert isinstance(item.item, cst.Call)
assert isinstance(res["library"], (cst.Name, cst.Attribute))
assert isinstance(res["base"], (cst.Name, cst.Attribute))
assert isinstance(res["function"], cst.Name)
library_node = res["library"]
for library_str in base:
if (
isinstance(library_node, cst.Name)
and library_str == library_node.value
):
break
if (
isinstance(library_node, cst.Attribute)
and isinstance(library_node.value, cst.Name)
and "." in library_str
):
base_1, base_2 = library_str.split(".")
if (
library_node.attr.value == base_2
and library_node.value.value == base_1
):
break
else:
continue
res_list.append(
AttributeCall(item.item, library_str, res["function"].value)
AttributeCall(
item.item, identifier_to_string(res["base"]), res["function"].value
)
)
return res_list

Expand Down
29 changes: 28 additions & 1 deletion tests/eval_files/async912.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# of not testing both in the same file, or running with NOAUTOFIX.
# NOAUTOFIX

from typing import TypeVar

import trio


Expand All @@ -27,6 +29,7 @@ async def foo():
with trio.CancelScope(0.1): # ASYNC100: 9, "trio", "CancelScope"
...

# conditional cases trigger ASYNC912
with trio.move_on_after(0.1): # ASYNC912: 9
if bar():
await trio.lowlevel.checkpoint()
Expand All @@ -51,16 +54,23 @@ async def foo():
with open(""):
...

# don't error with guaranteed checkpoint
with trio.move_on_after(0.1):
await trio.lowlevel.checkpoint()
with trio.move_on_after(0.1):
if bar():
await trio.lowlevel.checkpoint()
else:
await trio.lowlevel.checkpoint()

# both scopes error in nested cases
with trio.move_on_after(0.1): # ASYNC912: 9
with trio.move_on_after(0.1): # ASYNC912: 13
if bar():
await trio.lowlevel.checkpoint()

# We don't know which cancelscope will trigger first, so to avoid false
# positives on tricky-but-valid cases we don't raise any error for the outer one.
# alarms on tricky-but-valid cases we don't raise any error for the outer one.
with trio.move_on_after(0.1):
with trio.move_on_after(0.1):
await trio.lowlevel.checkpoint()
Expand All @@ -75,6 +85,7 @@ async def foo():
await trio.lowlevel.checkpoint()
await trio.lowlevel.checkpoint()

# check correct line gives error
# fmt: off
with (
# a
Expand All @@ -94,13 +105,29 @@ async def foo():
await trio.lowlevel.checkpoint()
# fmt: on

# error on each call with multiple matching calls in the same with
with (
trio.move_on_after(0.1), # ASYNC912: 8
trio.fail_at(5), # ASYNC912: 8
):
if bar():
await trio.lowlevel.checkpoint()

# wrapped calls do not raise errors
T = TypeVar("T")

def customWrapper(a: T) -> T:
return a

with customWrapper(trio.fail_at(10)):
...
with (res := trio.fail_at(10)):
...
# but saving with `as` does
with trio.fail_at(10) as res: # ASYNC912: 9
if bar():
await trio.lowlevel.checkpoint()


# TODO: issue #240
async def livelocks():
Expand Down
23 changes: 23 additions & 0 deletions tests/eval_files/async912_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@

import asyncio

from typing import Any


def bar() -> bool:
return False


def customWrapper(a: object) -> object: ...


async def foo():
# async100
async with asyncio.timeout(10): # ASYNC100: 15, "asyncio", "timeout"
Expand Down Expand Up @@ -51,3 +56,21 @@ async def foo():
async with asyncio.timeouts.timeout_at(10): # ASYNC912: 15
if bar():
await foo()

# double check that helper methods used by visitor don't trigger erroneously
timeouts: Any
timeout_at: Any
async with asyncio.timeout_at.timeouts(10):
...
async with timeouts.asyncio.timeout_at(10):
...
async with timeouts.timeout_at.asyncio(10):
...
async with timeout_at.asyncio.timeouts(10):
...
async with timeout_at.timeouts.asyncio(10):
...
async with foo.timeout(10):
...
async with asyncio.timeouts(10):
...

0 comments on commit df403c3

Please sign in to comment.