diff --git a/flake8_trio.py b/flake8_trio.py index 4b94b70..d51655a 100644 --- a/flake8_trio.py +++ b/flake8_trio.py @@ -11,12 +11,11 @@ from __future__ import annotations -import argparse import ast import keyword import tokenize -from argparse import Namespace -from collections.abc import Iterable, Sequence +from argparse import ArgumentTypeError, Namespace +from collections.abc import Iterable from fnmatch import fnmatch from typing import Any, NamedTuple, Union, cast @@ -1427,46 +1426,33 @@ def visit_Call(self, node: ast.Call): self.error(node, key, blocking_calls[key]) -class ListOfIdentifiers(argparse.Action): - def __call__( - self, - parser: argparse.ArgumentParser, - namespace: argparse.Namespace, - values: Sequence[str] | None, - option_string: str | None = None, - ): - assert values is not None - assert option_string is not None - for value in values: - if keyword.iskeyword(value) or not value.isidentifier(): - raise argparse.ArgumentError( - self, f"{value!r} is not a valid method identifier" - ) - setattr(namespace, self.dest, values) - - -class ParseDict(argparse.Action): - def __call__( - self, - parser: argparse.ArgumentParser, - namespace: argparse.Namespace, - values: Sequence[str] | None, - option_string: str | None = None, - ): - res: dict[str, str] = {} - splitter = "->" # avoid ":" because it's part of .ini file syntax - assert values is not None - for value in values: - split_values = list(map(str.strip, value.split(splitter))) - if len(split_values) != 2: - raise argparse.ArgumentError( - self, - f"Invalid number ({len(split_values)-1}) of splitter " - f"tokens {splitter!r} in {value!r}", - ) - res[split_values[0]] = split_values[1] - - setattr(namespace, self.dest, res) +# flake8 ignores type parameters if using comma_separated_list +# so we need to reimplement that ourselves if we want to use "type" +# to check values +def parse_trio114_identifiers(raw_value: str) -> list[str]: + values = [s.strip() for s in raw_value.split(",") if s.strip()] + for value in values: + if keyword.iskeyword(value) or not value.isidentifier(): + raise ArgumentTypeError(f"{value!r} is not a valid method identifier") + return values + + +def parse_trio200_dict(raw_value: str) -> dict[str, str]: + res: dict[str, str] = {} + splitter = "->" # avoid ":" because it's part of .ini file syntax + values = [s.strip() for s in raw_value.split(",") if s.strip()] + + for value in values: + split_values = list(map(str.strip, value.split(splitter))) + if len(split_values) != 2: + # argparse will eat this error message and spit out it's own + # if we raise it as ValueError + raise ArgumentTypeError( + f"Invalid number ({len(split_values)-1}) of splitter " + f"tokens {splitter!r} in {value!r}", + ) + res[split_values[0]] = split_values[1] + return res class Plugin: @@ -1484,15 +1470,6 @@ def from_filename(cls, filename: str) -> Plugin: return cls(ast.parse(source)) def run(self) -> Iterable[Error]: - # temporary workaround, since the Action does not seem to be called properly - # by flake8 when parsing from config - if isinstance(self.options.trio200_blocking_calls, list): - ParseDict([""], dest="trio200_blocking_calls")( - None, # type: ignore - self.options, - self.options.trio200_blocking_calls, # type: ignore - None, - ) yield from Flake8TrioRunner.run(self._tree, self.options) @staticmethod @@ -1513,11 +1490,10 @@ def add_options(option_manager: OptionManager): ) option_manager.add_option( "--startable-in-context-manager", + type=parse_trio114_identifiers, default="", parse_from_config=True, required=False, - comma_separated_list=True, - action=ListOfIdentifiers, help=( "Comma-separated list of method calls to additionally enable TRIO113 " "warnings for. Will also check for the pattern inside function calls. " @@ -1529,11 +1505,10 @@ def add_options(option_manager: OptionManager): ) option_manager.add_option( "--trio200-blocking-calls", + type=parse_trio200_dict, default={}, parse_from_config=True, required=False, - comma_separated_list=True, - action=ParseDict, help=( "Comma-separated list of key:value pairs, where key is a [dotted] " "function that if found inside an async function will raise TRIO200, " diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py index c2c99d7..33cd06b 100644 --- a/tests/test_flake8_trio.py +++ b/tests/test_flake8_trio.py @@ -80,6 +80,9 @@ def test_eval(test: str, path: str): lines = file.readlines() for lineno, line in enumerate(lines, start=1): + # interpret '\n' in comments as actual newlines + line = line.replace("\\n", "\n") + line = line.strip() # add other error codes to check if #INCLUDE is specified @@ -90,9 +93,7 @@ def test_eval(test: str, path: str): # add command-line args if specified with #ARGS elif reg_match := re.search(r"(?<=ARGS).*", line): - for arg in reg_match.group().split(" "): - if arg.strip(): - parsed_args.append(arg.strip()) + parsed_args.append(reg_match.group().strip()) # skip commented out lines if not line or line[0] == "#": @@ -444,16 +445,17 @@ def test_200_options(capsys: pytest.CaptureFixture[str]): om.parse_args(args=[f"--trio200-blocking-calls={arg}"]) ) out, err = capsys.readouterr() - assert not out + assert "" == out assert all(word in err for word in (str(i), arg, "->")) -def test_from_config_file(tmp_path: Path): +def _test_trio200_from_config_common(tmp_path: Path) -> str: tmp_path.joinpath(".flake8").write_text( """ [flake8] trio200-blocking-calls = - sync_fns.*->the_async_equivalent, + other -> async, + sync_fns.* -> the_async_equivalent, select = TRIO200 """ ) @@ -465,12 +467,42 @@ async def foo(): sync_fns.takes_a_long_time() """ ) + return ( + "./example.py:5:5: TRIO200: User-configured blocking sync call sync_fns.* " + "in async function, consider replacing with the_async_equivalent.\n" + ) + + +def test_200_from_config_flake8_internals( + tmp_path: Path, capsys: pytest.CaptureFixture[str] +): + # abuse flake8 internals to avoid having to use subprocess + # which breaks breakpoints and hinders debugging + # TODO: fixture (?) to change working directory + + err_msg = _test_trio200_from_config_common(tmp_path) + # replace ./ with tmp_path/ + err_msg = str(tmp_path) + err_msg[1:] + + from flake8.main.cli import main + + main( + argv=[ + str(tmp_path / "example.py"), + "--append-config", + str(tmp_path / ".flake8"), + ] + ) + out, err = capsys.readouterr() + assert not err + assert err_msg == out + + +def test_200_from_config_subprocess(tmp_path: Path): + err_msg = _test_trio200_from_config_common(tmp_path) res = subprocess.run(["flake8"], cwd=tmp_path, capture_output=True) assert not res.stderr - assert res.stdout == ( - b"./example.py:5:5: TRIO200: User-configured blocking sync call sync_fns.* " - b"in async function, consider replacing with the_async_equivalent.\n" - ) + assert res.stdout == bytes(err_msg, "ascii") @pytest.mark.fuzz diff --git a/tests/trio200.py b/tests/trio200.py index 35cb48b..1462662 100644 --- a/tests/trio200.py +++ b/tests/trio200.py @@ -1,5 +1,7 @@ # specify command-line arguments to be used when testing this file. -# ARGS --trio200-blocking-calls=bar->BAR,bee->SHOULD_NOT_BE_PRINTED,bonnet->SHOULD_NOT_BE_PRINTED,bee.bonnet->BEEBONNET,*.postwild->POSTWILD,prewild.*->PREWILD,*.*.*->TRIPLEDOT +# Test spaces in options, and trailing comma +# Cannot test newlines, since argparse splits on those if passed on the CLI +# ARGS --trio200-blocking-calls=bar -> BAR, bee-> SHOULD_NOT_BE_PRINTED,bonnet ->SHOULD_NOT_BE_PRINTED,bee.bonnet->BEEBONNET,*.postwild->POSTWILD,prewild.*->PREWILD,*.*.*->TRIPLEDOT, # don't error in sync function def foo(): diff --git a/typings/flake8/main/cli.pyi b/typings/flake8/main/cli.pyi new file mode 100644 index 0000000..f454868 --- /dev/null +++ b/typings/flake8/main/cli.pyi @@ -0,0 +1,18 @@ +""" +This type stub file was generated by pyright. +""" + +from collections.abc import Sequence + +"""Command-line implementation of flake8.""" + +def main(argv: Sequence[str] | None = ...) -> int: + """Execute the main bit of the application. + + This handles the creation of an instance of :class:`Application`, runs it, + and then exits the application. + + :param argv: + The arguments to be passed to the application for parsing. + """ + ...