Skip to content

Commit

Permalink
Get existing tests passing
Browse files Browse the repository at this point in the history
- Re-add code regions
- Adjust some comments and docstrings.
- Remove need to set loader state in existing tests.
- Run pre-commit autoupdate.
  • Loading branch information
Sachaa-Thanasius committed Sep 17, 2024
1 parent aeeef4f commit 82ce25b
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 93 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ repos:
- id: check-yaml

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
rev: v0.6.5
hooks:
- id: ruff
args: [--fix]
Expand Down
2 changes: 0 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ This can be installed via pip::

python -m pip install defer-imports

It can also easily be vendored, as it has zero dependencies and has around 1,000 lines of code.


Usage
=====
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
dynamic = ["version"]

[tool.hatch.version]
path = "src/defer_imports/_compiletime.py"
path = "src/defer_imports/_comptime.py"

[project.optional-dependencies]
benchmark = ["slothy"]
Expand Down
125 changes: 91 additions & 34 deletions src/defer_imports/_comptime.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,9 @@
__version__ = "0.0.2"


StrPath: _tp.TypeAlias = "_tp.Union[str, os.PathLike[str]]"
SourceData: _tp.TypeAlias = "_tp.Union[_tp.ReadableBuffer, str, ast.Module, ast.Expression, ast.Interactive]"


TOK_NAME, TOK_OP = tokenize.NAME, tokenize.OP

BYTECODE_HEADER = f"defer_imports{__version__}".encode()
"""Custom header for defer_imports-instrumented bytecode files. Should be updated with every version release."""
# ============================================================================
# region -------- Helper functions --------
# ============================================================================


