diff --git a/docs/concepts/function-modifiers.rst b/docs/concepts/function-modifiers.rst
index e74bbcbe1..c3b38c502 100644
--- a/docs/concepts/function-modifiers.rst
+++ b/docs/concepts/function-modifiers.rst
@@ -176,6 +176,12 @@ pandera support
Hamilton has a pandera plugin for data validation that you can install with ``pip install sf-hamilton[pandera]``. Then, you can pass a pandera schema (for DataFrame or Series) to ``@check_output(schema=...)``.
+pydantic support
+~~~~~~~~~~~~~~~~
+
+Hamilton also supports data validation of pydantic models, which can be enabled with ``pip install sf-hamilton[pydantic]``. With pydantic installed, you can pass any subclass of the pydantic base model to ``@check_output(model=...)``. Pydantic validation is performed in strict mode, meaning that raw values will not be coerced to the model's types. For more information on strict mode see the `pydantic docs `_.
+
+
Split node output into *n* nodes
--------------------------------
diff --git a/docs/reference/decorators/check_output.rst b/docs/reference/decorators/check_output.rst
index 9dd86bbb3..5f53ea622 100644
--- a/docs/reference/decorators/check_output.rst
+++ b/docs/reference/decorators/check_output.rst
@@ -27,9 +27,10 @@ Note that you can also specify custom decorators using the ``@check_output_custo
See `data_quality `_ for more information on
available validators and how to build custom ones.
-Note we also have a plugin that allows you to use pandera. There are two ways to access it:
-1. `@check_output(schema=pandera_schema)`
-2. `@h_pandera.check_output()` on a function that declares a typed pandera dataframe as an output
+Note we also have a plugins that allow for validation with the pandera and pydantic libraries. There are two ways to access these:
+
+1. ``@check_output(schema=pandera_schema)`` or ``@check_output(model=pydantic_model)``
+2. ``@h_pandera.check_output()`` or ``@h_pydantic.check_output()`` on the function that declares either a typed dataframe or a pydantic model.
----
@@ -43,3 +44,6 @@ Note we also have a plugin that allows you to use pandera. There are two ways to
.. autoclass:: hamilton.plugins.h_pandera.check_output
:special-members: __init__
+
+.. autoclass:: hamilton.plugins.h_pydantic.check_output
+ :special-members: __init__
diff --git a/hamilton/data_quality/default_validators.py b/hamilton/data_quality/default_validators.py
index ec4b7780d..54c550cde 100644
--- a/hamilton/data_quality/default_validators.py
+++ b/hamilton/data_quality/default_validators.py
@@ -508,6 +508,23 @@ def _append_pandera_to_default_validators():
_append_pandera_to_default_validators()
+def _append_pydantic_to_default_validators():
+ """Utility method to append pydantic validators as needed"""
+ try:
+ import pydantic # noqa: F401
+ except ModuleNotFoundError:
+ logger.debug(
+ "Cannot import pydantic from pydantic_validators. Run pip install sf-hamilton[pydantic] if needed."
+ )
+ return
+ from hamilton.data_quality import pydantic_validators
+
+ AVAILABLE_DEFAULT_VALIDATORS.extend(pydantic_validators.PYDANTIC_VALIDATORS)
+
+
+_append_pydantic_to_default_validators()
+
+
def resolve_default_validators(
output_type: Type[Type],
importance: str,
diff --git a/hamilton/data_quality/pydantic_validators.py b/hamilton/data_quality/pydantic_validators.py
new file mode 100644
index 000000000..a37a1eea2
--- /dev/null
+++ b/hamilton/data_quality/pydantic_validators.py
@@ -0,0 +1,60 @@
+from typing import Any, Type
+
+from pydantic import BaseModel, TypeAdapter, ValidationError
+
+from hamilton.data_quality import base
+from hamilton.htypes import custom_subclass_check
+
+
+class PydanticModelValidator(base.BaseDefaultValidator):
+ """Pydantic model compatibility validator
+
+ Note that this validator uses pydantic's strict mode, which does not allow for
+ coercion of data. This means that if an object does not exactly match the reference
+ type, it will fail validation, regardless of whether it could be coerced into the
+ correct type.
+
+ :param model: Pydantic model to validate against
+ :param importance: Importance of the validator, possible values "warn" and "fail"
+ :param arbitrary_types_allowed: Whether arbitrary types are allowed in the model
+ """
+
+ def __init__(self, model: Type[BaseModel], importance: str):
+ super(PydanticModelValidator, self).__init__(importance)
+ self.model = model
+ self._model_adapter = TypeAdapter(model)
+
+ @classmethod
+ def applies_to(cls, datatype: Type[Type]) -> bool:
+ # In addition to checking for a subclass of BaseModel, we also check for dict
+ # as this is the standard 'de-serialized' format of pydantic models in python
+ return custom_subclass_check(datatype, BaseModel) or custom_subclass_check(datatype, dict)
+
+ def description(self) -> str:
+ return "Validates that the returned object is compatible with the specified pydantic model"
+
+ def validate(self, data: Any) -> base.ValidationResult:
+ try:
+ # Currently, validate can not alter the output data, so we must use
+ # strict=True. The downside to this is that data that could be coerced
+ # into the correct type will fail validation.
+ self._model_adapter.validate_python(data, strict=True)
+ except ValidationError as e:
+ return base.ValidationResult(
+ passes=False, message=str(e), diagnostics={"model_errors": e.errors()}
+ )
+ return base.ValidationResult(
+ passes=True,
+ message=f"Data passes pydantic check for model {str(self.model)}",
+ )
+
+ @classmethod
+ def arg(cls) -> str:
+ return "model"
+
+ @classmethod
+ def name(cls) -> str:
+ return "pydantic_validator"
+
+
+PYDANTIC_VALIDATORS = [PydanticModelValidator]
diff --git a/hamilton/plugins/h_pydantic.py b/hamilton/plugins/h_pydantic.py
new file mode 100644
index 000000000..45d2e131f
--- /dev/null
+++ b/hamilton/plugins/h_pydantic.py
@@ -0,0 +1,111 @@
+from typing import List
+
+from pydantic import BaseModel
+
+from hamilton import node
+from hamilton.data_quality import base as dq_base
+from hamilton.function_modifiers import InvalidDecoratorException
+from hamilton.function_modifiers import base as fm_base
+from hamilton.function_modifiers import check_output as base_check_output
+from hamilton.function_modifiers.validation import BaseDataValidationDecorator
+from hamilton.htypes import custom_subclass_check
+
+
+class check_output(BaseDataValidationDecorator):
+ def __init__(
+ self,
+ importance: str = dq_base.DataValidationLevel.WARN.value,
+ target: fm_base.TargetType = None,
+ ):
+ """Specific output-checker for pydantic models. This decorator utilizes the output type of
+ the function, which can be any subclass of pydantic.BaseModel. The function output must
+ be declared with a type hint.
+
+ :param model: The pydantic model to use for validation. If this is not provided, then the output type of the function is used.
+ :param importance: Importance level (either "warn" or "fail") -- see documentation for check_output for more details.
+ :param target: The target of the decorator -- see documentation for check_output for more details.
+
+ Here is an example of how to use this decorator with a function that returns a pydantic model:
+
+ .. code-block:: python
+
+ from hamilton.plugins import h_pydantic
+ from pydantic import BaseModel
+
+ class MyModel(BaseModel):
+ a: int
+ b: float
+ c: str
+
+ @h_pydantic.check_output()
+ def foo() -> MyModel:
+ return MyModel(a=1, b=2.0, c="hello")
+
+ Alternatively, you can return a dictionary from the function (type checkers will probably
+ complain about this):
+
+ .. code-block:: python
+
+ from hamilton.plugins import h_pydantic
+ from pydantic import BaseModel
+
+ class MyModel(BaseModel):
+ a: int
+ b: float
+ c: str
+
+ @h_pydantic.check_output()
+ def foo() -> MyModel:
+ return {"a": 1, "b": 2.0, "c": "hello"}
+
+ You can also use pydantic validation through ``function_modifiers.check_output`` by
+ providing the model as an argument:
+
+ .. code-block:: python
+
+ from typing import Any
+
+ from hamilton import function_modifiers
+ from pydantic import BaseModel
+
+ class MyModel(BaseModel):
+ a: int
+ b: float
+ c: str
+
+ @function_modifiers.check_output(model=MyModel)
+ def foo() -> dict[str, Any]:
+ return {"a": 1, "b": 2.0, "c": "hello"}
+
+ Note, that because we do not (yet) support modification of the output, the validation is
+ performed in strict mode, meaning that no data coercion is performed. For example, the
+ following function will *fail* validation:
+
+ .. code-block:: python
+
+ from hamilton.plugins import h_pydantic
+ from pydantic import BaseModel
+
+ class MyModel(BaseModel):
+ a: int # Defined as an int
+
+ @h_pydantic.check_output() # This will fail validation!
+ def foo() -> MyModel:
+ return MyModel(a="1") # Assigned as a string
+
+ For more information about strict mode see the pydantic docs: https://docs.pydantic.dev/latest/concepts/strict_mode/
+
+ """
+ super(check_output, self).__init__(target)
+ self.importance = importance
+ self.target = target
+
+ def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValidator]:
+ output_type = node_to_validate.type
+ if not custom_subclass_check(output_type, BaseModel):
+ raise InvalidDecoratorException(
+ f"Output of function {node_to_validate.name} must be a Pydantic model"
+ )
+ return base_check_output(
+ importance=self.importance, model=output_type, target_=self.target
+ ).get_validators(node_to_validate)
diff --git a/pyproject.toml b/pyproject.toml
index a43f4b79d..5660d30a7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -73,6 +73,7 @@ docs = [
"pillow",
"polars",
"pyarrow >= 1.0.0",
+ "pydantic >=2.0",
"pyspark",
"openlineage-python",
"PyYAML",
@@ -99,6 +100,7 @@ packaging = [
"build",
]
pandera = ["pandera"]
+pydantic = ["pydantic>=2.0"]
pyspark = [
# we have to run these dependencies because Spark does not check to ensure the right target was called
"pyspark[pandas_on_spark,sql]"
@@ -129,6 +131,7 @@ test = [
"plotly",
"polars",
"pyarrow",
+ "pydantic >=2.0",
"pyreadstat", # for SPSS data loader
"pytest",
"pytest-asyncio",
@@ -144,10 +147,7 @@ test = [
]
tqdm = ["tqdm"]
ui = ["sf-hamilton-ui"]
-vaex = [
- "pydantic<2.0", # because of https://github.com/vaexio/vaex/issues/2384
- "vaex"
-]
+vaex = ["vaex"]
visualization = ["graphviz", "networkx"]
[project.entry-points.console_scripts]
diff --git a/tests/integrations/pydantic/requirements.txt b/tests/integrations/pydantic/requirements.txt
new file mode 100644
index 000000000..7b8df944a
--- /dev/null
+++ b/tests/integrations/pydantic/requirements.txt
@@ -0,0 +1 @@
+# Additional requirements on top of hamilton...pydantic
diff --git a/tests/integrations/pydantic/test_pydantic_data_quality.py b/tests/integrations/pydantic/test_pydantic_data_quality.py
new file mode 100644
index 000000000..426e6c01a
--- /dev/null
+++ b/tests/integrations/pydantic/test_pydantic_data_quality.py
@@ -0,0 +1,271 @@
+from typing import Any, Dict, List
+
+import pytest
+from pydantic import BaseModel, ValidationError
+
+from hamilton.data_quality.pydantic_validators import PydanticModelValidator
+from hamilton.function_modifiers import check_output
+from hamilton.node import Node
+from hamilton.plugins import h_pydantic
+
+
+def test_basic_pydantic_validator_passes():
+ class DummyModel(BaseModel):
+ value: float
+
+ validator = PydanticModelValidator(model=DummyModel, importance="warn")
+ validation_result = validator.validate({"value": 15.0})
+ assert validation_result.passes
+
+
+def test_basic_pydantic_check_output_passes():
+ class DummyModel(BaseModel):
+ value: float
+
+ @check_output(model=DummyModel, importance="warn")
+ def dummy() -> Dict[str, float]:
+ return {"value": 15.0}
+
+ node = Node.from_fn(dummy)
+ validators = check_output(model=DummyModel).get_validators(node)
+ assert len(validators) == 1
+ validator = validators[0]
+ result_success = validator.validate(node())
+ assert result_success.passes
+
+
+def test_basic_pydantic_validator_fails():
+ class DummyModel(BaseModel):
+ value: float
+
+ validator = PydanticModelValidator(model=DummyModel, importance="warn")
+ validation_result = validator.validate({"value": "15.0"})
+ assert not validation_result.passes
+ assert "value" in validation_result.diagnostics["model_errors"][0]["loc"]
+
+
+def test_basic_pydantic_check_output_fails():
+ class DummyModel(BaseModel):
+ value: float
+
+ @check_output(model=DummyModel, importance="warn")
+ def dummy() -> Dict[str, float]:
+ return {"value": "fifteen"} # type: ignore
+
+ node = Node.from_fn(dummy)
+ validators = check_output(model=DummyModel).get_validators(node)
+ assert len(validators) == 1
+ validator = validators[0]
+ result = validator.validate(node())
+ assert not result.passes
+
+
+def test_pydantic_validator_is_strict():
+ class DummyModel(BaseModel):
+ value: float
+
+ validator = PydanticModelValidator(model=DummyModel, importance="warn")
+ validation_result = validator.validate({"value": "15"})
+ assert not validation_result.passes
+
+
+def test_complex_pydantic_validator_passes():
+ class Owner(BaseModel):
+ name: str
+
+ class Version(BaseModel):
+ name: str
+ id: int
+
+ class Repo(BaseModel):
+ name: str
+ owner: Owner
+ versions: List[Version]
+
+ data = {
+ "name": "hamilton",
+ "owner": {"name": "DAGWorks-Inc"},
+ "versions": [{"name": "0.1.0", "id": 1}, {"name": "0.2.0", "id": 2}],
+ }
+
+ validator = PydanticModelValidator(model=Repo, importance="warn")
+ validation_result = validator.validate(data)
+ assert validation_result.passes
+
+
+def test_complex_pydantic_validator_fails():
+ class Owner(BaseModel):
+ name: str
+
+ class Version(BaseModel):
+ name: str
+ id: int
+
+ class Repo(BaseModel):
+ name: str
+ owner: Owner
+ versions: List[Version]
+
+ data = {
+ "name": "hamilton",
+ "owner": {"name": "DAGWorks-Inc"},
+ "versions": [{"name": "0.1.0", "id": 1}, {"name": "0.2.0", "id": "2"}],
+ }
+
+ validator = PydanticModelValidator(model=Repo, importance="warn")
+ validation_result = validator.validate(data)
+ assert not validation_result.passes
+
+
+def test_complex_pydantic_check_output_passes():
+ class Owner(BaseModel):
+ name: str
+
+ class Version(BaseModel):
+ name: str
+ id: int
+
+ class Repo(BaseModel):
+ name: str
+ owner: Owner
+ versions: List[Version]
+
+ @check_output(model=Repo, importance="warn")
+ def dummy() -> Dict[str, Any]:
+ return {
+ "name": "hamilton",
+ "owner": {"name": "DAGWorks-Inc"},
+ "versions": [{"name": "0.1.0", "id": 1}, {"name": "0.2.0", "id": 2}],
+ }
+
+ node = Node.from_fn(dummy)
+ validators = check_output(model=Repo).get_validators(node)
+ assert len(validators) == 1
+ validator = validators[0]
+ result_success = validator.validate(node())
+ assert result_success.passes
+
+
+def test_complex_pydantic_check_output_fails():
+ class Owner(BaseModel):
+ name: str
+
+ class Version(BaseModel):
+ name: str
+ id: int
+
+ class Repo(BaseModel):
+ name: str
+ owner: Owner
+ versions: List[Version]
+
+ @check_output(model=Repo, importance="warn")
+ def dummy() -> Dict[str, Any]:
+ return {
+ "name": "hamilton",
+ "owner": {"name": "DAGWorks-Inc"},
+ "versions": [
+ {"name": "0.1.0", "id": "one"}, # id should be an int
+ {"name": "0.2.0", "id": "two"}, # id should be an int
+ ],
+ }
+
+ node = Node.from_fn(dummy)
+ validators = check_output(model=Repo).get_validators(node)
+ assert len(validators) == 1
+ validator = validators[0]
+ result = validator.validate(node())
+ assert not result.passes
+
+
+def test_basic_pydantic_plugin_check_output_passes():
+ class DummyModel(BaseModel):
+ value: float
+
+ def dummy() -> DummyModel:
+ return DummyModel(value=15.0)
+
+ node = Node.from_fn(dummy)
+ validators = h_pydantic.check_output().get_validators(node)
+ assert len(validators) == 1
+ validator = validators[0]
+ result_success = validator.validate(node())
+ assert result_success.passes
+
+
+def test_basic_pydantic_plugin_check_output_fails():
+ class DummyModel(BaseModel):
+ value: float
+
+ def dummy() -> DummyModel:
+ return DummyModel(value="fifteen") # type: ignore
+
+ node = Node.from_fn(dummy)
+ validators = h_pydantic.check_output().get_validators(node)
+ assert len(validators) == 1
+ validator = validators[0]
+
+ with pytest.raises(ValidationError):
+ result = validator.validate(node())
+ assert not result.passes
+
+
+def test_complex_pydantic_plugin_check_output_passes():
+ class Owner(BaseModel):
+ name: str
+
+ class Version(BaseModel):
+ name: str
+ id: int
+
+ class Repo(BaseModel):
+ name: str
+ owner: Owner
+ versions: List[Version]
+
+ def dummy() -> Repo:
+ return Repo(
+ name="hamilton",
+ owner=Owner(name="DAGWorks-Inc"),
+ versions=[Version(name="0.1.0", id=1), Version(name="0.2.0", id=2)],
+ )
+
+ node = Node.from_fn(dummy)
+ validators = h_pydantic.check_output().get_validators(node)
+ assert len(validators) == 1
+ validator = validators[0]
+ result_success = validator.validate(node())
+ assert result_success.passes
+
+
+def test_complex_pydantic_plugin_check_output_fails():
+ class Owner(BaseModel):
+ name: str
+
+ class Version(BaseModel):
+ name: str
+ id: int
+
+ class Repo(BaseModel):
+ name: str
+ owner: Owner
+ versions: List[Version]
+
+ def dummy() -> Repo:
+ return Repo(
+ name="hamilton",
+ owner=Owner(name="DAGWorks-Inc"),
+ versions=[
+ Version(name="0.1.0", id=1),
+ Version(name="0.2.0", id="two"), # type: ignore
+ ],
+ )
+
+ node = Node.from_fn(dummy)
+ validators = h_pydantic.check_output().get_validators(node)
+ assert len(validators) == 1
+ validator = validators[0]
+
+ with pytest.raises(ValidationError):
+ result = validator.validate(node())
+ assert not result.passes