diff --git a/src/pymend/file_parser.py b/src/pymend/file_parser.py index c62997f..3dc4a39 100644 --- a/src/pymend/file_parser.py +++ b/src/pymend/file_parser.py @@ -74,7 +74,7 @@ def __init__( start_node : Union[ast.FunctionDef, ast.AsyncFunctionDef] Node to start traversal from. """ - self.name = start_node.name + self.start_node = start_node self.returns: set[tuple[str, ...]] = set() self.returns_value = False self.yields: set[tuple[str, ...]] = set() @@ -108,10 +108,9 @@ def _visit_FunctionDef(self, node: ast.FunctionDef) -> None: # noqa: N802 # py node : ast.FunctionDef Current node in the traversal. """ - nested_function = self._inside_nested_function - self._inside_nested_function += int(nested_function) + self._inside_nested_function += 0 if node is self.start_node else 1 self._generic_visit(node) - self._inside_nested_function -= int(nested_function) + self._inside_nested_function -= 0 if node is self.start_node else 1 def _visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: # noqa: N802 # pylint: disable=invalid-name """Keep track of nested function depth. @@ -121,10 +120,9 @@ def _visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: # noqa: node : ast.AsyncFunctionDef Current node in the traversal. """ - nested_function = self._inside_nested_function - self._inside_nested_function += int(nested_function) + self._inside_nested_function += 0 if node is self.start_node else 1 self._generic_visit(node) - self._inside_nested_function -= int(nested_function) + self._inside_nested_function -= 0 if node is self.start_node else 1 def _visit_Return(self, node: ast.Return) -> None: # noqa: N802 # pylint: disable=invalid-name """Do not process returns from nested functions. diff --git a/tests/test_pymend/refs/returns.py b/tests/test_pymend/refs/returns.py index 9718b1b..7c384b8 100644 --- a/tests/test_pymend/refs/returns.py +++ b/tests/test_pymend/refs/returns.py @@ -76,3 +76,11 @@ def my_multi_return_func() -> Tuple[int, str, bool]: Some bool """ pass + +def nested_function(): + """_summary_.""" + def nested_function1(): + """_summary_.""" + def nested_function2(): + """_summary_.""" + return 3 diff --git a/tests/test_pymend/refs/returns.py.patch.numpydoc.expected b/tests/test_pymend/refs/returns.py.patch.numpydoc.expected index 1891963..3942860 100644 --- a/tests/test_pymend/refs/returns.py.patch.numpydoc.expected +++ b/tests/test_pymend/refs/returns.py.patch.numpydoc.expected @@ -1,7 +1,7 @@ -# Patch generated by Pymend v0.4.0dev +# Patch generated by Pymend v1.1.0 ---- a/returns.py -+++ b/returns.py +--- a/tests\test_pymend\refs\returns.py ++++ b/tests\test_pymend\refs\returns.py @@ -2,11 +2,16 @@ def my_func(param0, param01: int, param1: str = "Some value", param2: List[str] = {}): """_summary_. @@ -91,3 +91,16 @@ Some integer y : str Some string +@@ -82,5 +103,11 @@ + def nested_function1(): + """_summary_.""" + def nested_function2(): +- """_summary_.""" ++ """_summary_. ++ ++ Returns ++ ------- ++ _type_ ++ _description_ ++ """ + return 3 diff --git a/tests/test_pymend/test_numpyoutput.py b/tests/test_pymend/test_numpyoutput.py index 16775dc..9f62bad 100644 --- a/tests/test_pymend/test_numpyoutput.py +++ b/tests/test_pymend/test_numpyoutput.py @@ -77,67 +77,77 @@ def check_expected_diff(test_name: str) -> None: assert remove_diff_header(result) == remove_diff_header(expected) -class TestNumpyOutput: - """Integration tests for numpy style output.""" - - def test_positional_only_identifier(self) -> None: - """Make sure that '/' is parsed correctly in signatures.""" - check_expected_diff("positional_only") - - def test_keyword_only_identifier(self) -> None: - """Make sure that '*' is parsed correctly in signatures.""" - check_expected_diff("keyword_only") - - def test_returns(self) -> None: - """Make sure single and multi return values are parsed/produced correctly.""" - check_expected_diff("returns") - - def test_star_args(self) -> None: - """Make sure that *args are treated correctly.""" - check_expected_diff("star_args") - - def test_starstar_kwargs(self) -> None: - """Make sure that **kwargs are treated correctly.""" - check_expected_diff("star_star_kwargs") - - def test_module_doc_dot(self) -> None: - """Make sure missing '.' are added to the first line of module docstring.""" - check_expected_diff("module_dot_missing") - - def test_ast_ref(self) -> None: - """Bunch of different stuff.""" - check_expected_diff("ast_ref") - - def test_yields(self) -> None: - """Make sure yields are handled correctly from body.""" - check_expected_diff("yields") - - def test_raises(self) -> None: - """Make sure raises are handled correctly from body.""" - check_expected_diff("raises") - - def test_skip_overload(self) -> None: - """Function annotated with @overload should be skipped for DS creation.""" - check_expected_diff("skip_overload_decorator") - - def test_class_body(self) -> None: - """Correctly parse and compose class from body information.""" - check_expected_diff("class_body") - - def test_quote_default(self) -> None: - """Test that default values of triple quotes do not cause issues.""" - check_expected_diff("quote_default") - - def test_blank_lines(self) -> None: - """Test that blank lines are set correctly.""" - expected = get_expected_patch("blank_lines.py.patch.numpydoc.expected") - comment = pym.PyComment( - absdir("refs/blank_lines.py"), - fixer_settings=FixerSettings(force_params=False), - ) - result = "".join(comment._docstring_diff()) - assert remove_diff_header(result) == remove_diff_header(expected) - - def test_comments_after_docstring(self) -> None: - """Test that comments after the last line are not removed.""" - check_expected_diff("comments_after_docstring") +def test_positional_only_identifier() -> None: + """Make sure that '/' is parsed correctly in signatures.""" + check_expected_diff("positional_only") + + +def test_keyword_only_identifier() -> None: + """Make sure that '*' is parsed correctly in signatures.""" + check_expected_diff("keyword_only") + + +def test_returns() -> None: + """Make sure single and multi return values are parsed/produced correctly.""" + check_expected_diff("returns") + + +def test_star_args() -> None: + """Make sure that *args are treated correctly.""" + check_expected_diff("star_args") + + +def test_starstar_kwargs() -> None: + """Make sure that **kwargs are treated correctly.""" + check_expected_diff("star_star_kwargs") + + +def test_module_doc_dot() -> None: + """Make sure missing '.' are added to the first line of module docstring.""" + check_expected_diff("module_dot_missing") + + +def test_ast_ref() -> None: + """Bunch of different stuff.""" + check_expected_diff("ast_ref") + + +def test_yields() -> None: + """Make sure yields are handled correctly from body.""" + check_expected_diff("yields") + + +def test_raises() -> None: + """Make sure raises are handled correctly from body.""" + check_expected_diff("raises") + + +def test_skip_overload() -> None: + """Function annotated with @overload should be skipped for DS creation.""" + check_expected_diff("skip_overload_decorator") + + +def test_class_body() -> None: + """Correctly parse and compose class from body information.""" + check_expected_diff("class_body") + + +def test_quote_default() -> None: + """Test that default values of triple quotes do not cause issues.""" + check_expected_diff("quote_default") + + +def test_blank_lines() -> None: + """Test that blank lines are set correctly.""" + expected = get_expected_patch("blank_lines.py.patch.numpydoc.expected") + comment = pym.PyComment( + absdir("refs/blank_lines.py"), + fixer_settings=FixerSettings(force_params=False), + ) + result = "".join(comment._docstring_diff()) + assert remove_diff_header(result) == remove_diff_header(expected) + + +def test_comments_after_docstring() -> None: + """Test that comments after the last line are not removed.""" + check_expected_diff("comments_after_docstring")