diff --git a/src/hydra_zen/structured_configs/_implementations.py b/src/hydra_zen/structured_configs/_implementations.py index 9217ad619..0d98e6f47 100644 --- a/src/hydra_zen/structured_configs/_implementations.py +++ b/src/hydra_zen/structured_configs/_implementations.py @@ -560,6 +560,10 @@ def _check_instance(*target_types: str, value: "Any", module: str): # pragma: n _check_instance, "required", module="torch.optim.optimizer" ) +_is_pydantic_BaseModel = functools.partial( + _check_instance, "BaseModel", module="pydantic" +) + def _check_for_dynamically_defined_dataclass_type(target_path: str, value: Any) -> None: if target_path.startswith("types."): @@ -1096,13 +1100,16 @@ def _make_hydra_compatible( pydantic = sys.modules.get("pydantic") if pydantic is not None: # pragma: no cover - if isinstance(value, pydantic.fields.FieldInfo): + if _check_instance("FieldInfo", module="pydantic.fields", value=value): _val = ( value.default_factory() # type: ignore if value.default_factory is not None # type: ignore else value.default # type: ignore ) - if isinstance(_val, pydantic.fields.UndefinedType): + + if _check_instance( + "UndefinedType", module="pydantic.fields", value=_val + ): return MISSING return cls._make_hydra_compatible( @@ -1115,11 +1122,11 @@ def _make_hydra_compatible( hydra_convert=hydra_convert, hydra_recursive=hydra_recursive, ) - if isinstance(value, pydantic.BaseModel): + if _is_pydantic_BaseModel(value=value): return cls.builds(type(value), **value.__dict__) - if isinstance(value, str) or ( - pydantic is not None and isinstance(value, pydantic.AnyUrl) + if isinstance(value, str) or _check_instance( + "AnyUrl", module="pydantic", value=value ): # Supports pydantic.AnyURL _v = str(value) diff --git a/src/hydra_zen/wrapper/_implementations.py b/src/hydra_zen/wrapper/_implementations.py index 63fa6d1fe..e94069262 100644 --- a/src/hydra_zen/wrapper/_implementations.py +++ b/src/hydra_zen/wrapper/_implementations.py @@ -9,6 +9,7 @@ from functools import partial, wraps from inspect import Parameter, iscoroutinefunction, signature from typing import ( + TYPE_CHECKING, Any, Callable, DefaultDict, @@ -48,7 +49,7 @@ from hydra_zen import instantiate from hydra_zen._compatibility import HYDRA_VERSION, Version from hydra_zen.errors import HydraZenValidationError -from hydra_zen.structured_configs._implementations import BuildsFn, DefaultBuilds +from hydra_zen.structured_configs._implementations import DefaultBuilds from hydra_zen.structured_configs._type_guards import safe_getattr from hydra_zen.typing._implementations import ( DataClass_, @@ -61,9 +62,12 @@ from ..structured_configs._type_guards import is_dataclass from ..structured_configs._utils import safe_name +if TYPE_CHECKING: + from hydra_zen import BuildsFn + + __all__ = ["zen", "store", "Zen"] -get_obj_path = BuildsFn._get_obj_path # type: ignore R = TypeVar("R") P = ParamSpec("P") @@ -830,7 +834,7 @@ def default_to_config( ListConfig, DictConfig, ], - BuildsFn: Type[BuildsFn[Any]] = DefaultBuilds, + CustomBuildsFn: Type["BuildsFn[Any]"] = DefaultBuilds, **kw: Any, ) -> Union[DataClass_, Type[DataClass_], ListConfig, DictConfig]: """Creates a config that describes `target`. @@ -849,7 +853,7 @@ def default_to_config( ---------- target : Callable[..., Any] | DataClass | Type[DataClass] | list | dict - BuildsFn : Type[BuildsFn[Any]], optional (default=DefaultBuilds) + CustomBuildsFn : Type[BuildsFn[Any]], optional (default=DefaultBuilds) Provides the config-creation functions (`builds`, `just`) used by this function. @@ -893,21 +897,20 @@ def default_to_config( 'y': ??? """ + kw = kw.copy() + if is_dataclass(target): if isinstance(target, type): if issubclass(target, HydraConf): # don't auto-config HydraConf return target - if not kw and get_obj_path(target).startswith("types."): + if not kw and CustomBuildsFn._get_obj_path(target).startswith("types."): # type: ignore # handles dataclasses returned by make_config() return target - return BuildsFn.builds( - target, - **kw, - populate_full_signature=True, - builds_bases=(target,), - ) + kw.setdefault("populate_full_signature", True) + kw.setdefault("builds_bases", (target,)) + return CustomBuildsFn.builds(target, **kw) if kw: raise ValueError( "store(, [...]) does not support specifying " @@ -917,14 +920,13 @@ def default_to_config( elif isinstance(target, (dict, list)): # TODO: convert to OmegaConf containers? - return BuildsFn.just(target) + return CustomBuildsFn.just(target) elif isinstance(target, (DictConfig, ListConfig)): return target else: t = cast(Callable[..., Any], target) - return cast( - Type[DataClass_], BuildsFn.builds(t, **kw, populate_full_signature=True) - ) + kw.setdefault("populate_full_signature", True) + return cast(Type[DataClass_], CustomBuildsFn.builds(t, **kw)) class _HasName(Protocol): diff --git a/tests/test_BuildsFn.py b/tests/test_BuildsFn.py index d6e773436..c418dcbb7 100644 --- a/tests/test_BuildsFn.py +++ b/tests/test_BuildsFn.py @@ -167,7 +167,7 @@ def test_zen_field(): def test_default_to_config(): store = ZenStore("my store")( - to_config=partial(default_to_config, BuildsFn=MyBuildsFn) + to_config=partial(default_to_config, CustomBuildsFn=MyBuildsFn) ) store(A, x=A(x=2), name="blah") assert instantiate(store[None, "blah"]) == A(x=A(x=2)) diff --git a/tests/test_store.py b/tests/test_store.py index 995428744..7ddebd4ee 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -24,6 +24,7 @@ just, make_config, store as default_store, + to_yaml, ) from tests.custom_strategies import new_stores, store_entries @@ -982,3 +983,19 @@ def test_merge(): assert s3._queue == {(None, "a"), (None, "b"), (None, "c")} assert s3._internal_repo[None, "a"] is not s1._internal_repo[None, "a"] assert s3._internal_repo[None, "b"] is not s2._internal_repo[None, "b"] + + +def test_disable_pop_sig_autoconfig(): + s = ZenStore() + s(ZenStore, populate_full_signature=False, name="s") + config = s[None, "s"] + assert len(to_yaml(config).splitlines()) == 1 + sout = instantiate(s[None, "s"]) + assert isinstance(sout, ZenStore) + + s2 = ZenStore() + s2(ZenStore, name="s") + config2 = s2[None, "s"] + assert len(to_yaml(config2).splitlines()) > 1 + sout2 = instantiate(s2[None, "s"]) + assert isinstance(sout2, ZenStore) diff --git a/tests/test_third_party/test_using_v1_pydantic.py b/tests/test_third_party/test_using_v1_pydantic.py index ed0a10d59..e28070425 100644 --- a/tests/test_third_party/test_using_v1_pydantic.py +++ b/tests/test_third_party/test_using_v1_pydantic.py @@ -1,6 +1,7 @@ # Copyright (c) 2023 Massachusetts Institute of Technology # SPDX-License-Identifier: MIT import dataclasses +import sys from typing import Any, List, Optional import hypothesis.strategies as st @@ -28,6 +29,12 @@ ) +def test_BaseModel(): + _pydantic = sys.modules.get("pydantic") + assert _pydantic is not None + assert _pydantic.BaseModel is BaseModel + + @parametrize_pydantic_fields def test_pydantic_specific_fields_function(custom_type, good_val, bad_val): def f(x):