diff --git a/src/hydra_zen/wrapper/_implementations.py b/src/hydra_zen/wrapper/_implementations.py index 77e1407e9..446e85713 100644 --- a/src/hydra_zen/wrapper/_implementations.py +++ b/src/hydra_zen/wrapper/_implementations.py @@ -4,9 +4,10 @@ import warnings from collections import defaultdict +from contextvars import copy_context from copy import deepcopy -from functools import wraps -from inspect import Parameter, signature +from functools import partial, wraps +from inspect import Parameter, iscoroutinefunction, signature from typing import ( Any, Callable, @@ -139,6 +140,7 @@ def __init__( pre_call: PreCall = None, unpack_kwargs: bool = False, resolve_pre_call: bool = True, + run_in_context: bool = False, ) -> None: """ Parameters @@ -148,23 +150,42 @@ def __init__( unpack_kwargs: bool, optional (default=False) If `True` a `**kwargs` field in the wrapped function's signature will be - populated by all of the input config entries that are not specified by the - rest of the signature (and that are not specified by the `exclude` - argument). + populated by all of the input config entries that are not specified by the rest + of the signature (and that are not specified by the `exclude` argument). pre_call : Optional[Callable[[Any], Any] | Iterable[Callable[[Any], Any]]] One or more functions that will be called with the input config prior - to the wrapped functions. An iterable of pre-call functions are called + to the wrapped function. An iterable of pre-call functions are called from left (low-index) to right (high-index). - exclude: Optional[str | Iterable[str]] + This is useful, e.g., for seeding a RNG prior to the instantiation phase + that is triggered when calling the wrapped function. + + resolve_pre_call : bool, (default=True) + If `True`, the config passed to the zen-wrapped function has its + interpolated fields resolved prior to being passed to any pre-call + functions. Otherwise, the interpolation occurs after the pre-call functions + are called. + + exclude : Optional[str | Iterable[str]] Specifies one or more parameter names in the function's signature that will not be extracted from input configs by the zen-wrapped function. A single string of comma-separated names can be specified. + + run_in_context : bool, optional (default=False) + If `True`, the zen-wrapped function - and the `pre_call` function, if + specified - is run in a copied :py:class:`contextvars.Context`; i.e. + changes made to any :py:class:`contextvars.ContextVar` will be isolated to + the context of the function. + + `run_in_context` is not supported for async functions. """ - self.func: Callable[P, R] = __func + if run_in_context and iscoroutinefunction(__func): + raise TypeError(f"`{run_in_context=} is not supported for async functions.") + self.func: Callable[P, R] = __func + self._run_in_context: bool = run_in_context try: # Must cast to dict so that `self` is pickle-compatible. self.parameters: Mapping[str, Parameter] = dict( @@ -353,8 +374,15 @@ def __call__(self, __cfg: Union[ConfigLike, str]) -> R: # resolves all interpolated values in-place OmegaConf.resolve(cfg) + context = copy_context() if self._run_in_context else None + if self.pre_call is not None: - self.pre_call(cfg) + pre_call = ( + self.pre_call + if context is None + else partial(context.run, self.pre_call) + ) + pre_call(cfg) args_ = list(getattr(cfg, "_args_", [])) @@ -377,7 +405,10 @@ def __call__(self, __cfg: Union[ConfigLike, str]) -> R: and isinstance(name, str) ) cfg_kwargs.update({name: cfg[name] for name in names}) - return self.func( + + func = self.func if context is None else partial(context.run, self.func) + + return func( *(self.instantiate(x) if is_instantiable(x) else x for x in args_), **{ name: self.instantiate(val) if is_instantiable(val) else val @@ -472,6 +503,7 @@ def zen( pre_call: PreCall = ..., ZenWrapper: Type[Zen[Any, Any]] = ..., resolve_pre_call: bool = ..., + run_in_context: bool = ..., exclude: Optional[Union[str, Iterable[str]]] = ..., ) -> Zen[P, R]: ... @@ -485,6 +517,7 @@ def zen( pre_call: PreCall = ..., resolve_pre_call: bool = ..., ZenWrapper: Type[Zen[Any, Any]] = ..., + run_in_context: bool = ..., exclude: Optional[Union[str, Iterable[str]]] = ..., ) -> Callable[[Callable[P2, R2]], Zen[P2, R2]]: ... @@ -497,6 +530,7 @@ def zen( pre_call: PreCall = None, exclude: Optional[Union[str, Iterable[str]]] = None, resolve_pre_call: bool = True, + run_in_context: bool = False, ZenWrapper: Type[Zen[Any, Any]] = Zen, ) -> Union[Callable[[Callable[P2, R2]], Zen[P2, R2]], Zen[P, R]]: r"""zen(func, /, pre_call, ZenWrapper) @@ -545,6 +579,14 @@ def zen( A single string of comma-separated names can be specified. + run_in_context : bool, optional (default=False) + If `True`, the zen-wrapped function - and the `pre_call` function, if + specified - is run in a copied :py:class:`contextvars.Context`; i.e. + changes made to any :py:class:`contextvars.ContextVar` will be isolated to + the context of the call to the wrapped function. + + `run_in_context` is not supported for async functions. + ZenWrapper : Type[hydra_zen.wrapper.Zen], optional (default=Zen) If specified, a subclass of `Zen` that customizes the behavior of the wrapper. @@ -746,6 +788,7 @@ def task(x: int, y: int): exclude=exclude, unpack_kwargs=unpack_kwargs, resolve_pre_call=resolve_pre_call, + run_in_context=run_in_context, ), ) @@ -758,6 +801,7 @@ def wrap(f: Callable[P2, R2]) -> Zen[P2, R2]: exclude=exclude, unpack_kwargs=unpack_kwargs, resolve_pre_call=resolve_pre_call, + run_in_context=run_in_context, ), ) return out diff --git a/tests/test_zen_context_isolation.py b/tests/test_zen_context_isolation.py new file mode 100644 index 000000000..1a2c32987 --- /dev/null +++ b/tests/test_zen_context_isolation.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 Massachusetts Institute of Technology +# SPDX-License-Identifier: MIT + +import random +from contextvars import ContextVar +from typing import Any, Dict, Optional + +import pytest + +from hydra_zen import zen + +config: ContextVar[Optional[Dict[str, Any]]] = ContextVar("config", default=None) +var: ContextVar[Dict[str, Any]] = ContextVar("var", default=dict()) + + +@pytest.fixture(autouse=True) +def clean_context_vars(): + assert config.get() is None + assert var.get() == {} + yield + config.set(None) + var.set({}) + + +@pytest.mark.parametrize( + "run_in_context", + [ + True, + pytest.param(False, marks=pytest.mark.xfail), + ], +) +def test_context_isolation(run_in_context: bool): + def foo(x: str, zen_cfg): + config.set(zen_cfg) + conf = var.get().copy() + conf[str(random.randint(1, 100))] = random.randint(1, 100) + var.set(conf) + assert len(conf) == 1 + + zfoo = zen(foo, run_in_context=run_in_context) + + for letter in "ab": + zfoo(dict(x=letter)) + assert config.get() is None + assert var.get() == dict() + + +async def async_func_run_in_context_not_supported(): + async def foo(): + ... + + with pytest.raises(TypeError, match="not supported"): + zen(foo, run_in_context=True) + + +@pytest.mark.parametrize( + "run_in_context", + [ + True, + pytest.param(False, marks=pytest.mark.xfail), + ], +) +def test_pre_call_shares_context_with_wrapped_func(run_in_context: bool): + assert var.get() == {} + + def pre_call(cfg): + var.set({"swagger": 22}) + + def func(): + assert var.get() == {"swagger": 22} + + zen(func, pre_call=pre_call, run_in_context=run_in_context)({}) + assert var.get() == {}