Skip to content

Commit

Permalink
Merge pull request #595 from mit-ll-responsible-ai/index-zen-exclude
Browse files Browse the repository at this point in the history
Support parameter-indices in zen-exclude
  • Loading branch information
rsokl authored Nov 24, 2023
2 parents 20a036e + 537956b commit 83662a1
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 16 deletions.
2 changes: 1 addition & 1 deletion docs/source/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ For more details and examples, see :pull:`553`.
Improvements
------------
- :class:`~hydra_zen.BuildsFn` was introduced to permit customizable auto-config and type-refinement support in config-creation functions. See :pull:`553`.
- :func:`~hydra_zen.builds` and :func:`~hydra_zen.make_custom_builds_fn` now accept a `zen_exclude` field for excluding parameters from auto-population, either by name or by pattern. See :pull:`558`.
- :func:`~hydra_zen.builds` and :func:`~hydra_zen.make_custom_builds_fn` now accept a `zen_exclude` field for excluding parameters from auto-population, either by name, position-index, or by pattern. See :pull:`558`.
- :func:`~hydra_zen.builds` and :func:`~hydra_zen.just` can now configure static methods. Previously the incorrect ``_target_`` would be resolved. See :pull:`566`
- :func:`hydra_zen.zen` now has first class support for running code in an isolated :py:class:`contextvars.Context`. This enables users to safely leverage state via :py:class:`contextvars.ContextVar` in their task functions. See :pull:`583`.
- Adds formal support for Python 3.12. See :pull:`555`
Expand Down
49 changes: 36 additions & 13 deletions src/hydra_zen/structured_configs/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,9 +1556,9 @@ def builds(
This option is not available for objects with inaccessible signatures, such
as NumPy's various ufuncs.
zen_exclude : Collection[str] | Callable[[str], bool], optional (default=[])
Specifies parameter names, or a function for checking names, to exclude
those parameters from the config-creation process.
zen_exclude : Collection[str | int] | Callable[[str], bool], optional (default=[])
Specifies parameter names and/or indices, or a function for checking names,
to exclude those parameters from the config-creation process.
Note that inherited fields cannot be excluded.
zen_convert : Optional[ZenConvert]
Expand Down Expand Up @@ -1815,15 +1815,22 @@ def builds(self,target, populate_full_signature=False, **kw):
x: ???
'y': foo
`zen_exclude` can be used to either name parameter to be excluded from the
`zen_exclude` can be used to name parameter to be excluded from the
auto-population process:
>>> Conf2 = builds(bar, populate_full_signature=True, zen_exclude=["y"])
>>> pyaml(Conf2)
_target_: __main__.bar
x: ???
or specify a pattern - via a function - for excluding parameters:
to exclude parameters by index:
>>> Conf2 = builds(bar, populate_full_signature=True, zen_exclude=[-1])
>>> pyaml(Conf2)
_target_: __main__.bar
x: ???
or to specify a pattern - via a function - for excluding parameters:
>>> Conf3 = builds(bar, populate_full_signature=True,
... zen_exclude=lambda name: name.startswith("x"))
Expand Down Expand Up @@ -2027,20 +2034,32 @@ def builds(self,target, populate_full_signature=False, **kw):

del pos_args

zen_exclude: Callable[[str], bool] = kwargs_for_target.pop(
"zen_exclude", frozenset()
)
zen_exclude: Union[
Callable[[str], bool], Collection[Union[str, int]]
] = kwargs_for_target.pop("zen_exclude", frozenset())
zen_index_exclude: set[int] = set()

if (
not isinstance(zen_exclude, Collection) or isinstance(zen_exclude, str)
) and not callable(zen_exclude):
raise TypeError(
f"`zen_exclude` must be a non-string collection of strings, or "
f"callable[[str], bool]. Got {zen_exclude}"
f"`zen_exclude` must be a non-string collection of strings and/or ints"
f" or callable[[str], bool]. Got {zen_exclude}"
)

if isinstance(zen_exclude, Collection):
zen_exclude = set(zen_exclude).__contains__
_strings = []
for item in zen_exclude:
if isinstance(item, int):
zen_index_exclude.add(item)
elif isinstance(item, str):
_strings.append(item)
else:
raise TypeError(
f"`zen_exclude` must only contain ints or "
f"strings. Got {zen_exclude}"
)
zen_exclude = frozenset(_strings).__contains__

if not callable(target):
raise TypeError(
Expand All @@ -2050,7 +2069,8 @@ def builds(self,target, populate_full_signature=False, **kw):

if not isinstance(populate_full_signature, bool):
raise TypeError(
f"`populate_full_signature` must be a boolean type, got: {populate_full_signature}"
f"`populate_full_signature` must be a boolean type, got: "
f"{populate_full_signature}"
)

if zen_partial is not None and not isinstance(zen_partial, bool):
Expand Down Expand Up @@ -2621,6 +2641,9 @@ def builds(self,target, populate_full_signature=False, **kw):
if not zen_exclude(name)
}

# support negative indices
zen_index_exclude = {ind % len(signature_params) for ind in zen_index_exclude}

if populate_full_signature is True:
# Populate dataclass fields based on the target's signature.
#
Expand All @@ -2639,7 +2662,7 @@ def builds(self,target, populate_full_signature=False, **kw):
_seen: Set[str] = set()

for n, param in enumerate(signature_params.values()):
if zen_exclude(param.name):
if n in zen_index_exclude or zen_exclude(param.name):
continue

if n + 1 <= len(_pos_args):
Expand Down
13 changes: 11 additions & 2 deletions tests/test_zen_exclude.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from hydra_zen import builds, instantiate, make_custom_builds_fn


@pytest.mark.parametrize("bad_exclude", [1, "x"])
@pytest.mark.parametrize("bad_exclude", ["x", [["x"]]])
def test_validate_exclude(bad_exclude):
with pytest.raises(TypeError):
builds(dict, zen_exclude=bad_exclude)
Expand All @@ -17,7 +17,16 @@ def foo(x=1, _y=2, _z=3):

@pytest.mark.parametrize("partial", [True, False])
@pytest.mark.parametrize("custom_builds", [True, False])
@pytest.mark.parametrize("exclude", [["_y", "_z"], lambda x: x.startswith("_")])
@pytest.mark.parametrize(
"exclude",
[
["_y", "_z"],
lambda x: x.startswith("_"),
[1, 2],
["_y", -1],
["_y", 2],
],
)
def test_exclude_named(partial: bool, custom_builds: bool, exclude):
if custom_builds:
b = make_custom_builds_fn(
Expand Down

0 comments on commit 83662a1

Please sign in to comment.