def sliding_window(iterable: _tp.Iterable[_tp.T], n: int) -> _tp.Generator[tuple[_tp.T, ...]]:
Expand All @@ -52,6 +47,24 @@ def sliding_window(iterable: _tp.Iterable[_tp.T], n: int) -> _tp.Generator[tuple
yield tuple(window)


# endregion


# ============================================================================
# region -------- Main implementation --------
# ============================================================================


StrPath: _tp.TypeAlias = "_tp.Union[str, os.PathLike[str]]"
SourceData: _tp.TypeAlias = "_tp.Union[_tp.ReadableBuffer, str, ast.Module, ast.Expression, ast.Interactive]"


TOK_NAME, TOK_OP = tokenize.NAME, tokenize.OP

BYTECODE_HEADER = f"defer_imports{__version__}".encode()
"""Custom header for defer_imports-instrumented bytecode files. Should be updated with every version release."""


class DeferredInstrumenter(ast.NodeTransformer):
"""AST transformer that instruments imports within "with defer_imports.until_use: ..." blocks so that their
results are assigned to custom keys in the global namespace.
Expand All @@ -65,7 +78,7 @@ def __init__(
self,
data: _tp.Union[_tp.ReadableBuffer, str, ast.AST],
filepath: _tp.Union[StrPath, _tp.ReadableBuffer],
encoding: str,
encoding: str = "utf-8",
*,
defer_state: bool = False,
) -> None:
Expand Down Expand Up @@ -109,7 +122,7 @@ def _visit_eager_import_block(self, node: ast.AST) -> ast.AST:

# endregion

# region ---- until_use-wrapped imports instrumentation ----
# region ---- Basic instrumentation ----

def _decode_source(self) -> str:
"""Get the source code corresponding to the given data."""
Expand Down Expand Up @@ -195,7 +208,7 @@ def _initialize_local_ns() -> ast.Assign:

return ast.Assign(
targets=[ast.Name("@local_ns", ctx=ast.Store())],
value=ast.Call(ast.Name("locals", ctx=ast.Load()), args=[], keywords=[]),
value=ast.Call(func=ast.Name("locals", ctx=ast.Load()), args=[], keywords=[]),
)

@staticmethod
Expand All @@ -216,7 +229,7 @@ def _substitute_import_keys(self, import_nodes: list[ast.stmt]) -> list[ast.stmt
If any of the given nodes are not an import or are a wildcard import.
"""

new_import_nodes: list[ast.stmt] = list(import_nodes)
new_import_nodes = list(import_nodes)

for i in range(len(import_nodes) - 1, -1, -1):
node = import_nodes[i]
Expand Down Expand Up @@ -265,7 +278,7 @@ def visit_With(self, node: ast.With) -> ast.AST:
"""

if not self.check_With_for_defer_usage(node):
self._visit_eager_import_block(node)
return node

if self.scope_depth > 0:
msg = "with defer_imports.until_use only allowed at module level"
Expand Down Expand Up @@ -336,7 +349,7 @@ def _is_defer_imports_import(node: _tp.Union[ast.Import, ast.ImportFrom]) -> boo
return node.module is not None and node.module.partition(".")[0] == "defer_imports"

def _wrap_import_stmts(self, nodes: list[ast.stmt], start: int) -> ast.With:
"""Wrap a list of consecutive import nodes from a list of statements in a "defer_imports.until_use" block and
"""Wrap a list of consecutive import nodes from a list of statements using a "defer_imports.until_use" block and
instrument them.
The first node must be guaranteed to be an import node.
Expand Down Expand Up @@ -366,9 +379,9 @@ def _wrap_import_stmts(self, nodes: list[ast.stmt], start: int) -> ast.With:
def generic_visit(self, node: ast.AST) -> ast.AST:
"""Called if no explicit visitor function exists for a node.
Summary
-------
Almost a copy of ast.NodeVisitor.generic_vist, but intercepts global sequences of import statements to wrap
Extended Summary
----------------
Almost a copy of ast.NodeVisitor.generic_visit, but we intercept global sequences of import statements to wrap
them in a "with defer_imports.until_use" block and instrument them.
"""

Expand All @@ -379,10 +392,10 @@ def generic_visit(self, node: ast.AST) -> ast.AST:
if (
# Only when global instrumentation is enabled.
self.defer_state
# Only with import nodes that we are prepared to handle.
and self._is_regular_import(value) # pyright: ignore [reportUnknownArgumentType]
# Only at global scope.
and self.scope_depth == 0
# Only with import nodes that we are prepared to handle.
and self._is_regular_import(value) # pyright: ignore [reportUnknownArgumentType]
# Only outside of escape hatch blocks.
and (self.escape_hatch_depth == 0 and not self._is_defer_imports_import(value))
):
Expand Down Expand Up @@ -444,8 +457,8 @@ class DeferredFileLoader(SourceFileLoader):
defer_state: bool

def create_module(self, spec: ModuleSpec) -> _tp.Optional[_tp.ModuleType]:
if spec.loader_state is not None:
self.defer_state = spec.loader_state["defer_state"]
# This method should always run before source_to_code in regular circumstances.
self.defer_state = spec.loader_state["defer_state"] if (spec.loader_state is not None) else False
return super().create_module(spec)

def source_to_code( # pyright: ignore [reportIncompatibleMethodOverride]
Expand All @@ -455,8 +468,7 @@ def source_to_code( # pyright: ignore [reportIncompatibleMethodOverride]
*,
_optimize: int = -1,
) -> _tp.CodeType:
# NOTE: InspectLoader is the virtual superclass of SourceFileLoader thanks to ABC registration, so typeshed
# reflects that. However, there's some mismatch in source_to_code signatures. Can it be fixed with a PR?
# NOTE: Sigature of SourceFileLoader.source_to_code at runtime isn't consistent with the version in typeshed.

if not data:
return super().source_to_code(data, path, _optimize=_optimize) # pyright: ignore # See note above.
Expand All @@ -475,14 +487,14 @@ def source_to_code( # pyright: ignore [reportIncompatibleMethodOverride]
# TODO: This isn't safe when a syntax error occurs since we modify the tree but raise before fixing node
# locations. This also makes it difficult to point at the actual source code location. Find a way to
# deepcopy the tree first.
orig_ast = data
orig_tree = data
else:
orig_ast = ast.parse(data, path, "exec")
orig_tree = ast.parse(data, path, "exec")

transformer = DeferredInstrumenter(data, path, encoding, defer_state=self.defer_state)
new_ast = ast.fix_missing_locations(transformer.visit(orig_ast))
new_tree = ast.fix_missing_locations(transformer.visit(orig_tree))

return super().source_to_code(new_ast, path, _optimize=_optimize) # pyright: ignore # See note above.
return super().source_to_code(new_tree, path, _optimize=_optimize) # pyright: ignore # See note above.

def get_data(self, path: str) -> bytes:
"""Return the data from path as raw bytes.
Expand Down Expand Up @@ -536,6 +548,9 @@ def set_data(self, path: str, data: _tp.ReadableBuffer, *, _mode: int = 0o666) -
return super().set_data(path, data, _mode=_mode)


LOADER_DETAILS = (DeferredFileLoader, SOURCE_SUFFIXES)


class DeferredFileFinder(FileFinder):
def __init__(
self,
Expand All @@ -549,14 +564,20 @@ def __init__(
self.deferred_modules = deferred_modules

def find_spec(self, fullname: str, target: _tp.Optional[_tp.ModuleType] = None) -> _tp.Optional[ModuleSpec]:
"""Try to find a spec for 'fullname' on sys.path or 'path', with some modifications based on defer state."""
"""Try to find a spec for "fullname" on sys.path or "path", with some modifications based on defer state."""

spec = super().find_spec(fullname, target)
if spec is not None and isinstance(spec.loader, DeferredFileLoader):
# It's underdocumented, but spec.loader_state is meant for this kind of thing.
# It's under-documented, but spec.loader_state is meant for this kind of thing.
# Ref: https://docs.python.org/3/library/importlib.html#importlib.machinery.ModuleSpec.loader_state
# Ref: https://github.com/python/cpython/issues/89527
defer_state = self.defer_globally or bool(self.deferred_modules and (fullname in self.deferred_modules))
defer_state = self.defer_globally or bool(
self.deferred_modules
and (
fullname in self.deferred_modules
or any(module_name.startswith(f"{fullname}.") for module_name in self.deferred_modules)
)
)
spec.loader_state = {"defer_state": defer_state}
return spec

Expand All @@ -579,7 +600,31 @@ def path_hook_for_DeferredFileFinder(path: str) -> _tp.Self:
return path_hook_for_DeferredFileFinder


# endregion


# ============================================================================
# region -------- Public API --------
# ============================================================================


@_tp.final
class ImportHookContext:
"""The context manager returned by install_import_hook.
Parameters
----------
path_hook: Callable[[str], PathEntryFinderProtocol]
A path hook to uninstall. Can be uninstalled manually with the uninstall method or automatically upon
exiting the context manager.
Attributes
----------
path_hook: Callable[[str], PathEntryFinderProtocol]
A path hook to uninstall. Can be uninstalled manually with the uninstall method or automatically upon
exiting the context manager.
"""

def __init__(self, path_hook: _tp.Callable[[str], _tp.PathEntryFinderProtocol]) -> None:
self.path_hook = path_hook

Expand All @@ -590,7 +635,7 @@ def __exit__(self, *exc_info: object) -> None:
self.uninstall()

def uninstall(self) -> None:
"""Remove the path hook if it's still in sys.path_hooks."""
"""Attempt to remove the path hook from sys.path_hooks. If successful, also invalidate path entry caches."""

try:
sys.path_hooks.remove(self.path_hook)
Expand All @@ -602,19 +647,28 @@ def uninstall(self) -> None:


def install_import_hook(*, is_global: bool = False, module_names: _tp.Sequence[str] = ()) -> ImportHookContext:
"""Insert defer_imports's path hook right before the default FileFinder one in sys.path_hooks.
"""Insert a custom defer_imports path hook in sys.path_hooks with optional configuration for instrumenting
ALL import statements, not only ones wrapped by the defer_imports.until_use context manager.
This should be run before the rest of your code. One place to put it is in __init__.py of your package.
Parameters
----------
is_global: bool, default=False
Whether to apply module-level import deferral, i.e. instrumentation of all imports, to all modules henceforth.
Mutually exclusive with and has higher priority than module_names.
module_names: Sequence[str], optional
Whether to apply module-level import deferral to a set of modules and recursively to all of their submodules.
Mutually exclusive with and has lower priority than is_global.
Returns
-------
ImportHookContext
A object that can be used to uninstall the import hook, either manually by calling its uninstall method or
automatically by using it as a context manager.
"""

loader_details = (DeferredFileLoader, SOURCE_SUFFIXES)
path_hook = DeferredFileFinder.path_hook(loader_details, defer_globally=is_global, deferred_modules=module_names)
path_hook = DeferredFileFinder.path_hook(LOADER_DETAILS, defer_globally=is_global, deferred_modules=module_names)

try:
hook_insert_index = sys.path_hooks.index(zipimport.zipimporter) + 1 # pyright: ignore [reportArgumentType]
Expand All @@ -627,3 +681,6 @@ def install_import_hook(*, is_global: bool = False, module_names: _tp.Sequence[s
sys.path_hooks.insert(hook_insert_index, path_hook)

return ImportHookContext(path_hook)


# endregion
Loading

0 comments on commit 82ce25b

Please sign in to comment.