Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle parsing of return,yield,raises in nested functions correctly #152

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions src/pymend/file_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
start_node : Union[ast.FunctionDef, ast.AsyncFunctionDef]
Node to start traversal from.
"""
self.name = start_node.name
self.start_node = start_node
self.returns: set[tuple[str, ...]] = set()
self.returns_value = False
self.yields: set[tuple[str, ...]] = set()
Expand Down Expand Up @@ -108,10 +108,9 @@ def _visit_FunctionDef(self, node: ast.FunctionDef) -> None: # noqa: N802 # py
node : ast.FunctionDef
Current node in the traversal.
"""
nested_function = self._inside_nested_function
self._inside_nested_function += int(nested_function)
self._inside_nested_function += 0 if node is self.start_node else 1
self._generic_visit(node)
self._inside_nested_function -= int(nested_function)
self._inside_nested_function -= 0 if node is self.start_node else 1

def _visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: # noqa: N802 # pylint: disable=invalid-name
"""Keep track of nested function depth.
Expand All @@ -121,10 +120,9 @@ def _visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: # noqa:
node : ast.AsyncFunctionDef
Current node in the traversal.
"""
nested_function = self._inside_nested_function
self._inside_nested_function += int(nested_function)
self._inside_nested_function += 0 if node is self.start_node else 1
self._generic_visit(node)
self._inside_nested_function -= int(nested_function)
self._inside_nested_function -= 0 if node is self.start_node else 1

def _visit_Return(self, node: ast.Return) -> None: # noqa: N802 # pylint: disable=invalid-name
"""Do not process returns from nested functions.
Expand Down
8 changes: 8 additions & 0 deletions tests/test_pymend/refs/returns.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,11 @@ def my_multi_return_func() -> Tuple[int, str, bool]:
Some bool
"""
pass

def nested_function():
"""_summary_."""
def nested_function1():
"""_summary_."""
def nested_function2():
"""_summary_."""
return 3
19 changes: 16 additions & 3 deletions tests/test_pymend/refs/returns.py.patch.numpydoc.expected
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Patch generated by Pymend v0.4.0dev
# Patch generated by Pymend v1.1.0

--- a/returns.py
+++ b/returns.py
--- a/tests\test_pymend\refs\returns.py
+++ b/tests\test_pymend\refs\returns.py
@@ -2,11 +2,16 @@
def my_func(param0, param01: int, param1: str = "Some value", param2: List[str] = {}):
"""_summary_.
Expand Down Expand Up @@ -91,3 +91,16 @@
Some integer
y : str
Some string
@@ -82,5 +103,11 @@
def nested_function1():
"""_summary_."""
def nested_function2():
- """_summary_."""
+ """_summary_.
+
+ Returns
+ -------
+ _type_
+ _description_
+ """
return 3
138 changes: 74 additions & 64 deletions tests/test_pymend/test_numpyoutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,67 +77,77 @@ def check_expected_diff(test_name: str) -> None:
assert remove_diff_header(result) == remove_diff_header(expected)


class TestNumpyOutput:
"""Integration tests for numpy style output."""

def test_positional_only_identifier(self) -> None:
"""Make sure that '/' is parsed correctly in signatures."""
check_expected_diff("positional_only")

def test_keyword_only_identifier(self) -> None:
"""Make sure that '*' is parsed correctly in signatures."""
check_expected_diff("keyword_only")

def test_returns(self) -> None:
"""Make sure single and multi return values are parsed/produced correctly."""
check_expected_diff("returns")

def test_star_args(self) -> None:
"""Make sure that *args are treated correctly."""
check_expected_diff("star_args")

def test_starstar_kwargs(self) -> None:
"""Make sure that **kwargs are treated correctly."""
check_expected_diff("star_star_kwargs")

def test_module_doc_dot(self) -> None:
"""Make sure missing '.' are added to the first line of module docstring."""
check_expected_diff("module_dot_missing")

def test_ast_ref(self) -> None:
"""Bunch of different stuff."""
check_expected_diff("ast_ref")

def test_yields(self) -> None:
"""Make sure yields are handled correctly from body."""
check_expected_diff("yields")

def test_raises(self) -> None:
"""Make sure raises are handled correctly from body."""
check_expected_diff("raises")

def test_skip_overload(self) -> None:
"""Function annotated with @overload should be skipped for DS creation."""
check_expected_diff("skip_overload_decorator")

def test_class_body(self) -> None:
"""Correctly parse and compose class from body information."""
check_expected_diff("class_body")

def test_quote_default(self) -> None:
"""Test that default values of triple quotes do not cause issues."""
check_expected_diff("quote_default")

def test_blank_lines(self) -> None:
"""Test that blank lines are set correctly."""
expected = get_expected_patch("blank_lines.py.patch.numpydoc.expected")
comment = pym.PyComment(
absdir("refs/blank_lines.py"),
fixer_settings=FixerSettings(force_params=False),
)
result = "".join(comment._docstring_diff())
assert remove_diff_header(result) == remove_diff_header(expected)

def test_comments_after_docstring(self) -> None:
"""Test that comments after the last line are not removed."""
check_expected_diff("comments_after_docstring")
def test_positional_only_identifier() -> None:
"""Make sure that '/' is parsed correctly in signatures."""
check_expected_diff("positional_only")


def test_keyword_only_identifier() -> None:
"""Make sure that '*' is parsed correctly in signatures."""
check_expected_diff("keyword_only")


def test_returns() -> None:
"""Make sure single and multi return values are parsed/produced correctly."""
check_expected_diff("returns")


def test_star_args() -> None:
"""Make sure that *args are treated correctly."""
check_expected_diff("star_args")


def test_starstar_kwargs() -> None:
"""Make sure that **kwargs are treated correctly."""
check_expected_diff("star_star_kwargs")


def test_module_doc_dot() -> None:
"""Make sure missing '.' are added to the first line of module docstring."""
check_expected_diff("module_dot_missing")


def test_ast_ref() -> None:
"""Bunch of different stuff."""
check_expected_diff("ast_ref")


def test_yields() -> None:
"""Make sure yields are handled correctly from body."""
check_expected_diff("yields")


def test_raises() -> None:
"""Make sure raises are handled correctly from body."""
check_expected_diff("raises")


def test_skip_overload() -> None:
"""Function annotated with @overload should be skipped for DS creation."""
check_expected_diff("skip_overload_decorator")


def test_class_body() -> None:
"""Correctly parse and compose class from body information."""
check_expected_diff("class_body")


def test_quote_default() -> None:
"""Test that default values of triple quotes do not cause issues."""
check_expected_diff("quote_default")


def test_blank_lines() -> None:
"""Test that blank lines are set correctly."""
expected = get_expected_patch("blank_lines.py.patch.numpydoc.expected")
comment = pym.PyComment(
absdir("refs/blank_lines.py"),
fixer_settings=FixerSettings(force_params=False),
)
result = "".join(comment._docstring_diff())
assert remove_diff_header(result) == remove_diff_header(expected)


def test_comments_after_docstring() -> None:
"""Test that comments after the last line are not removed."""
check_expected_diff("comments_after_docstring")
Loading