Skip to content

Commit

Permalink
improve validation
Browse files Browse the repository at this point in the history
  • Loading branch information
rsokl committed Nov 12, 2023
1 parent 70ca0d8 commit 8f88b40
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/hydra_zen/wrapper/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def __init__(
raise TypeError(f"`{run_in_context=} is not supported for async functions.")

self.func: Callable[P, R] = __func
self._run_in_context: bool = run_in_context
try:
# Must cast to dict so that `self` is pickle-compatible.
self.parameters: Mapping[str, Parameter] = dict(
Expand All @@ -204,11 +203,18 @@ def __init__(
raise TypeError(
f"`resolve_pre_call` must be type `bool` got {resolve_pre_call}"
)

if not isinstance(run_in_context, bool): # pragma: no cover
raise TypeError(
f"`run_in_context` must be type `bool` got {run_in_context}"
)

self._resolve = resolve_pre_call
self._unpack_kwargs: bool = unpack_kwargs and any(
p.kind is p.VAR_KEYWORD for p in self.parameters.values()
)

self._run_in_context: bool = run_in_context
self._exclude: Set[str]

if exclude is None:
Expand Down Expand Up @@ -237,7 +243,13 @@ def __init__(
if _f is None:
continue

_f_params = signature(_f).parameters
if run_in_context and isinstance(_f, Zen) and _f._run_in_context:
raise HydraZenValidationError(
f"zen-wrapped pre_call function {_f!r} cannot specify "
f"`run_in_context=True` when the main wrapper specifies it as well."
)

_f_params = signature(_f).parameters # type: ignore

if (sum(p.default is p.empty for p in _f_params.values()) > 1) or len(
_f_params
Expand Down
14 changes: 14 additions & 0 deletions tests/test_zen_context_isolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from hydra_zen import zen
from hydra_zen.errors import HydraZenValidationError

config: ContextVar[Optional[Dict[str, Any]]] = ContextVar("config", default=None)
var: ContextVar[Dict[str, Any]] = ContextVar("var", default=dict())
Expand Down Expand Up @@ -71,3 +72,16 @@ def func():

zen(func, pre_call=pre_call, run_in_context=run_in_context)({})
assert var.get() == {}


def test_pre_call_run_in_its_own_context_is_forbidden():
def f(x):
...

with pytest.raises(HydraZenValidationError):
zen(f, pre_call=zen(f, run_in_context=True), run_in_context=True)


def test_validation():
with pytest.raises(TypeError, match="must be type"):
zen(lambda x: x, run_in_context=None) # type: ignore

0 comments on commit 8f88b40

Please sign in to comment.