Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix store autoconfig issue #588

Merged
merged 7 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/hydra_zen/structured_configs/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,10 @@ def _check_instance(*target_types: str, value: "Any", module: str): # pragma: n
_check_instance, "required", module="torch.optim.optimizer"
)

_is_pydantic_BaseModel = functools.partial(
_check_instance, "BaseModel", module="pydantic"
)


def _check_for_dynamically_defined_dataclass_type(target_path: str, value: Any) -> None:
if target_path.startswith("types."):
Expand Down Expand Up @@ -1096,13 +1100,16 @@ def _make_hydra_compatible(
pydantic = sys.modules.get("pydantic")

if pydantic is not None: # pragma: no cover
if isinstance(value, pydantic.fields.FieldInfo):
if _check_instance("FieldInfo", module="pydantic.fields", value=value):
_val = (
value.default_factory() # type: ignore
if value.default_factory is not None # type: ignore
else value.default # type: ignore
)
if isinstance(_val, pydantic.fields.UndefinedType):

if _check_instance(
"UndefinedType", module="pydantic.fields", value=_val
):
return MISSING

return cls._make_hydra_compatible(
Expand All @@ -1115,11 +1122,11 @@ def _make_hydra_compatible(
hydra_convert=hydra_convert,
hydra_recursive=hydra_recursive,
)
if isinstance(value, pydantic.BaseModel):
if _is_pydantic_BaseModel(value=value):
return cls.builds(type(value), **value.__dict__)

if isinstance(value, str) or (
pydantic is not None and isinstance(value, pydantic.AnyUrl)
if isinstance(value, str) or _check_instance(
"AnyUrl", module="pydantic", value=value
):
# Supports pydantic.AnyURL
_v = str(value)
Expand Down
32 changes: 17 additions & 15 deletions src/hydra_zen/wrapper/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import partial, wraps
from inspect import Parameter, iscoroutinefunction, signature
from typing import (
TYPE_CHECKING,
Any,
Callable,
DefaultDict,
Expand Down Expand Up @@ -48,7 +49,7 @@
from hydra_zen import instantiate
from hydra_zen._compatibility import HYDRA_VERSION, Version
from hydra_zen.errors import HydraZenValidationError
from hydra_zen.structured_configs._implementations import BuildsFn, DefaultBuilds
from hydra_zen.structured_configs._implementations import DefaultBuilds
from hydra_zen.structured_configs._type_guards import safe_getattr
from hydra_zen.typing._implementations import (
DataClass_,
Expand All @@ -61,9 +62,12 @@
from ..structured_configs._type_guards import is_dataclass
from ..structured_configs._utils import safe_name

if TYPE_CHECKING:
from hydra_zen import BuildsFn


__all__ = ["zen", "store", "Zen"]

get_obj_path = BuildsFn._get_obj_path # type: ignore

R = TypeVar("R")
P = ParamSpec("P")
Expand Down Expand Up @@ -830,7 +834,7 @@ def default_to_config(
ListConfig,
DictConfig,
],
BuildsFn: Type[BuildsFn[Any]] = DefaultBuilds,
CustomBuildsFn: Type["BuildsFn[Any]"] = DefaultBuilds,
**kw: Any,
) -> Union[DataClass_, Type[DataClass_], ListConfig, DictConfig]:
"""Creates a config that describes `target`.
Expand All @@ -849,7 +853,7 @@ def default_to_config(
----------
target : Callable[..., Any] | DataClass | Type[DataClass] | list | dict

BuildsFn : Type[BuildsFn[Any]], optional (default=DefaultBuilds)
CustomBuildsFn : Type[BuildsFn[Any]], optional (default=DefaultBuilds)
Provides the config-creation functions (`builds`, `just`) used
by this function.

Expand Down Expand Up @@ -893,21 +897,20 @@ def default_to_config(
'y': ???
"""

kw = kw.copy()

if is_dataclass(target):
if isinstance(target, type):
if issubclass(target, HydraConf):
# don't auto-config HydraConf
return target

if not kw and get_obj_path(target).startswith("types."):
if not kw and CustomBuildsFn._get_obj_path(target).startswith("types."): # type: ignore
# handles dataclasses returned by make_config()
return target
return BuildsFn.builds(
target,
**kw,
populate_full_signature=True,
builds_bases=(target,),
)
kw.setdefault("populate_full_signature", True)
kw.setdefault("builds_bases", (target,))
return CustomBuildsFn.builds(target, **kw)
if kw:
raise ValueError(
"store(<dataclass-instance>, [...]) does not support specifying "
Expand All @@ -917,14 +920,13 @@ def default_to_config(

elif isinstance(target, (dict, list)):
# TODO: convert to OmegaConf containers?
return BuildsFn.just(target)
return CustomBuildsFn.just(target)
elif isinstance(target, (DictConfig, ListConfig)):
return target
else:
t = cast(Callable[..., Any], target)
return cast(
Type[DataClass_], BuildsFn.builds(t, **kw, populate_full_signature=True)
)
kw.setdefault("populate_full_signature", True)
return cast(Type[DataClass_], CustomBuildsFn.builds(t, **kw))


class _HasName(Protocol):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_BuildsFn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_zen_field():

def test_default_to_config():
store = ZenStore("my store")(
to_config=partial(default_to_config, BuildsFn=MyBuildsFn)
to_config=partial(default_to_config, CustomBuildsFn=MyBuildsFn)
)
store(A, x=A(x=2), name="blah")
assert instantiate(store[None, "blah"]) == A(x=A(x=2))
17 changes: 17 additions & 0 deletions tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
just,
make_config,
store as default_store,
to_yaml,
)
from tests.custom_strategies import new_stores, store_entries

Expand Down Expand Up @@ -982,3 +983,19 @@ def test_merge():
assert s3._queue == {(None, "a"), (None, "b"), (None, "c")}
assert s3._internal_repo[None, "a"] is not s1._internal_repo[None, "a"]
assert s3._internal_repo[None, "b"] is not s2._internal_repo[None, "b"]


def test_disable_pop_sig_autoconfig():
s = ZenStore()
s(ZenStore, populate_full_signature=False, name="s")
config = s[None, "s"]
assert len(to_yaml(config).splitlines()) == 1
sout = instantiate(s[None, "s"])
assert isinstance(sout, ZenStore)

s2 = ZenStore()
s2(ZenStore, name="s")
config2 = s2[None, "s"]
assert len(to_yaml(config2).splitlines()) > 1
sout2 = instantiate(s2[None, "s"])
assert isinstance(sout2, ZenStore)
7 changes: 7 additions & 0 deletions tests/test_third_party/test_using_v1_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2023 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
import dataclasses
import sys
from typing import Any, List, Optional

import hypothesis.strategies as st
Expand Down Expand Up @@ -28,6 +29,12 @@
)


def test_BaseModel():
_pydantic = sys.modules.get("pydantic")
assert _pydantic is not None
assert _pydantic.BaseModel is BaseModel


@parametrize_pydantic_fields
def test_pydantic_specific_fields_function(custom_type, good_val, bad_val):
def f(x):
Expand Down