Skip to content

Commit

Permalink
Add some more comments and code regions. Also rename generated files …
Browse files Browse the repository at this point in the history
…for benchmarks.
  • Loading branch information
Sachaa-Thanasius committed Sep 16, 2024
1 parent 239eb2f commit 7caefbb
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 49 deletions.
32 changes: 18 additions & 14 deletions benchmark/bench_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,29 @@ def bench_regular() -> float:
return ct.elapsed


def bench_defer_imports() -> float:
# with defer_imports.install_import_hook(), CatchTime() as ct:
# import benchmark.sample_defer_imports
# print(len(dir(benchmark.sample_defer_imports)))
with defer_imports.install_import_hook(is_global=True), CatchTime() as ct:
import benchmark.sample_regular_defer
return ct.elapsed


def bench_slothy() -> float:
with CatchTime() as ct:
import benchmark.sample_slothy
return ct.elapsed


def bench_defer_imports_local() -> float:
with defer_imports.install_import_hook(), CatchTime() as ct:
import benchmark.sample_defer_local
return ct.elapsed


def bench_defer_imports_global() -> float:
with defer_imports.install_import_hook(is_global=True), CatchTime() as ct:
import benchmark.sample_defer_global
return ct.elapsed


BENCH_FUNCS = {
"regular": bench_regular,
"slothy": bench_slothy,
"defer_imports": bench_defer_imports,
"defer_imports (local)": bench_defer_imports_local,
"defer_imports (global)": bench_defer_imports_global,
}


