diff --git a/docs/source/changes.rst b/docs/source/changes.rst index 9f344f7a9..17b267d22 100644 --- a/docs/source/changes.rst +++ b/docs/source/changes.rst @@ -11,7 +11,7 @@ chronological order. All previous releases should still be available on pip. .. _v0.12.0: ---------------------- -0.12.0rc4 - 2023-11-07 +0.12.0rc5 - 2023-11-12 ---------------------- @@ -67,6 +67,7 @@ Improvements - :class:`~hydra_zen.BuildsFn` was introduced to permit customizable auto-config and type-refinement support in config-creation functions. See :pull:`553`. - :func:`~hydra_zen.builds` and :func:`~hydra_zen.make_custom_builds_fn` now accept a `zen_exclude` field for excluding parameters from auto-population, either by name or by pattern. See :pull:`558`. - :func:`~hydra_zen.builds` and :func:`~hydra_zen.just` can now configure static methods. Previously the incorrect ``_target_`` would be resolved. See :pull:`566` +- :func:`hydra_zen.zen` now has first class support for running code in an isolated :py:class:`contextvars.Context`. This enables users to safely leverage state via :py:class:`contextvars.ContextVar` in their task functions. See :pull:`583`. - Adds formal support for Python 3.12. See :pull:`555` - Several new methods were added to :class:`~hydra_zen.ZenStore`, including the abilities to copy, update, and merge stores. As well as remap the groups of a store's entries and delete individual entries. See :pull:`569` diff --git a/pyproject.toml b/pyproject.toml index b34cf8716..0dd8259c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ classifiers = [ ] [project.optional-dependencies] -test = ["pytest >= 3.8", "hypothesis >= 6.28.0"] +test = ["pytest >= 3.8", "hypothesis >= 6.28.0", "pytest-trio >= 0.8.0"] pydantic = ["pydantic>=1.8.2,<2.0.0"] beartype = ["beartype>=0.8.0"] @@ -104,6 +104,7 @@ exclude_lines = [ [tool.pytest.ini_options] xfail_strict = true +trio_mode = true [tool.pyright] @@ -149,6 +150,7 @@ deps = setuptools pytest hypothesis pytest-xdist + pytest-trio tzdata commands = pytest tests/ {posargs: -n auto --maxprocesses=4} diff --git a/src/hydra_zen/wrapper/_implementations.py b/src/hydra_zen/wrapper/_implementations.py index 77e1407e9..63fa6d1fe 100644 --- a/src/hydra_zen/wrapper/_implementations.py +++ b/src/hydra_zen/wrapper/_implementations.py @@ -4,9 +4,10 @@ import warnings from collections import defaultdict +from contextvars import copy_context from copy import deepcopy -from functools import wraps -from inspect import Parameter, signature +from functools import partial, wraps +from inspect import Parameter, iscoroutinefunction, signature from typing import ( Any, Callable, @@ -139,6 +140,7 @@ def __init__( pre_call: PreCall = None, unpack_kwargs: bool = False, resolve_pre_call: bool = True, + run_in_context: bool = False, ) -> None: """ Parameters @@ -148,23 +150,41 @@ def __init__( unpack_kwargs: bool, optional (default=False) If `True` a `**kwargs` field in the wrapped function's signature will be - populated by all of the input config entries that are not specified by the - rest of the signature (and that are not specified by the `exclude` - argument). + populated by all of the input config entries that are not specified by the rest + of the signature (and that are not specified by the `exclude` argument). pre_call : Optional[Callable[[Any], Any] | Iterable[Callable[[Any], Any]]] One or more functions that will be called with the input config prior - to the wrapped functions. An iterable of pre-call functions are called + to the wrapped function. An iterable of pre-call functions are called from left (low-index) to right (high-index). - exclude: Optional[str | Iterable[str]] + This is useful, e.g., for seeding a RNG prior to the instantiation phase + that is triggered when calling the wrapped function. + + resolve_pre_call : bool, (default=True) + If `True`, the config passed to the zen-wrapped function has its + interpolated fields resolved prior to being passed to any pre-call + functions. Otherwise, the interpolation occurs after the pre-call functions + are called. + + exclude : Optional[str | Iterable[str]] Specifies one or more parameter names in the function's signature that will not be extracted from input configs by the zen-wrapped function. A single string of comma-separated names can be specified. + + run_in_context : bool, optional (default=False) + If `True`, the zen-wrapped function - and the `pre_call` function, if + specified - is run in a copied :py:class:`contextvars.Context`; i.e. + changes made to any :py:class:`contextvars.ContextVar` will be isolated to + that call of the wrapped function. + + `run_in_context` is not supported for async functions. """ - self.func: Callable[P, R] = __func + if run_in_context and iscoroutinefunction(__func): + raise TypeError(f"`{run_in_context=} is not supported for async functions.") + self.func: Callable[P, R] = __func try: # Must cast to dict so that `self` is pickle-compatible. self.parameters: Mapping[str, Parameter] = dict( @@ -183,11 +203,18 @@ def __init__( raise TypeError( f"`resolve_pre_call` must be type `bool` got {resolve_pre_call}" ) + + if not isinstance(run_in_context, bool): # pragma: no cover + raise TypeError( + f"`run_in_context` must be type `bool` got {run_in_context}" + ) + self._resolve = resolve_pre_call self._unpack_kwargs: bool = unpack_kwargs and any( p.kind is p.VAR_KEYWORD for p in self.parameters.values() ) + self._run_in_context: bool = run_in_context self._exclude: Set[str] if exclude is None: @@ -216,7 +243,13 @@ def __init__( if _f is None: continue - _f_params = signature(_f).parameters + if run_in_context and isinstance(_f, Zen) and _f._run_in_context: + raise HydraZenValidationError( + f"zen-wrapped pre_call function {_f!r} cannot specify " + f"`run_in_context=True` when the main wrapper specifies it as well." + ) + + _f_params = signature(_f).parameters # type: ignore if (sum(p.default is p.empty for p in _f_params.values()) > 1) or len( _f_params @@ -353,8 +386,15 @@ def __call__(self, __cfg: Union[ConfigLike, str]) -> R: # resolves all interpolated values in-place OmegaConf.resolve(cfg) + context = copy_context() if self._run_in_context else None + if self.pre_call is not None: - self.pre_call(cfg) + pre_call = ( + self.pre_call + if context is None + else partial(context.run, self.pre_call) + ) + pre_call(cfg) args_ = list(getattr(cfg, "_args_", [])) @@ -377,7 +417,10 @@ def __call__(self, __cfg: Union[ConfigLike, str]) -> R: and isinstance(name, str) ) cfg_kwargs.update({name: cfg[name] for name in names}) - return self.func( + + func = self.func if context is None else partial(context.run, self.func) + + return func( *(self.instantiate(x) if is_instantiable(x) else x for x in args_), **{ name: self.instantiate(val) if is_instantiable(val) else val @@ -472,6 +515,7 @@ def zen( pre_call: PreCall = ..., ZenWrapper: Type[Zen[Any, Any]] = ..., resolve_pre_call: bool = ..., + run_in_context: bool = ..., exclude: Optional[Union[str, Iterable[str]]] = ..., ) -> Zen[P, R]: ... @@ -485,6 +529,7 @@ def zen( pre_call: PreCall = ..., resolve_pre_call: bool = ..., ZenWrapper: Type[Zen[Any, Any]] = ..., + run_in_context: bool = ..., exclude: Optional[Union[str, Iterable[str]]] = ..., ) -> Callable[[Callable[P2, R2]], Zen[P2, R2]]: ... @@ -497,6 +542,7 @@ def zen( pre_call: PreCall = None, exclude: Optional[Union[str, Iterable[str]]] = None, resolve_pre_call: bool = True, + run_in_context: bool = False, ZenWrapper: Type[Zen[Any, Any]] = Zen, ) -> Union[Callable[[Callable[P2, R2]], Zen[P2, R2]], Zen[P, R]]: r"""zen(func, /, pre_call, ZenWrapper) @@ -545,6 +591,14 @@ def zen( A single string of comma-separated names can be specified. + run_in_context : bool, optional (default=False) + If `True`, the zen-wrapped function - and the `pre_call` function, if + specified - is run in a copied :py:class:`contextvars.Context`; i.e. + changes made to any :py:class:`contextvars.ContextVar` will be isolated to + that call of the wrapped function. + + `run_in_context` is not supported for async functions. + ZenWrapper : Type[hydra_zen.wrapper.Zen], optional (default=Zen) If specified, a subclass of `Zen` that customizes the behavior of the wrapper. @@ -746,6 +800,7 @@ def task(x: int, y: int): exclude=exclude, unpack_kwargs=unpack_kwargs, resolve_pre_call=resolve_pre_call, + run_in_context=run_in_context, ), ) @@ -758,6 +813,7 @@ def wrap(f: Callable[P2, R2]) -> Zen[P2, R2]: exclude=exclude, unpack_kwargs=unpack_kwargs, resolve_pre_call=resolve_pre_call, + run_in_context=run_in_context, ), ) return out diff --git a/tests/test_zen.py b/tests/test_zen.py index b0b32c104..6c4e41a1a 100644 --- a/tests/test_zen.py +++ b/tests/test_zen.py @@ -625,3 +625,10 @@ def pikl(x): def test_pickle_compatible(): loaded = pickle.loads(pickle.dumps(zpikl)) assert loaded({"x": 3}) == pikl(3) + + +async def test_async_compatible(): + async def foo(x: int): + return x + + assert await zen(foo)(dict(x=builds(int, 22))) == 22 diff --git a/tests/test_zen_context_isolation.py b/tests/test_zen_context_isolation.py new file mode 100644 index 000000000..11970a46b --- /dev/null +++ b/tests/test_zen_context_isolation.py @@ -0,0 +1,87 @@ +# Copyright (c) 2023 Massachusetts Institute of Technology +# SPDX-License-Identifier: MIT + +import random +from contextvars import ContextVar +from typing import Any, Dict, Optional + +import pytest + +from hydra_zen import zen +from hydra_zen.errors import HydraZenValidationError + +config: ContextVar[Optional[Dict[str, Any]]] = ContextVar("config", default=None) +var: ContextVar[Dict[str, Any]] = ContextVar("var", default=dict()) + + +@pytest.fixture(autouse=True) +def clean_context_vars(): + assert config.get() is None + assert var.get() == {} + yield + config.set(None) + var.set({}) + + +@pytest.mark.parametrize( + "run_in_context", + [ + True, + pytest.param(False, marks=pytest.mark.xfail), + ], +) +def test_context_isolation(run_in_context: bool): + def foo(x: str, zen_cfg): + config.set(zen_cfg) + conf = var.get().copy() + conf[str(random.randint(1, 100))] = random.randint(1, 100) + var.set(conf) + assert len(conf) == 1 + + zfoo = zen(foo, run_in_context=run_in_context) + + for letter in "ab": + zfoo(dict(x=letter)) + assert config.get() is None + assert var.get() == dict() + + +def test_async_func_run_in_context_not_supported(): + async def foo(): + ... + + with pytest.raises(TypeError, match="not supported"): + zen(foo, run_in_context=True) + + +@pytest.mark.parametrize( + "run_in_context", + [ + True, + pytest.param(False, marks=pytest.mark.xfail), + ], +) +def test_pre_call_shares_context_with_wrapped_func(run_in_context: bool): + assert var.get() == {} + + def pre_call(cfg): + var.set({"swagger": 22}) + + def func(): + assert var.get() == {"swagger": 22} + + zen(func, pre_call=pre_call, run_in_context=run_in_context)({}) + assert var.get() == {} + + +def test_pre_call_run_in_its_own_context_is_forbidden(): + def f(x): + ... + + with pytest.raises(HydraZenValidationError): + zen(f, pre_call=zen(f, run_in_context=True), run_in_context=True) + + +def test_validation(): + with pytest.raises(TypeError, match="must be type"): + zen(lambda x: x, run_in_context=None) # type: ignore