Skip to content

Commit

Permalink
- fixed not looking up a method up the class hierarchy;
Browse files Browse the repository at this point in the history
  • Loading branch information
jaltmayerpizzorno committed Aug 28, 2024
1 parent 25c6b2e commit 2030a23
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 11 deletions.
42 changes: 32 additions & 10 deletions src/coverup/codeinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import importlib.util

_debug = lambda x: None
#_debug = print


# TODO use 'ast' alternative that retains comments?
Expand Down Expand Up @@ -65,9 +64,21 @@ def _load_module(module_name: str) -> ast.Module | None:
return None


def _auto_stack(func):
"""Decorator that adds a stack of the first argument of the function being called."""
def helper(*args):
helper.stack.append(args[0])
_debug(f"{'.'.join(getattr(n, 'name', '?') for n in helper.stack)}")
retval = func(*args)
helper.stack.pop()
return retval
helper.stack = []
return helper


def _find_name_path(module: ast.Module, name: T.List[str], *, paths_seen: T.Set[Path] = None) -> T.List[ast.AST]:
"""Looks for a class or function by name, returning the "path" of ast.ClassDef modules crossed
to find it. If an `import` is found for the sought, it is returned instead.
"""Looks for a symbol's definition by its name, returning the "path" of ast.ClassDef, ast.Import, etc.,
crossed to find it.
"""
_debug(f"looking up {name} in {module.path}")

Expand All @@ -81,25 +92,36 @@ def transition(node: ast.Import | ast.ImportFrom, alias: ast.alias, mod: ast.Mod
imp.names = [alias]
return [imp, mod]

@_auto_stack
def find_name(node: ast.AST, name: T.List[str]) -> T.List[ast.AST]:
_debug(f"_find_name {name} in {ast.dump(node)}")
if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
if node.name == name[0]:
if len(name) == 1:
return [node]

if isinstance(node, ast.ClassDef):
for c in ast.iter_child_nodes(node):
for c in node.body:
_debug(f"{node.name} checking {ast.dump(c)}")
if (path := find_name(c, name[1:])):
return [node, *path]

for base in node.bases:
if (len(find_name.stack) > 1 and
isinstance(context := find_name.stack[-2], ast.ClassDef)):
if (base_path := find_name(context, [context.name, base.id, *name[1:]])):
return base_path[1:]

if (path := find_name(module, [base.id, *name[1:]])):
return path

return []

if (isinstance(node, ast.Assign) and
any(isinstance(n, ast.Name) and n.id == name[0] for t in node.targets for n in ast.walk(t))):
return [node] if len(name) == 1 else []

if isinstance(node, ast.Import):
_debug(f"{ast.dump(node)=}")
# import N
# import N.x imports N and N.x
# import a.b as N 'a.b' is renamed 'N'
Expand All @@ -116,11 +138,10 @@ def find_name(node: ast.AST, name: T.List[str]) -> T.List[ast.AST]:
if path := _find_name_path(mod, name[common_prefix:], paths_seen=paths_seen):
return transition(node, alias, mod) + path

if isinstance(node, ast.ImportFrom):
elif isinstance(node, ast.ImportFrom):
# from a.b import N either gets symbol N out of a.b, or imports a.b.N as N
# from a.b import c as N

_debug(f"{ast.dump(node)=}")
for alias in node.names:
if (alias.asname if alias.asname else alias.name) == name[0]:
modname = _resolve_from_import(module.path, node)
Expand All @@ -134,9 +155,10 @@ def find_name(node: ast.AST, name: T.List[str]) -> T.List[ast.AST]:
(path := _find_name_path(mod, name[1:], paths_seen=paths_seen)):
return transition(node, alias, mod) + path

for c in ast.iter_child_nodes(node):
if (path := find_name(c, name)):
return path
elif not isinstance(node, (ast.Expression, ast.Expr, ast.Name)):
for c in ast.iter_child_nodes(node):
if (path := find_name(c, name)):
return path

return []

Expand Down
150 changes: 149 additions & 1 deletion tests/test_codeinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import coverup.codeinfo as codeinfo


codeinfo._debug = print # enables debugging


@pytest.fixture
def importlib_cleanup():
import importlib
Expand Down Expand Up @@ -145,7 +148,6 @@ def d(self):

tree = codeinfo.parse_file(code)

print(codeinfo.get_info(tree, 'C'))
assert codeinfo.get_info(tree, 'C') == textwrap.dedent("""\
```python
class C:
Expand Down Expand Up @@ -220,6 +222,113 @@ def foo():
assert codeinfo.get_info(tree, 'foo') == None


def test_get_info_method_from_parent(import_fixture):
tmp_path = import_fixture

code = tmp_path / "foo.py"
code.write_text(textwrap.dedent("""\
class A:
def a(self):
return 42
class B(A):
pass
"""
))

tree = codeinfo.parse_file(code)
assert codeinfo.get_info(tree, 'B.a') == textwrap.dedent("""\
```python
class A:
...
def a(self):
return 42
```"""
)


def test_get_info_method_from_parent_imported(import_fixture):
tmp_path = import_fixture

code = tmp_path / "foo.py"
code.write_text(textwrap.dedent("""\
from bar import A
class B(A):
pass
"""
))

(tmp_path / "bar.py").write_text(textwrap.dedent("""\
class A:
def a(self):
return 42
"""
))

tree = codeinfo.parse_file(code)
assert codeinfo.get_info(tree, 'B.a') == textwrap.dedent("""\
in foo.py:
```python
from bar import A
```
in bar.py:
```python
class A:
...
def a(self):
return 42
```"""
)


def test_get_info_method_from_parent_imported_in_class(import_fixture):
tmp_path = import_fixture

code = tmp_path / "foo.py"
code.write_text(textwrap.dedent("""\
class A:
from bar import B
class C(B):
pass
class D(C):
pass
"""
))

(tmp_path / "bar.py").write_text(textwrap.dedent("""\
class B:
def b(self):
return 42
"""
))

# FIXME this would be better if it showed class C(B) and class D(C)
tree = codeinfo.parse_file(code)
assert codeinfo.get_info(tree, 'A.D.b') == textwrap.dedent("""\
in foo.py:
```python
class A:
...
from bar import B
```
in bar.py:
```python
class B:
...
def b(self):
return 42
```"""
)


def test_get_info_assignment(import_fixture):
tmp_path = import_fixture

Expand Down Expand Up @@ -594,6 +703,45 @@ class Baz:
)


@pytest.mark.xfail
def test_get_info_import_in_function(import_fixture):
tmp_path = import_fixture

code = tmp_path / "code.py"
code.write_text(textwrap.dedent("""\
import os
def something():
from foo import Foo
"""
))

(tmp_path / "foo").mkdir()
(tmp_path / "foo" / "__init__.py").write_text(textwrap.dedent("""\
class Foo:
pass
"""
))

# XXX pass context here
tree = codeinfo.parse_file(code)
assert codeinfo.get_info(tree, 'Foo') == textwrap.dedent('''\
in code.py:
```python
def something():
...
from foo import Foo
...
```
in foo/__init__.py:
```python
class Foo:
pass
```'''
)


def test_get_info_import_in_class(import_fixture):
tmp_path = import_fixture

Expand Down

0 comments on commit 2030a23

Please sign in to comment.