Skip to content

Commit

Permalink
Add support for run_in_context
Browse files Browse the repository at this point in the history
rsokl committed Nov 12, 2023
1 parent f64dded commit b18aa86
Showing 2 changed files with 127 additions and 10 deletions.
64 changes: 54 additions & 10 deletions src/hydra_zen/wrapper/_implementations.py
Original file line number Diff line number Diff line change
@@ -4,9 +4,10 @@

import warnings
from collections import defaultdict
from contextvars import copy_context
from copy import deepcopy
from functools import wraps
from inspect import Parameter, signature
from functools import partial, wraps
from inspect import Parameter, iscoroutinefunction, signature
from typing import (
Any,
Callable,
@@ -139,6 +140,7 @@ def __init__(
pre_call: PreCall = None,
unpack_kwargs: bool = False,
resolve_pre_call: bool = True,
run_in_context: bool = False,
) -> None:
"""
Parameters
@@ -148,23 +150,42 @@ def __init__(
unpack_kwargs: bool, optional (default=False)
If `True` a `**kwargs` field in the wrapped function's signature will be
populated by all of the input config entries that are not specified by the
rest of the signature (and that are not specified by the `exclude`
argument).
populated by all of the input config entries that are not specified by the rest
of the signature (and that are not specified by the `exclude` argument).
pre_call : Optional[Callable[[Any], Any] | Iterable[Callable[[Any], Any]]]
One or more functions that will be called with the input config prior
to the wrapped functions. An iterable of pre-call functions are called
to the wrapped function. An iterable of pre-call functions are called
from left (low-index) to right (high-index).
exclude: Optional[str | Iterable[str]]
This is useful, e.g., for seeding a RNG prior to the instantiation phase
that is triggered when calling the wrapped function.
resolve_pre_call : bool, (default=True)
If `True`, the config passed to the zen-wrapped function has its
interpolated fields resolved prior to being passed to any pre-call
functions. Otherwise, the interpolation occurs after the pre-call functions
are called.
exclude : Optional[str | Iterable[str]]
Specifies one or more parameter names in the function's signature
that will not be extracted from input configs by the zen-wrapped function.
A single string of comma-separated names can be specified.
run_in_context : bool, optional (default=False)
If `True`, the zen-wrapped function - and the `pre_call` function, if
specified - is run in a copied :py:class:`contextvars.Context`; i.e.
changes made to any :py:class:`contextvars.ContextVar` will be isolated to
the context of the function.
`run_in_context` is not supported for async functions.
"""
self.func: Callable[P, R] = __func
if run_in_context and iscoroutinefunction(__func):
raise TypeError(f"`{run_in_context=} is not supported for async functions.")

self.func: Callable[P, R] = __func
self._run_in_context: bool = run_in_context
try:
# Must cast to dict so that `self` is pickle-compatible.
self.parameters: Mapping[str, Parameter] = dict(
@@ -353,8 +374,15 @@ def __call__(self, __cfg: Union[ConfigLike, str]) -> R:
# resolves all interpolated values in-place
OmegaConf.resolve(cfg)

context = copy_context() if self._run_in_context else None

if self.pre_call is not None:
self.pre_call(cfg)
pre_call = (
self.pre_call
if context is None
else partial(context.run, self.pre_call)
)
pre_call(cfg)

args_ = list(getattr(cfg, "_args_", []))

@@ -377,7 +405,10 @@ def __call__(self, __cfg: Union[ConfigLike, str]) -> R:
and isinstance(name, str)
)
cfg_kwargs.update({name: cfg[name] for name in names})
return self.func(

func = self.func if context is None else partial(context.run, self.func)

return func(
*(self.instantiate(x) if is_instantiable(x) else x for x in args_),
**{
name: self.instantiate(val) if is_instantiable(val) else val
@@ -472,6 +503,7 @@ def zen(
pre_call: PreCall = ...,
ZenWrapper: Type[Zen[Any, Any]] = ...,
resolve_pre_call: bool = ...,
run_in_context: bool = ...,
exclude: Optional[Union[str, Iterable[str]]] = ...,
) -> Zen[P, R]:
...
@@ -485,6 +517,7 @@ def zen(
pre_call: PreCall = ...,
resolve_pre_call: bool = ...,
ZenWrapper: Type[Zen[Any, Any]] = ...,
run_in_context: bool = ...,
exclude: Optional[Union[str, Iterable[str]]] = ...,
) -> Callable[[Callable[P2, R2]], Zen[P2, R2]]:
...
@@ -497,6 +530,7 @@ def zen(
pre_call: PreCall = None,
exclude: Optional[Union[str, Iterable[str]]] = None,
resolve_pre_call: bool = True,
run_in_context: bool = False,
ZenWrapper: Type[Zen[Any, Any]] = Zen,
) -> Union[Callable[[Callable[P2, R2]], Zen[P2, R2]], Zen[P, R]]:
r"""zen(func, /, pre_call, ZenWrapper)
@@ -545,6 +579,14 @@ def zen(
A single string of comma-separated names can be specified.
run_in_context : bool, optional (default=False)
If `True`, the zen-wrapped function - and the `pre_call` function, if
specified - is run in a copied :py:class:`contextvars.Context`; i.e.
changes made to any :py:class:`contextvars.ContextVar` will be isolated to
the context of the call to the wrapped function.
`run_in_context` is not supported for async functions.
ZenWrapper : Type[hydra_zen.wrapper.Zen], optional (default=Zen)
If specified, a subclass of `Zen` that customizes the behavior of the wrapper.
@@ -746,6 +788,7 @@ def task(x: int, y: int):
exclude=exclude,
unpack_kwargs=unpack_kwargs,
resolve_pre_call=resolve_pre_call,
run_in_context=run_in_context,
),
)

@@ -758,6 +801,7 @@ def wrap(f: Callable[P2, R2]) -> Zen[P2, R2]:
exclude=exclude,
unpack_kwargs=unpack_kwargs,
resolve_pre_call=resolve_pre_call,
run_in_context=run_in_context,
),
)
return out
73 changes: 73 additions & 0 deletions tests/test_zen_context_isolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2023 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT

import random
from contextvars import ContextVar
from typing import Any, Dict, Optional

import pytest

from hydra_zen import zen

config: ContextVar[Optional[Dict[str, Any]]] = ContextVar("config", default=None)
var: ContextVar[Dict[str, Any]] = ContextVar("var", default=dict())


@pytest.fixture(autouse=True)
def clean_context_vars():
assert config.get() is None
assert var.get() == {}
yield
config.set(None)
var.set({})


@pytest.mark.parametrize(
"run_in_context",
[
True,
pytest.param(False, marks=pytest.mark.xfail),
],
)
def test_context_isolation(run_in_context: bool):
def foo(x: str, zen_cfg):
config.set(zen_cfg)
conf = var.get().copy()
conf[str(random.randint(1, 100))] = random.randint(1, 100)
var.set(conf)
assert len(conf) == 1

zfoo = zen(foo, run_in_context=run_in_context)

for letter in "ab":
zfoo(dict(x=letter))
assert config.get() is None
assert var.get() == dict()


async def async_func_run_in_context_not_supported():
async def foo():
...

with pytest.raises(TypeError, match="not supported"):
zen(foo, run_in_context=True)


@pytest.mark.parametrize(
"run_in_context",
[
True,
pytest.param(False, marks=pytest.mark.xfail),
],
)
def test_pre_call_shares_context_with_wrapped_func(run_in_context: bool):
assert var.get() == {}

def pre_call(cfg):
var.set({"swagger": 22})

def func():
assert var.get() == {"swagger": 22}

zen(func, pre_call=pre_call, run_in_context=run_in_context)({})
assert var.get() == {}

0 comments on commit b18aa86

Please sign in to comment.