Expand All @@ -74,7 +78,7 @@ def main() -> None:
parser.add_argument(
"--exec-order",
action="extend",
nargs=3,
nargs=4,
choices=BENCH_FUNCS.keys(),
type=str,
help="The order in which the influenced (or not influenced) imports are run",
Expand All @@ -100,11 +104,11 @@ def main() -> None:
version_len = len(version_header)
version_divider = "=" * version_len

benchmark_len = 14
benchmark_len = 22
benchmark_header = "Benchmark".ljust(benchmark_len)
benchmark_divider = "=" * benchmark_len

time_len = 23
time_len = 19
time_header = "Time".ljust(time_len)
time_divider = "=" * time_len

Expand All @@ -125,7 +129,7 @@ def main() -> None:

for bench_type, result in results.items():
fmt_bench_type = bench_type.ljust(benchmark_len)
fmt_result = f"{result:.7f}s ({result / minimum:.2f}x)".ljust(time_len)
fmt_result = f"{result:.5f}s ({result / minimum:.2f}x)".ljust(time_len)

print(impl, version, fmt_bench_type, fmt_result, sep=" ")

Expand Down
13 changes: 8 additions & 5 deletions benchmark/generate_samples.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Generate sample scripts with the same set of imports but influenced by different libraries, e.g. defer_imports."""

import shutil
from pathlib import Path


Expand Down Expand Up @@ -563,8 +564,11 @@ def main() -> None:
regular_contents = "\n".join((PYRIGHT_IGNORE_DIRECTIVES, GENERATED_BY_COMMENT, STDLIB_IMPORTS))
regular_path.write_text(regular_contents, encoding="utf-8")

# defer_imports-instrumented and defer_imports-hooked imports
defer_imports_path = bench_path / "sample_defer_imports.py"
# defer_imports-instrumented and defer_imports-hooked imports (global)
shutil.copy(regular_path, regular_path.with_name("sample_defer_global.py"))

# defer_imports-instrumented and defer_imports-hooked imports (local)
defer_imports_path = bench_path / "sample_defer_local.py"
defer_imports_contents = (
f"{PYRIGHT_IGNORE_DIRECTIVES}\n"
f"{GENERATED_BY_COMMENT}\n"
Expand All @@ -576,9 +580,8 @@ def main() -> None:
)
defer_imports_path.write_text(defer_imports_contents, encoding="utf-8")

# Same defer_imports-influenced imports, but for a test in the tests directory
tests_path = Path().resolve() / "tests" / "stdlib_imports.py"
tests_path.write_text(defer_imports_contents, encoding="utf-8")
# defer_imports-influenced imports (local), but for a test in the tests directory
shutil.copy(defer_imports_path, bench_path.parent / "tests" / "stdlib_imports.py")

# slothy-hooked imports
slothy_path = bench_path / "sample_slothy.py"
Expand Down
File renamed without changes.
File renamed without changes.
84 changes: 58 additions & 26 deletions src/defer_imports/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


# ============================================================================
# region -------- Vendored helpers
# region -------- Vendored helpers --------
#
# The helper functions should reflect the behavior of the corresponding functions in all supported CPython versions.
# ============================================================================
Expand Down Expand Up @@ -139,15 +139,15 @@ def resolve_name(name: str, package: str, level: int) -> str:


# ============================================================================
# region -------- Compile-time hook
# region -------- Compile-time hook --------
# ============================================================================


StrPath: _tp.TypeAlias = "_tp.Union[str, _tp.PathLike[str]]"
SourceData: _tp.TypeAlias = "_tp.Union[_tp.ReadableBuffer, str, ast.Module, ast.Expression, ast.Interactive]"


should_apply_globally = contextvars.ContextVar("should_instrument_globally", default=False)
should_instrument_globally = contextvars.ContextVar("should_instrument_globally", default=False)
"""Whether the instrumentation should apply globally."""


Expand All @@ -158,6 +158,10 @@ def resolve_name(name: str, package: str, level: int) -> str:
class DeferredInstrumenter(ast.NodeTransformer):
"""AST transformer that instruments imports within "with defer_imports.until_use: ..." blocks so that their
results are assigned to custom keys in the global namespace.
Notes
-----
This assumes the module is not empty and "with defer_imports.until_use" is used somewhere in it.
"""

def __init__(
Expand All @@ -173,6 +177,8 @@ def __init__(
self.scope_depth = 0
self.escape_hatch_depth = 0

# region ---- Scope tracking ----

def _visit_scope(self, node: ast.AST) -> ast.AST:
"""Track Python scope changes. Used to determine if defer_imports.until_use usage is global."""

Expand Down Expand Up @@ -201,6 +207,10 @@ def _visit_eager_import_block(self, node: ast.AST) -> ast.AST:
if sys.version_info >= (3, 11):
visit_TryStar = _visit_eager_import_block

# endregion

# region ---- Until_use-wrapped imports instrumentation ----

def _decode_source(self) -> str:
"""Get the source code corresponding to the given data."""

Expand Down Expand Up @@ -355,26 +365,19 @@ def visit_With(self, node: ast.With) -> ast.AST:
"""

if not self.check_With_for_defer_usage(node):
self.escape_hatch_depth += 1
try:
return self.generic_visit(node)
finally:
self.escape_hatch_depth -= 1
self._visit_eager_import_block(node)

if self.scope_depth != 0:
if self.scope_depth > 0:
msg = "with defer_imports.until_use only allowed at module level"
raise SyntaxError(msg, self._get_node_context(node))

node.body = self._substitute_import_keys(node.body)
return node

def visit_Module(self, node: ast.Module) -> ast.AST:
"""Insert imports necessary to make defer_imports.until_use work properly. The import is placed after the
module docstring and after __future__ imports.
"""Insert imports necessary to make defer_imports.until_use work properly.
Notes
-----
This assumes the module is not empty and "with defer_imports.until_use" is used somewhere in it.
The imports are placed after the module docstring and after __future__ imports.
"""

expect_docstring = True
Expand All @@ -396,7 +399,7 @@ def visit_Module(self, node: ast.Module) -> ast.AST:
position += 1

# Import defer classes.
if should_apply_globally.get():
if should_instrument_globally.get():
top_level_import = ast.Import(names=[ast.alias(name="defer_imports")])
node.body.insert(position, top_level_import)
position += 1
Expand All @@ -413,6 +416,10 @@ def visit_Module(self, node: ast.Module) -> ast.AST:

return self.generic_visit(node)

# endregion

# region ---- Global imports instrumentation ----

@staticmethod
def _identify_regular_import(obj: object) -> _tp.TypeGuard[_tp.Union[ast.Import, ast.ImportFrom]]:
"""Check if a given object is an import AST without wildcards."""
Expand All @@ -421,12 +428,20 @@ def _identify_regular_import(obj: object) -> _tp.TypeGuard[_tp.Union[ast.Import,

@staticmethod
def _is_defer_imports_import(node: _tp.Union[ast.Import, ast.ImportFrom]) -> bool:
"""Check if the given import node imports from defer_imports."""

if isinstance(node, ast.Import):
return any(alias.name.partition(".")[0] == "defer_imports" for alias in node.names)
else:
return node.module is not None and node.module.partition(".")[0] == "defer_imports"

def _wrap_import_stmts(self, nodes: list[ast.stmt], start: int) -> ast.With:
"""Wrap a list of consecutive import nodes from a list of statements in a "defer_imports.until_use" block and
instrument them.
The first node must be guaranteed to be an import node.
"""

import_range = tuple(takewhile(lambda i: self._identify_regular_import(nodes[i]), range(start, len(nodes))))
import_slice = slice(import_range[0], import_range[-1] + 1)
import_nodes = nodes[import_slice]
Expand All @@ -449,14 +464,22 @@ def _wrap_import_stmts(self, nodes: list[ast.stmt], start: int) -> ast.With:
return wrapper_node

def generic_visit(self, node: ast.AST) -> ast.AST:
"""Called if no explicit visitor function exists for a node.
Summary
-------
Almost a copy of ast.NodeVisitor.generic_vist, but intercepts global sequences of import statements to wrap
them in a "with defer_imports.until_use" block and instrument them.
"""

for field, old_value in ast.iter_fields(node):
if isinstance(old_value, list):
new_values: list[_tp.Any] = []
for i, value in enumerate(old_value): # pyright: ignore
# This if block is the only difference from NodeTransformer.generic_visit.
if (
should_apply_globally.get()
# Only instrument import nodes, specifically ones we are prepared to handle.
# Only do this when the user has enabled global instrumentation.
should_instrument_globally.get()
# Only instrument import nodes that we are prepared to handle.
and self._identify_regular_import(value) # pyright: ignore [reportUnknownArgumentType]
# Only instrument imports in global scopes.
and self.scope_depth == 0
Expand All @@ -482,6 +505,8 @@ def generic_visit(self, node: ast.AST) -> ast.AST:
setattr(node, field, new_node)
return node

# endregion


def check_source_for_defer_usage(data: _tp.Union[_tp.ReadableBuffer, str]) -> tuple[str, bool]:
"""Get the encoding of the given code and also check if it uses "with defer_imports.until_use"."""
Expand Down Expand Up @@ -614,7 +639,7 @@ def set_data(self, path: str, data: _tp.ReadableBuffer, *, _mode: int = 0o666) -


# ============================================================================
# region -------- Runtime hook
# region -------- Runtime hook --------
# ============================================================================


Expand Down Expand Up @@ -812,17 +837,23 @@ def deferred___import__( # noqa: ANN202


# ============================================================================
# region -------- Public API
# region -------- Public API --------
# ============================================================================


def install_import_hook(*, is_global: bool = False) -> ImportHookContext:
"""Insert defer_imports's path hook right before the default FileFinder one in sys.path_hooks.
This should be run before the rest of your code. One place to put it is in __init__.py of your package.
Returns
-------
ImportHookContext
A object that can be used to uninstall the import hook, either manually by calling its uninstall method or
automatically by using it as a context manager.
"""

global_apply_tok = should_apply_globally.set(is_global)
global_apply_tok = should_instrument_globally.set(is_global)

if DEFERRED_PATH_HOOK not in sys.path_hooks:
# NOTE: PathFinder.invalidate_caches() is expensive because it imports importlib.metadata, but we have to just bear
Expand All @@ -837,7 +868,7 @@ def install_import_hook(*, is_global: bool = False) -> ImportHookContext:


class ImportHookContext:
def __init__(self, tok: contextvars.Token[bool]):
def __init__(self, tok: contextvars.Token[bool]) -> None:
self._tok = tok

def __enter__(self) -> _tp.Self:
Expand All @@ -849,16 +880,17 @@ def __exit__(self, *exc_info: object) -> None:
def uninstall(self) -> None:
"""Remove defer_imports's path hook if it's in sys.path_hooks."""

if hasattr(self, "_tok"):
should_apply_globally.reset(self._tok)
del self._tok
# Ensure the token is only used once.
if self._tok is not None:
should_instrument_globally.reset(self._tok)
self._tok = None

try:
sys.path_hooks.remove(DEFERRED_PATH_HOOK)
except ValueError:
pass
else:
# NOTE: Use the same invalidation mechanism as install_defer_import_hook() does.
# NOTE: Use the same invalidation mechanism as install_import_hook() does.
PathFinder.invalidate_caches()


Expand Down
2 changes: 1 addition & 1 deletion src/defer_imports/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class TypeGuard:
raise AttributeError(msg)


_original_global_names = list(globals())
_original_global_names = tuple(globals())


def __dir__() -> list[str]:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
DeferredFileLoader,
DeferredInstrumenter,
install_import_hook,
should_apply_globally,
should_instrument_globally,
)


Expand Down Expand Up @@ -67,11 +67,11 @@ def _verbose_repr(self) -> str: # pyright: ignore # noqa: ANN001
def global_instrumentation_on():
"""Turn on the global instrumentation aspect of defer_imports temporarily."""

_tok = should_apply_globally.set(True)
_tok = should_instrument_globally.set(True)
try:
yield
finally:
should_apply_globally.reset(_tok)
should_instrument_globally.reset(_tok)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 7caefbb

Please sign in to comment.