Skip to content

Commit

Permalink
Merge pull request #587 from mit-ll-responsible-ai/pydantic
Browse files Browse the repository at this point in the history
Add rudimentary support for pydantic.BaseModel and for pydantic 2.0
  • Loading branch information
rsokl authored Nov 14, 2023
2 parents 03e4508 + bce0a7a commit e31c388
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 31 deletions.
15 changes: 15 additions & 0 deletions .github/workflows/tox_run.yml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,21 @@ jobs:
- name: Test with tox
run: tox -e pre-release

smoke-test-pydantic-v2:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install tox tox-gh-actions
- name: Test with tox
run: tox -e pydantic-v2p0-smoketest

check-repo-format:
runs-on: ubuntu-latest
steps:
Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,14 @@ deps = {[testenv]deps}
pydantic<2.0.0
beartype
[testenv:pydantic-v2p0-smoketest]
description = Ensures that importing pydantic 2.0 doesn't break things
install_command = pip install --upgrade --upgrade-strategy eager {opts} {packages}
basepython = python3.10
deps = {[testenv]deps}
pydantic>=2.0.0
[testenv:pyright]
description = Ensure that hydra-zen's source code and test suite scan clean
under pyright, and that hydra-zen's public API has a 100 prcnt
Expand Down
70 changes: 45 additions & 25 deletions src/hydra_zen/structured_configs/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,27 +1094,33 @@ def _make_hydra_compatible(

# pydantic objects
pydantic = sys.modules.get("pydantic")
if pydantic is not None and isinstance(value, pydantic.fields.FieldInfo):
_val = (
value.default_factory() # type: ignore
if value.default_factory is not None # type: ignore
else value.default # type: ignore
)
if isinstance(_val, pydantic.fields.UndefinedType):
return MISSING

return cls._make_hydra_compatible(
_val,
allow_zen_conversion=allow_zen_conversion,
error_prefix=error_prefix,
field_name=field_name,
structured_conf_permitted=structured_conf_permitted,
convert_dataclass=convert_dataclass,
hydra_convert=hydra_convert,
hydra_recursive=hydra_recursive,
)

if isinstance(value, str):
if pydantic is not None: # pragma: no cover
if isinstance(value, pydantic.fields.FieldInfo):
_val = (
value.default_factory() # type: ignore
if value.default_factory is not None # type: ignore
else value.default # type: ignore
)
if isinstance(_val, pydantic.fields.UndefinedType):
return MISSING

return cls._make_hydra_compatible(
_val,
allow_zen_conversion=allow_zen_conversion,
error_prefix=error_prefix,
field_name=field_name,
structured_conf_permitted=structured_conf_permitted,
convert_dataclass=convert_dataclass,
hydra_convert=hydra_convert,
hydra_recursive=hydra_recursive,
)
if isinstance(value, pydantic.BaseModel):
return cls.builds(type(value), **value.__dict__)

if isinstance(value, str) or (
pydantic is not None and isinstance(value, pydantic.AnyUrl)
):
# Supports pydantic.AnyURL
_v = str(value)
if type(_v) is str: # pragma: no branch
Expand Down Expand Up @@ -2363,7 +2369,7 @@ def builds(self,target, populate_full_signature=False, **kw):
)

_sig_target = cls._get_sig_obj(target)

pydantic = sys.modules.get("pydantic")
try:
# We want to rely on `inspect.signature` logic for raising
# against an uninspectable sig, before we start inspecting
Expand Down Expand Up @@ -2395,7 +2401,14 @@ def builds(self,target, populate_full_signature=False, **kw):
# has inherited from a parent that implements __new__ and
# the target implements only __init__.

if _sig_target is not target:
if pydantic is not None and (
_sig_target is pydantic.BaseModel.__init__
# pydantic v2.0
or is_dataclass(target)
and hasattr(target, "__pydantic_config__")
):
pass
elif _sig_target is not target:
_params = tuple(inspect.signature(_sig_target).parameters.items())

