diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 93148a1..8cf662a 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -10,6 +10,14 @@ this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ## [Unreleased] +### Fixed + +- [`pass` statements are removed from `orelse` parent nodes causing syntax errors by @hadialqattan](https://github.com/hadialqattan/pycln/pull/100) + +### Changed + +- [In case of `(async)func`/`class` contains docstring, keep only one `pass` statement instead of none by @hadialqattan](https://github.com/hadialqattan/pycln/pull/100) + ## [1.2.1] - 2022-02-24 ### Fixed diff --git a/pycln/utils/refactor.py b/pycln/utils/refactor.py index ed46740..de7bd0c 100644 --- a/pycln/utils/refactor.py +++ b/pycln/utils/refactor.py @@ -4,7 +4,7 @@ from functools import lru_cache from importlib import import_module from pathlib import Path -from typing import List, Optional, Set, Tuple, Union, cast +from typing import Iterable, List, Optional, Set, Tuple, Union, cast from . import iou, pathu, regexu, scan from ._exceptions import ( @@ -84,16 +84,48 @@ def remove_useless_passes(source_lines: List[str]) -> List[str]: :param source_lines: source code lines. :returns: clean source code lines. """ + + def remove_from_children( + parent: ast.AST, children: Iterable, body_len: int, wl: Set[ast.AST] + ): + #: Remove any `ast.Pass` node + #: that is both useless and not in the `wl` (white list). + #: + #: The below case is not going to be touched: + #: + #: >>> (async) (def) (class) foo: + #: >>> """DOCString""" + #: >>> pass + #: + for child in children: + if isinstance(child, ast.Pass): + if child not in wl: + if isinstance( + parent, + (ast.AsyncFunctionDef, ast.FunctionDef, ast.ClassDef), + ): + if body_len == 2 and ast.get_docstring(parent): + break + if body_len > 1: + body_len -= 1 + source_lines[child.lineno - 1] = "" + tree = ast.parse("".join(source_lines)) for parent in ast.walk(tree): + body = getattr(parent, "body", None) if body and hasattr(body, "__len__"): body_len = len(body) - for child in ast.iter_child_nodes(parent): - if isinstance(child, ast.Pass): - if body_len > 1: - body_len -= 1 - source_lines[child.lineno - 1] = "" + white_list: Set[ast.AST] = set() + + if hasattr(parent, "orelse"): + orelse = getattr(parent, "orelse") + remove_from_children(parent, orelse, len(orelse), white_list) + white_list = set(orelse) + + children = ast.iter_child_nodes(parent) + remove_from_children(parent, children, body_len, white_list) + return "".join(source_lines).splitlines(True) def session(self, path: Path) -> None: diff --git a/tests/test_refactor.py b/tests/test_refactor.py index 04c5c74..283a64c 100644 --- a/tests/test_refactor.py +++ b/tests/test_refactor.py @@ -113,15 +113,31 @@ def setup_method(self, method): ), pytest.param( [ + "class Foo:\n", + " '''docs'''\n", + " pass\n", + " pass\n", + "async def foo():\n", + " '''docs'''\n", + " pass\n", + " pass\n", "def foo():\n", " '''docs'''\n", " pass\n", + " pass\n", ], [ + "class Foo:\n", + " '''docs'''\n", + " pass\n", + "async def foo():\n", + " '''docs'''\n", + " pass\n", "def foo():\n", " '''docs'''\n", + " pass\n", ], - id="useless with docs", + id="useless with docstring", ), pytest.param( [ @@ -132,6 +148,52 @@ def setup_method(self, method): ], id="TypeError", ), + pytest.param( + [ + "if True:\n", + " pass\n", + " pass\n", + "else:\n", + " pass\n", + " pass\n", + ], + [ + "if True:\n", + " pass\n", + "else:\n", + " pass\n", + ], + id="orelse parent - else - useless", + ), + pytest.param( + [ + "if True:\n", + " print()\n", + " print()\n", + "else:\n", + " pass\n", + ], + [ + "if True:\n", + " print()\n", + " print()\n", + "else:\n", + " pass\n", + ], + id="orelse parent - else - useful", + ), + pytest.param( + [ + "if True:\n", + " print()\n", + " pass\n", + ], + [ + "if True:\n", + " print()\n", + ], + id="orelse parent - no-else", + ), ], ) def test_remove_useless_passes(self, source_lines, expec_lines):