From 82ce25b1f2dcaf0d7ab3e2f1624848028bdfc0d2 Mon Sep 17 00:00:00 2001 From: Sachaa-Thanasius Date: Tue, 17 Sep 2024 13:36:43 +0530 Subject: [PATCH] Get existing tests passing - Re-add code regions - Adjust some comments and docstrings. - Remove need to set loader state in existing tests. - Run pre-commit autoupdate. --- .pre-commit-config.yaml | 2 +- README.rst | 2 - pyproject.toml | 2 +- src/defer_imports/_comptime.py | 125 ++++++++++++++++++++++++--------- src/defer_imports/_runtime.py | 63 +++++++++++------ src/defer_imports/_typing.py | 4 +- tests/test_deferred.py | 48 +++++-------- 7 files changed, 153 insertions(+), 93 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1b8eb81..cf730c1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: - id: check-yaml - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.4 + rev: v0.6.5 hooks: - id: ruff args: [--fix] diff --git a/README.rst b/README.rst index f2ed759..a424034 100644 --- a/README.rst +++ b/README.rst @@ -31,8 +31,6 @@ This can be installed via pip:: python -m pip install defer-imports -It can also easily be vendored, as it has zero dependencies and has around 1,000 lines of code. - Usage ===== diff --git a/pyproject.toml b/pyproject.toml index f0ffe9b..579a971 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ classifiers = [ dynamic = ["version"] [tool.hatch.version] -path = "src/defer_imports/_compiletime.py" +path = "src/defer_imports/_comptime.py" [project.optional-dependencies] benchmark = ["slothy"] diff --git a/src/defer_imports/_comptime.py b/src/defer_imports/_comptime.py index 049a0e7..eacaf30 100644 --- a/src/defer_imports/_comptime.py +++ b/src/defer_imports/_comptime.py @@ -22,14 +22,9 @@ __version__ = "0.0.2" -StrPath: _tp.TypeAlias = "_tp.Union[str, os.PathLike[str]]" -SourceData: _tp.TypeAlias = "_tp.Union[_tp.ReadableBuffer, str, ast.Module, ast.Expression, ast.Interactive]" - - -TOK_NAME, TOK_OP = tokenize.NAME, tokenize.OP - -BYTECODE_HEADER = f"defer_imports{__version__}".encode() -"""Custom header for defer_imports-instrumented bytecode files. Should be updated with every version release.""" +# ============================================================================ +# region -------- Helper functions -------- +# ============================================================================ def sliding_window(iterable: _tp.Iterable[_tp.T], n: int) -> _tp.Generator[tuple[_tp.T, ...]]: @@ -52,6 +47,24 @@ def sliding_window(iterable: _tp.Iterable[_tp.T], n: int) -> _tp.Generator[tuple yield tuple(window) +# endregion + + +# ============================================================================ +# region -------- Main implementation -------- +# ============================================================================ + + +StrPath: _tp.TypeAlias = "_tp.Union[str, os.PathLike[str]]" +SourceData: _tp.TypeAlias = "_tp.Union[_tp.ReadableBuffer, str, ast.Module, ast.Expression, ast.Interactive]" + + +TOK_NAME, TOK_OP = tokenize.NAME, tokenize.OP + +BYTECODE_HEADER = f"defer_imports{__version__}".encode() +"""Custom header for defer_imports-instrumented bytecode files. Should be updated with every version release.""" + + class DeferredInstrumenter(ast.NodeTransformer): """AST transformer that instruments imports within "with defer_imports.until_use: ..." blocks so that their results are assigned to custom keys in the global namespace. @@ -65,7 +78,7 @@ def __init__( self, data: _tp.Union[_tp.ReadableBuffer, str, ast.AST], filepath: _tp.Union[StrPath, _tp.ReadableBuffer], - encoding: str, + encoding: str = "utf-8", *, defer_state: bool = False, ) -> None: @@ -109,7 +122,7 @@ def _visit_eager_import_block(self, node: ast.AST) -> ast.AST: # endregion - # region ---- until_use-wrapped imports instrumentation ---- + # region ---- Basic instrumentation ---- def _decode_source(self) -> str: """Get the source code corresponding to the given data.""" @@ -195,7 +208,7 @@ def _initialize_local_ns() -> ast.Assign: return ast.Assign( targets=[ast.Name("@local_ns", ctx=ast.Store())], - value=ast.Call(ast.Name("locals", ctx=ast.Load()), args=[], keywords=[]), + value=ast.Call(func=ast.Name("locals", ctx=ast.Load()), args=[], keywords=[]), ) @staticmethod @@ -216,7 +229,7 @@ def _substitute_import_keys(self, import_nodes: list[ast.stmt]) -> list[ast.stmt If any of the given nodes are not an import or are a wildcard import. """ - new_import_nodes: list[ast.stmt] = list(import_nodes) + new_import_nodes = list(import_nodes) for i in range(len(import_nodes) - 1, -1, -1): node = import_nodes[i] @@ -265,7 +278,7 @@ def visit_With(self, node: ast.With) -> ast.AST: """ if not self.check_With_for_defer_usage(node): - self._visit_eager_import_block(node) + return node if self.scope_depth > 0: msg = "with defer_imports.until_use only allowed at module level" @@ -336,7 +349,7 @@ def _is_defer_imports_import(node: _tp.Union[ast.Import, ast.ImportFrom]) -> boo return node.module is not None and node.module.partition(".")[0] == "defer_imports" def _wrap_import_stmts(self, nodes: list[ast.stmt], start: int) -> ast.With: - """Wrap a list of consecutive import nodes from a list of statements in a "defer_imports.until_use" block and + """Wrap a list of consecutive import nodes from a list of statements using a "defer_imports.until_use" block and instrument them. The first node must be guaranteed to be an import node. @@ -366,9 +379,9 @@ def _wrap_import_stmts(self, nodes: list[ast.stmt], start: int) -> ast.With: def generic_visit(self, node: ast.AST) -> ast.AST: """Called if no explicit visitor function exists for a node. - Summary - ------- - Almost a copy of ast.NodeVisitor.generic_vist, but intercepts global sequences of import statements to wrap + Extended Summary + ---------------- + Almost a copy of ast.NodeVisitor.generic_visit, but we intercept global sequences of import statements to wrap them in a "with defer_imports.until_use" block and instrument them. """ @@ -379,10 +392,10 @@ def generic_visit(self, node: ast.AST) -> ast.AST: if ( # Only when global instrumentation is enabled. self.defer_state - # Only with import nodes that we are prepared to handle. - and self._is_regular_import(value) # pyright: ignore [reportUnknownArgumentType] # Only at global scope. and self.scope_depth == 0 + # Only with import nodes that we are prepared to handle. + and self._is_regular_import(value) # pyright: ignore [reportUnknownArgumentType] # Only outside of escape hatch blocks. and (self.escape_hatch_depth == 0 and not self._is_defer_imports_import(value)) ): @@ -444,8 +457,8 @@ class DeferredFileLoader(SourceFileLoader): defer_state: bool def create_module(self, spec: ModuleSpec) -> _tp.Optional[_tp.ModuleType]: - if spec.loader_state is not None: - self.defer_state = spec.loader_state["defer_state"] + # This method should always run before source_to_code in regular circumstances. + self.defer_state = spec.loader_state["defer_state"] if (spec.loader_state is not None) else False return super().create_module(spec) def source_to_code( # pyright: ignore [reportIncompatibleMethodOverride] @@ -455,8 +468,7 @@ def source_to_code( # pyright: ignore [reportIncompatibleMethodOverride] *, _optimize: int = -1, ) -> _tp.CodeType: - # NOTE: InspectLoader is the virtual superclass of SourceFileLoader thanks to ABC registration, so typeshed - # reflects that. However, there's some mismatch in source_to_code signatures. Can it be fixed with a PR? + # NOTE: Sigature of SourceFileLoader.source_to_code at runtime isn't consistent with the version in typeshed. if not data: return super().source_to_code(data, path, _optimize=_optimize) # pyright: ignore # See note above. @@ -475,14 +487,14 @@ def source_to_code( # pyright: ignore [reportIncompatibleMethodOverride] # TODO: This isn't safe when a syntax error occurs since we modify the tree but raise before fixing node # locations. This also makes it difficult to point at the actual source code location. Find a way to # deepcopy the tree first. - orig_ast = data + orig_tree = data else: - orig_ast = ast.parse(data, path, "exec") + orig_tree = ast.parse(data, path, "exec") transformer = DeferredInstrumenter(data, path, encoding, defer_state=self.defer_state) - new_ast = ast.fix_missing_locations(transformer.visit(orig_ast)) + new_tree = ast.fix_missing_locations(transformer.visit(orig_tree)) - return super().source_to_code(new_ast, path, _optimize=_optimize) # pyright: ignore # See note above. + return super().source_to_code(new_tree, path, _optimize=_optimize) # pyright: ignore # See note above. def get_data(self, path: str) -> bytes: """Return the data from path as raw bytes. @@ -536,6 +548,9 @@ def set_data(self, path: str, data: _tp.ReadableBuffer, *, _mode: int = 0o666) - return super().set_data(path, data, _mode=_mode) +LOADER_DETAILS = (DeferredFileLoader, SOURCE_SUFFIXES) + + class DeferredFileFinder(FileFinder): def __init__( self, @@ -549,14 +564,20 @@ def __init__( self.deferred_modules = deferred_modules def find_spec(self, fullname: str, target: _tp.Optional[_tp.ModuleType] = None) -> _tp.Optional[ModuleSpec]: - """Try to find a spec for 'fullname' on sys.path or 'path', with some modifications based on defer state.""" + """Try to find a spec for "fullname" on sys.path or "path", with some modifications based on defer state.""" spec = super().find_spec(fullname, target) if spec is not None and isinstance(spec.loader, DeferredFileLoader): - # It's underdocumented, but spec.loader_state is meant for this kind of thing. + # It's under-documented, but spec.loader_state is meant for this kind of thing. # Ref: https://docs.python.org/3/library/importlib.html#importlib.machinery.ModuleSpec.loader_state # Ref: https://github.com/python/cpython/issues/89527 - defer_state = self.defer_globally or bool(self.deferred_modules and (fullname in self.deferred_modules)) + defer_state = self.defer_globally or bool( + self.deferred_modules + and ( + fullname in self.deferred_modules + or any(module_name.startswith(f"{fullname}.") for module_name in self.deferred_modules) + ) + ) spec.loader_state = {"defer_state": defer_state} return spec @@ -579,7 +600,31 @@ def path_hook_for_DeferredFileFinder(path: str) -> _tp.Self: return path_hook_for_DeferredFileFinder +# endregion + + +# ============================================================================ +# region -------- Public API -------- +# ============================================================================ + + +@_tp.final class ImportHookContext: + """The context manager returned by install_import_hook. + + Parameters + ---------- + path_hook: Callable[[str], PathEntryFinderProtocol] + A path hook to uninstall. Can be uninstalled manually with the uninstall method or automatically upon + exiting the context manager. + + Attributes + ---------- + path_hook: Callable[[str], PathEntryFinderProtocol] + A path hook to uninstall. Can be uninstalled manually with the uninstall method or automatically upon + exiting the context manager. + """ + def __init__(self, path_hook: _tp.Callable[[str], _tp.PathEntryFinderProtocol]) -> None: self.path_hook = path_hook @@ -590,7 +635,7 @@ def __exit__(self, *exc_info: object) -> None: self.uninstall() def uninstall(self) -> None: - """Remove the path hook if it's still in sys.path_hooks.""" + """Attempt to remove the path hook from sys.path_hooks. If successful, also invalidate path entry caches.""" try: sys.path_hooks.remove(self.path_hook) @@ -602,10 +647,20 @@ def uninstall(self) -> None: def install_import_hook(*, is_global: bool = False, module_names: _tp.Sequence[str] = ()) -> ImportHookContext: - """Insert defer_imports's path hook right before the default FileFinder one in sys.path_hooks. + """Insert a custom defer_imports path hook in sys.path_hooks with optional configuration for instrumenting + ALL import statements, not only ones wrapped by the defer_imports.until_use context manager. This should be run before the rest of your code. One place to put it is in __init__.py of your package. + Parameters + ---------- + is_global: bool, default=False + Whether to apply module-level import deferral, i.e. instrumentation of all imports, to all modules henceforth. + Mutually exclusive with and has higher priority than module_names. + module_names: Sequence[str], optional + Whether to apply module-level import deferral to a set of modules and recursively to all of their submodules. + Mutually exclusive with and has lower priority than is_global. + Returns ------- ImportHookContext @@ -613,8 +668,7 @@ def install_import_hook(*, is_global: bool = False, module_names: _tp.Sequence[s automatically by using it as a context manager. """ - loader_details = (DeferredFileLoader, SOURCE_SUFFIXES) - path_hook = DeferredFileFinder.path_hook(loader_details, defer_globally=is_global, deferred_modules=module_names) + path_hook = DeferredFileFinder.path_hook(LOADER_DETAILS, defer_globally=is_global, deferred_modules=module_names) try: hook_insert_index = sys.path_hooks.index(zipimport.zipimporter) + 1 # pyright: ignore [reportArgumentType] @@ -627,3 +681,6 @@ def install_import_hook(*, is_global: bool = False, module_names: _tp.Sequence[s sys.path_hooks.insert(hook_insert_index, path_hook) return ImportHookContext(path_hook) + + +# endregion diff --git a/src/defer_imports/_runtime.py b/src/defer_imports/_runtime.py index cbaea49..f055f90 100644 --- a/src/defer_imports/_runtime.py +++ b/src/defer_imports/_runtime.py @@ -16,11 +16,9 @@ from . import _typing as _tp -original_import = contextvars.ContextVar("original_import", default=builtins.__import__) -"""What builtins.__import__ currently points to.""" - -is_deferred = contextvars.ContextVar("is_deferred", default=False) -"""Whether imports should be deferred.""" +# ============================================================================ +# region -------- Helper functions -------- +# ============================================================================ def sanity_check(name: str, package: _tp.Optional[str], level: int) -> None: @@ -107,6 +105,21 @@ def resolve_name(name: str, package: str, level: int) -> str: return f"{base}.{name}" if name else base +# endregion + + +# ============================================================================ +# region -------- Main implementation -------- +# ============================================================================ + + +original_import = contextvars.ContextVar("original_import", default=builtins.__import__) +"""What builtins.__import__ currently points to.""" + +is_deferred = contextvars.ContextVar("is_deferred", default=False) +"""Whether imports should be deferred.""" + + class DeferredImportProxy: """Proxy for a deferred __import__ call.""" @@ -155,7 +168,7 @@ def __getattr__(self, name: str, /) -> _tp.Self: from_proxy.defer_proxy_fromlist = (name,) return from_proxy - elif name == self.defer_proxy_name.rpartition(".")[2]: + elif ("." in self.defer_proxy_name) and (name == self.defer_proxy_name.rpartition(".")[2]): submodule_proxy = type(self)(*self.defer_proxy_import_args) submodule_proxy.defer_proxy_sub = name return submodule_proxy @@ -171,22 +184,20 @@ class DeferredImportKey(str): When referenced, the key will replace itself in the namespace with the resolved import or the right name from it. """ - __slots__ = ("defer_key_str", "defer_key_proxy", "is_resolving", "lock") + __slots__ = ("defer_key_proxy", "is_resolving", "lock") def __new__(cls, key: str, proxy: DeferredImportProxy, /) -> _tp.Self: return super().__new__(cls, key) def __init__(self, key: str, proxy: DeferredImportProxy, /) -> None: - self.defer_key_str = str(key) self.defer_key_proxy = proxy - self.is_resolving = False self.lock = RLock() def __eq__(self, value: object, /) -> bool: if not isinstance(value, str): return NotImplemented - if self.defer_key_str != value: + if not super().__eq__(value): return False # Only the first thread to grab the lock should resolve the deferred import. @@ -202,17 +213,17 @@ def __eq__(self, value: object, /) -> bool: return True def __hash__(self) -> int: - return hash(self.defer_key_str) + return super().__hash__() def _resolve(self) -> None: """Perform an actual import for the given proxy and bind the result to the relevant namespace.""" proxy = self.defer_key_proxy - # Perform the original __import__ and pray. + # 1. Perform the original __import__ and pray. module: _tp.ModuleType = original_import.get()(*proxy.defer_proxy_import_args) - # Transfer nested proxies over to the resolved module. + # 2. Transfer nested proxies over to the resolved module. module_vars = vars(module) for attr_key, attr_val in vars(proxy).items(): if isinstance(attr_val, DeferredImportProxy) and not hasattr(module, attr_key): @@ -223,21 +234,20 @@ def _resolve(self) -> None: # Change the namespaces as well to make sure nested proxies are replaced in the right place. attr_val.defer_proxy_global_ns = attr_val.defer_proxy_local_ns = module_vars - # Replace the proxy with the resolved module or module attribute in the relevant namespace. - - # 1. Get the regular string key and the relevant namespace. - key = self.defer_key_str + # 3. Replace the proxy with the resolved module or module attribute in the relevant namespace. + # 3.1. Get the regular string key and the relevant namespace. + key = str(self) namespace = proxy.defer_proxy_local_ns - # 2. Replace the deferred version of the key to avoid it sticking around. - # This will trigger __eq__ again, so use is_deferred to prevent recursive resolution. + # 3.2. Replace the deferred version of the key to avoid it sticking around. + # This will trigger __eq__ again, so we use is_deferred to prevent recursion. _is_def_tok = is_deferred.set(True) try: namespace[key] = namespace.pop(key) finally: is_deferred.reset(_is_def_tok) - # 3. Resolve any requested attribute access. + # 3.3. Resolve any requested attribute access. if proxy.defer_proxy_fromlist: namespace[key] = getattr(module, proxy.defer_proxy_fromlist[0]) elif proxy.defer_proxy_sub: @@ -290,6 +300,14 @@ def deferred___import__( # noqa: ANN202 return DeferredImportProxy(name, globals, locals, fromlist, level) +# endregion + + +# ============================================================================ +# region -------- Public API -------- +# ============================================================================ + + @_tp.final class DeferredContext: """The type for defer_imports.until_use.""" @@ -310,7 +328,7 @@ def __exit__(self, *exc_info: object) -> None: until_use: _tp.Final[DeferredContext] = DeferredContext() """A context manager within which imports occur lazily. Not reentrant. -This will not work correctly if install_defer_import_hook() was not called first elsewhere. +This will not work correctly if install_import_hook() was not called first elsewhere. Raises ------ @@ -324,3 +342,6 @@ def __exit__(self, *exc_info: object) -> None: ----- As part of its implementation, this temporarily replaces builtins.__import__. """ + + +# endregion diff --git a/src/defer_imports/_typing.py b/src/defer_imports/_typing.py index 6870a20..02b2714 100644 --- a/src/defer_imports/_typing.py +++ b/src/defer_imports/_typing.py @@ -163,9 +163,9 @@ def find_spec(self, fullname: str, target: ModuleType | None = ..., /) -> Module raise AttributeError(msg) -_original_global_names = tuple(globals()) +_initial_global_names = tuple(globals()) def __dir__() -> list[str]: # This will hopefully make potential debugging easier. - return [*_original_global_names, *__all__] + return [*_initial_global_names, *__all__] diff --git a/tests/test_deferred.py b/tests/test_deferred.py index 6f1bf56..e0c00f2 100644 --- a/tests/test_deferred.py +++ b/tests/test_deferred.py @@ -19,23 +19,18 @@ from defer_imports._comptime import BYTECODE_HEADER, DeferredFileLoader, DeferredInstrumenter -def create_default_defer_loader_state(): - return {"defer_globally": False, "defer_locally": True} - - def create_sample_module(path: Path, source: str, loader_type: type): - """Utility function for creating a sample module with the given path, source code, and loader.""" - - tmp_file = path / "sample.py" - tmp_file.write_text(source, encoding="utf-8") + """Create a sample module with the given path, source code, and loader.""" module_name = "sample" - module_path = tmp_file.resolve() + module_path = path / f"{module_name}.py" + module_path.write_text(source, encoding="utf-8") + module_path = module_path.resolve() loader = loader_type(module_name, str(module_path)) spec = importlib.util.spec_from_file_location(module_name, module_path, loader=loader) assert spec - spec.loader_state = create_default_defer_loader_state() + module = importlib.util.module_from_spec(spec) return spec, module, module_path @@ -57,7 +52,7 @@ def better_key_repr(monkeypatch: pytest.MonkeyPatch): """Replace defer_imports._comptime.DeferredImportKey.__repr__ with a more verbose version for all tests.""" def _verbose_repr(self) -> str: # pyright: ignore # noqa: ANN001 - return f"" # pyright: ignore [reportUnknownMemberType] + return f"" # pyright: ignore monkeypatch.setattr("defer_imports._runtime.DeferredImportKey.__repr__", _verbose_repr) # pyright: ignore [reportUnknownArgumentType] @@ -179,16 +174,13 @@ def test_instrumentation(before: str, after: str): """Test what code is generated by the instrumentation side of defer_imports.""" import ast - import io - import tokenize filename = "" - before_bytes = before.encode() - encoding, _ = tokenize.detect_encoding(io.BytesIO(before_bytes).readline) - tree = ast.parse(before_bytes, filename, "exec") - transformed_tree = ast.fix_missing_locations(DeferredInstrumenter(before_bytes, filename, encoding).visit(tree)) + orig_tree = ast.parse(before, filename, "exec") + transformer = DeferredInstrumenter(before, filename) + new_tree = ast.fix_missing_locations(transformer.visit(orig_tree)) - assert f"{ast.unparse(transformed_tree)}\n" == after + assert f"{ast.unparse(new_tree)}\n" == after @pytest.mark.parametrize( @@ -379,18 +371,16 @@ def do_the_thing(a: int) -> int: ], ) def test_global_instrumentation(before: str, after: str): + """Test what code is generated by the instrumentation side of defer_imports if applied globally.""" + import ast - import io - import tokenize filename = "" - before_bytes = before.encode() - encoding, _ = tokenize.detect_encoding(io.BytesIO(before_bytes).readline) - tree = ast.parse(before_bytes, filename, "exec") - transformer = DeferredInstrumenter(before_bytes, filename, encoding, defer_state=True) - transformed_tree = ast.fix_missing_locations(transformer.visit(tree)) + orig_tree = ast.parse(before, filename, "exec") + transformer = DeferredInstrumenter(before, filename, defer_state=True) + new_tree = ast.fix_missing_locations(transformer.visit(orig_tree)) - assert f"{ast.unparse(transformed_tree)}\n" == after + assert f"{ast.unparse(new_tree)}\n" == after def test_path_hook_installation(): @@ -894,8 +884,6 @@ def __init__(self, val: object): assert spec assert spec.loader - spec.loader_state = create_default_defer_loader_state() - module = importlib.util.module_from_spec(spec) with temp_cache_module(package_name, module): @@ -978,8 +966,6 @@ def Y2(): assert spec assert spec.loader - spec.loader_state = create_default_defer_loader_state() - module = importlib.util.module_from_spec(spec) with temp_cache_module(package_name, module): @@ -1115,8 +1101,6 @@ def test_leaking_patch(tmp_path: Path): # pragma: no cover assert spec assert spec.loader - spec.loader_state = create_default_defer_loader_state() - module = importlib.util.module_from_spec(spec) with temp_cache_module(package_name, module):