if (
Expand All @@ -2414,19 +2427,26 @@ def builds(self,target, populate_full_signature=False, **kw):

target_has_valid_signature: bool = True

if is_dataclass(target):
if is_dataclass(target) or (
pydantic is not None
and isinstance(target, type)
and issubclass(target, pydantic.BaseModel)
):
# Update `signature_params` so that any param with `default=<factory>`
# has its default replaced with `<factory>()`
# If this is a mutable value, `builds` will automatically re-pack
# it using a default factory
_fields = {f.name: f for f in fields(target)}
if is_dataclass(target):
_fields = {f.name: f for f in fields(target)}
else:
_fields = target.__fields__ # type: ignore
_update = {}
for name, param in signature_params.items():
if name not in _fields:
# field is InitVar
continue
f = _fields[name]
if f.default_factory is not MISSING:
if f.default_factory is not MISSING and f.default_factory is not None:
_update[name] = inspect.Parameter(
name,
param.kind,
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2023 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
import importlib
import logging
import os
import sys
Expand Down Expand Up @@ -37,6 +38,12 @@
for _module_name in OPTIONAL_TEST_DEPENDENCIES:
if _module_name not in _installed:
collect_ignore_glob.append(f"*{_module_name}*.py")
else:
# Some of hydra-zen's logic for supporting 3rd party libraries
# depends on that library being imported. We want to ensure that
# we import these when they are available so that the full test
# suite runs against these paths being enabled
importlib.import_module(_module_name)

if sys.version_info > (3, 6):
collect_ignore_glob.append("*py36*")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
from typing import Any, List, Optional

import hypothesis.strategies as st
import pydantic
import pytest
from hydra.errors import InstantiationException
from hypothesis import given, settings
from omegaconf import OmegaConf
from pydantic import AnyUrl, Field, PositiveFloat
from pydantic import AnyUrl, BaseModel, Field, PositiveFloat
from pydantic.dataclasses import dataclass as pyd_dataclass
from typing_extensions import Literal

from hydra_zen import builds, instantiate, just, to_yaml
from hydra_zen import builds, get_target, instantiate, just, to_yaml
from hydra_zen.third_party.pydantic import validates_with_pydantic

if pydantic.__version__.startswith("2."):
pytest.skip("These tests are for pydantic v1", allow_module_level=True)

parametrize_pydantic_fields = pytest.mark.parametrize(
"custom_type, good_val, bad_val",
[
Expand Down Expand Up @@ -86,12 +90,19 @@ class PydanticConf:
y: int = 2


class BaseModelConf(BaseModel):
x: Literal[1, 2]
y: int = 2


@pytest.mark.parametrize("Target", [PydanticConf, BaseModelConf])
@pytest.mark.parametrize("x", [1, 2])
def test_documented_example_passes(x):
HydraConf = builds(PydanticConf, populate_full_signature=True)
def test_documented_example_passes(Target, x):
HydraConf = builds(Target, populate_full_signature=True)
conf = instantiate(HydraConf, x=x)
assert isinstance(conf, PydanticConf)
assert conf == PydanticConf(x=x, y=2)
assert isinstance(conf, Target)
assert conf == Target(x=x, y=2)
assert get_target(HydraConf) is Target


@settings(max_examples=20)
Expand Down Expand Up @@ -151,13 +162,25 @@ class HasDefaultFactory:
x: Any = Field(default_factory=lambda: [1 + 2j])


class BaseModelHasDefault(BaseModel):
x: int = Field(default=1)


class BaseModelHasDefaultFactory(BaseModel):
x: Any = Field(default_factory=lambda: [1 + 2j])


@pytest.mark.parametrize(
"target,kwargs",
[
(HasDefault, {}),
(BaseModelHasDefault, {}),
(HasDefaultFactory, {}),
(BaseModelHasDefaultFactory, {}),
(HasDefault, {"x": 12}),
(BaseModelHasDefault, {"x": 12}),
(HasDefaultFactory, {"x": [[-2j, 1 + 1j]]}),
(BaseModelHasDefaultFactory, {"x": [[-2j, 1 + 1j]]}),
],
)
def test_pop_sig_with_pydantic_Field(target, kwargs):
Expand Down Expand Up @@ -186,6 +209,25 @@ def test_nested_dataclasses(via_yaml: bool):
assert instantiate(conf) == navbar


class ModelNavbarButton(BaseModel):
href: AnyUrl


class ModelNavbar(BaseModel):
button: ModelNavbarButton


@given(...)
def test_nested_base_models(via_yaml: bool):
navbar = ModelNavbar(button=ModelNavbarButton(href="https://example.com")) # type: ignore
conf = just(navbar)
if via_yaml:
# ensure serializable
assert instantiate(OmegaConf.create(to_yaml(conf))) == navbar
else:
assert instantiate(conf) == navbar


@pyd_dataclass
class PydFieldNoDefault:
btwn_0_and_3: int = Field(gt=0, lt=3)
Expand Down

0 comments on commit e31c388

Please sign in to comment.