diff --git a/src/coverup/codeinfo.py b/src/coverup/codeinfo.py index 18b90ec..fb64f4a 100644 --- a/src/coverup/codeinfo.py +++ b/src/coverup/codeinfo.py @@ -55,7 +55,8 @@ def _resolve_from_import(file: Path, imp: ast.ImportFrom) -> str: def _load_module(module_name: str) -> ast.Module | None: try: - if (spec := importlib.util.find_spec(module_name)) and spec.origin and spec.origin != 'frozen': + if ((spec := importlib.util.find_spec(module_name)) + and spec.origin and spec.origin not in ('frozen', 'built-in')): return parse_file(Path(spec.origin)) except ModuleNotFoundError: @@ -124,7 +125,7 @@ def _find_name_path(module: ast.Module, name: T.List[str], *, paths_seen: T.Set[ crossed to find it. """ if not module: return None - if not name: return None # TODO return module? + if not name: return [module] _debug(f"looking up {name} in {module.path}") @@ -198,6 +199,12 @@ def _summarize(path: T.List[ast.AST]) -> ast.AST: # Leave "__init__" unmodified as it's likely to contain important member information c.body = [ast.Expr(ast.Constant(value=ast.literal_eval("...")))] + elif isinstance(path[-1], ast.Module): + path[-1] = copy.deepcopy(path[-1]) + for c in ast.iter_child_nodes(path[-1]): + if isinstance(c, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)): + c.body = [ast.Expr(ast.Constant(value=ast.literal_eval("...")))] + # now Class objects for i in reversed(range(len(path)-1)): if isinstance(path[i], ast.ClassDef): @@ -334,10 +341,12 @@ def any_import_as_or_import_in_class() -> bool: if any(isinstance(n, ast.Module) for n in path): result = "" - for i in range(len(path)): - if isinstance(path[i], ast.Module): - mod, content = path[i:i+2] - imports = get_global_imports(mod, content) + for i, mod in enumerate(path): + if isinstance(mod, ast.Module): + content = path[i+1] if i < len(path)-1 else mod + # When a module itself is the content, all imports are retained, + # so there's no need to look for them. + imports = get_global_imports(mod, content) if mod != content else [] if result: result += "\n\n" result += f"""\ in {_package_path(mod.path)}: diff --git a/tests/test_codeinfo.py b/tests/test_codeinfo.py index 349e522..44a2e17 100644 --- a/tests/test_codeinfo.py +++ b/tests/test_codeinfo.py @@ -1059,22 +1059,62 @@ def test_get_info_module(import_fixture): """ )) (tmp_path / "bar.py").write_text(textwrap.dedent("""\ + try: + import os + except ImportError: + import os2 + + class A: + def __init__(self): + pass + + class B: + pass + + def foo(a, b): + return a+b + answer = 42 + + if __name__ == '__main__': + import sys + sys.exit(0) """ )) tree = codeinfo.parse_file(code) - assert codeinfo.get_info(tree, 'bar') == None + assert codeinfo.get_info(tree, 'bar') == textwrap.dedent("""\ + in bar.py: + ```python + try: + import os + except ImportError: + import os2 + + class A: + ... + + def foo(a, b): + ... + answer = 42 + if __name__ == '__main__': + import sys + sys.exit(0) + ```""" + ) -def test_get_info_frozen_module(import_fixture): +def test_get_info_frozen_or_builtin_module(import_fixture): tmp_path = import_fixture code = tmp_path / "foo.py" code.write_text(textwrap.dedent("""\ import os + import sys """ )) tree = codeinfo.parse_file(code) assert codeinfo.get_info(tree, 'os.path.join') == None + + assert codeinfo.get_info(tree, 'sys.path') == None