diff --git a/docs/concepts/function-modifiers.rst b/docs/concepts/function-modifiers.rst index e74bbcbe1..c3b38c502 100644 --- a/docs/concepts/function-modifiers.rst +++ b/docs/concepts/function-modifiers.rst @@ -176,6 +176,12 @@ pandera support Hamilton has a pandera plugin for data validation that you can install with ``pip install sf-hamilton[pandera]``. Then, you can pass a pandera schema (for DataFrame or Series) to ``@check_output(schema=...)``. +pydantic support +~~~~~~~~~~~~~~~~ + +Hamilton also supports data validation of pydantic models, which can be enabled with ``pip install sf-hamilton[pydantic]``. With pydantic installed, you can pass any subclass of the pydantic base model to ``@check_output(model=...)``. Pydantic validation is performed in strict mode, meaning that raw values will not be coerced to the model's types. For more information on strict mode see the `pydantic docs `_. + + Split node output into *n* nodes -------------------------------- diff --git a/docs/reference/decorators/check_output.rst b/docs/reference/decorators/check_output.rst index 9dd86bbb3..5f53ea622 100644 --- a/docs/reference/decorators/check_output.rst +++ b/docs/reference/decorators/check_output.rst @@ -27,9 +27,10 @@ Note that you can also specify custom decorators using the ``@check_output_custo See `data_quality `_ for more information on available validators and how to build custom ones. -Note we also have a plugin that allows you to use pandera. There are two ways to access it: -1. `@check_output(schema=pandera_schema)` -2. `@h_pandera.check_output()` on a function that declares a typed pandera dataframe as an output +Note we also have a plugins that allow for validation with the pandera and pydantic libraries. There are two ways to access these: + +1. ``@check_output(schema=pandera_schema)`` or ``@check_output(model=pydantic_model)`` +2. ``@h_pandera.check_output()`` or ``@h_pydantic.check_output()`` on the function that declares either a typed dataframe or a pydantic model. ---- @@ -43,3 +44,6 @@ Note we also have a plugin that allows you to use pandera. There are two ways to .. autoclass:: hamilton.plugins.h_pandera.check_output :special-members: __init__ + +.. autoclass:: hamilton.plugins.h_pydantic.check_output + :special-members: __init__ diff --git a/hamilton/data_quality/default_validators.py b/hamilton/data_quality/default_validators.py index ec4b7780d..54c550cde 100644 --- a/hamilton/data_quality/default_validators.py +++ b/hamilton/data_quality/default_validators.py @@ -508,6 +508,23 @@ def _append_pandera_to_default_validators(): _append_pandera_to_default_validators() +def _append_pydantic_to_default_validators(): + """Utility method to append pydantic validators as needed""" + try: + import pydantic # noqa: F401 + except ModuleNotFoundError: + logger.debug( + "Cannot import pydantic from pydantic_validators. Run pip install sf-hamilton[pydantic] if needed." + ) + return + from hamilton.data_quality import pydantic_validators + + AVAILABLE_DEFAULT_VALIDATORS.extend(pydantic_validators.PYDANTIC_VALIDATORS) + + +_append_pydantic_to_default_validators() + + def resolve_default_validators( output_type: Type[Type], importance: str, diff --git a/hamilton/data_quality/pydantic_validators.py b/hamilton/data_quality/pydantic_validators.py new file mode 100644 index 000000000..a37a1eea2 --- /dev/null +++ b/hamilton/data_quality/pydantic_validators.py @@ -0,0 +1,60 @@ +from typing import Any, Type + +from pydantic import BaseModel, TypeAdapter, ValidationError + +from hamilton.data_quality import base +from hamilton.htypes import custom_subclass_check + + +class PydanticModelValidator(base.BaseDefaultValidator): + """Pydantic model compatibility validator + + Note that this validator uses pydantic's strict mode, which does not allow for + coercion of data. This means that if an object does not exactly match the reference + type, it will fail validation, regardless of whether it could be coerced into the + correct type. + + :param model: Pydantic model to validate against + :param importance: Importance of the validator, possible values "warn" and "fail" + :param arbitrary_types_allowed: Whether arbitrary types are allowed in the model + """ + + def __init__(self, model: Type[BaseModel], importance: str): + super(PydanticModelValidator, self).__init__(importance) + self.model = model + self._model_adapter = TypeAdapter(model) + + @classmethod + def applies_to(cls, datatype: Type[Type]) -> bool: + # In addition to checking for a subclass of BaseModel, we also check for dict + # as this is the standard 'de-serialized' format of pydantic models in python + return custom_subclass_check(datatype, BaseModel) or custom_subclass_check(datatype, dict) + + def description(self) -> str: + return "Validates that the returned object is compatible with the specified pydantic model" + + def validate(self, data: Any) -> base.ValidationResult: + try: + # Currently, validate can not alter the output data, so we must use + # strict=True. The downside to this is that data that could be coerced + # into the correct type will fail validation. + self._model_adapter.validate_python(data, strict=True) + except ValidationError as e: + return base.ValidationResult( + passes=False, message=str(e), diagnostics={"model_errors": e.errors()} + ) + return base.ValidationResult( + passes=True, + message=f"Data passes pydantic check for model {str(self.model)}", + ) + + @classmethod + def arg(cls) -> str: + return "model" + + @classmethod + def name(cls) -> str: + return "pydantic_validator" + + +PYDANTIC_VALIDATORS = [PydanticModelValidator] diff --git a/hamilton/plugins/h_pydantic.py b/hamilton/plugins/h_pydantic.py new file mode 100644 index 000000000..45d2e131f --- /dev/null +++ b/hamilton/plugins/h_pydantic.py @@ -0,0 +1,111 @@ +from typing import List + +from pydantic import BaseModel + +from hamilton import node +from hamilton.data_quality import base as dq_base +from hamilton.function_modifiers import InvalidDecoratorException +from hamilton.function_modifiers import base as fm_base +from hamilton.function_modifiers import check_output as base_check_output +from hamilton.function_modifiers.validation import BaseDataValidationDecorator +from hamilton.htypes import custom_subclass_check + + +class check_output(BaseDataValidationDecorator): + def __init__( + self, + importance: str = dq_base.DataValidationLevel.WARN.value, + target: fm_base.TargetType = None, + ): + """Specific output-checker for pydantic models. This decorator utilizes the output type of + the function, which can be any subclass of pydantic.BaseModel. The function output must + be declared with a type hint. + + :param model: The pydantic model to use for validation. If this is not provided, then the output type of the function is used. + :param importance: Importance level (either "warn" or "fail") -- see documentation for check_output for more details. + :param target: The target of the decorator -- see documentation for check_output for more details. + + Here is an example of how to use this decorator with a function that returns a pydantic model: + + .. code-block:: python + + from hamilton.plugins import h_pydantic + from pydantic import BaseModel + + class MyModel(BaseModel): + a: int + b: float + c: str + + @h_pydantic.check_output() + def foo() -> MyModel: + return MyModel(a=1, b=2.0, c="hello") + + Alternatively, you can return a dictionary from the function (type checkers will probably + complain about this): + + .. code-block:: python + + from hamilton.plugins import h_pydantic + from pydantic import BaseModel + + class MyModel(BaseModel): + a: int + b: float + c: str + + @h_pydantic.check_output() + def foo() -> MyModel: + return {"a": 1, "b": 2.0, "c": "hello"} + + You can also use pydantic validation through ``function_modifiers.check_output`` by + providing the model as an argument: + + .. code-block:: python + + from typing import Any + + from hamilton import function_modifiers + from pydantic import BaseModel + + class MyModel(BaseModel): + a: int + b: float + c: str + + @function_modifiers.check_output(model=MyModel) + def foo() -> dict[str, Any]: + return {"a": 1, "b": 2.0, "c": "hello"} + + Note, that because we do not (yet) support modification of the output, the validation is + performed in strict mode, meaning that no data coercion is performed. For example, the + following function will *fail* validation: + + .. code-block:: python + + from hamilton.plugins import h_pydantic + from pydantic import BaseModel + + class MyModel(BaseModel): + a: int # Defined as an int + + @h_pydantic.check_output() # This will fail validation! + def foo() -> MyModel: + return MyModel(a="1") # Assigned as a string + + For more information about strict mode see the pydantic docs: https://docs.pydantic.dev/latest/concepts/strict_mode/ + + """ + super(check_output, self).__init__(target) + self.importance = importance + self.target = target + + def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValidator]: + output_type = node_to_validate.type + if not custom_subclass_check(output_type, BaseModel): + raise InvalidDecoratorException( + f"Output of function {node_to_validate.name} must be a Pydantic model" + ) + return base_check_output( + importance=self.importance, model=output_type, target_=self.target + ).get_validators(node_to_validate) diff --git a/pyproject.toml b/pyproject.toml index a43f4b79d..5660d30a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ docs = [ "pillow", "polars", "pyarrow >= 1.0.0", + "pydantic >=2.0", "pyspark", "openlineage-python", "PyYAML", @@ -99,6 +100,7 @@ packaging = [ "build", ] pandera = ["pandera"] +pydantic = ["pydantic>=2.0"] pyspark = [ # we have to run these dependencies because Spark does not check to ensure the right target was called "pyspark[pandas_on_spark,sql]" @@ -129,6 +131,7 @@ test = [ "plotly", "polars", "pyarrow", + "pydantic >=2.0", "pyreadstat", # for SPSS data loader "pytest", "pytest-asyncio", @@ -144,10 +147,7 @@ test = [ ] tqdm = ["tqdm"] ui = ["sf-hamilton-ui"] -vaex = [ - "pydantic<2.0", # because of https://github.com/vaexio/vaex/issues/2384 - "vaex" -] +vaex = ["vaex"] visualization = ["graphviz", "networkx"] [project.entry-points.console_scripts] diff --git a/tests/integrations/pydantic/requirements.txt b/tests/integrations/pydantic/requirements.txt new file mode 100644 index 000000000..7b8df944a --- /dev/null +++ b/tests/integrations/pydantic/requirements.txt @@ -0,0 +1 @@ +# Additional requirements on top of hamilton...pydantic diff --git a/tests/integrations/pydantic/test_pydantic_data_quality.py b/tests/integrations/pydantic/test_pydantic_data_quality.py new file mode 100644 index 000000000..426e6c01a --- /dev/null +++ b/tests/integrations/pydantic/test_pydantic_data_quality.py @@ -0,0 +1,271 @@ +from typing import Any, Dict, List + +import pytest +from pydantic import BaseModel, ValidationError + +from hamilton.data_quality.pydantic_validators import PydanticModelValidator +from hamilton.function_modifiers import check_output +from hamilton.node import Node +from hamilton.plugins import h_pydantic + + +def test_basic_pydantic_validator_passes(): + class DummyModel(BaseModel): + value: float + + validator = PydanticModelValidator(model=DummyModel, importance="warn") + validation_result = validator.validate({"value": 15.0}) + assert validation_result.passes + + +def test_basic_pydantic_check_output_passes(): + class DummyModel(BaseModel): + value: float + + @check_output(model=DummyModel, importance="warn") + def dummy() -> Dict[str, float]: + return {"value": 15.0} + + node = Node.from_fn(dummy) + validators = check_output(model=DummyModel).get_validators(node) + assert len(validators) == 1 + validator = validators[0] + result_success = validator.validate(node()) + assert result_success.passes + + +def test_basic_pydantic_validator_fails(): + class DummyModel(BaseModel): + value: float + + validator = PydanticModelValidator(model=DummyModel, importance="warn") + validation_result = validator.validate({"value": "15.0"}) + assert not validation_result.passes + assert "value" in validation_result.diagnostics["model_errors"][0]["loc"] + + +def test_basic_pydantic_check_output_fails(): + class DummyModel(BaseModel): + value: float + + @check_output(model=DummyModel, importance="warn") + def dummy() -> Dict[str, float]: + return {"value": "fifteen"} # type: ignore + + node = Node.from_fn(dummy) + validators = check_output(model=DummyModel).get_validators(node) + assert len(validators) == 1 + validator = validators[0] + result = validator.validate(node()) + assert not result.passes + + +def test_pydantic_validator_is_strict(): + class DummyModel(BaseModel): + value: float + + validator = PydanticModelValidator(model=DummyModel, importance="warn") + validation_result = validator.validate({"value": "15"}) + assert not validation_result.passes + + +def test_complex_pydantic_validator_passes(): + class Owner(BaseModel): + name: str + + class Version(BaseModel): + name: str + id: int + + class Repo(BaseModel): + name: str + owner: Owner + versions: List[Version] + + data = { + "name": "hamilton", + "owner": {"name": "DAGWorks-Inc"}, + "versions": [{"name": "0.1.0", "id": 1}, {"name": "0.2.0", "id": 2}], + } + + validator = PydanticModelValidator(model=Repo, importance="warn") + validation_result = validator.validate(data) + assert validation_result.passes + + +def test_complex_pydantic_validator_fails(): + class Owner(BaseModel): + name: str + + class Version(BaseModel): + name: str + id: int + + class Repo(BaseModel): + name: str + owner: Owner + versions: List[Version] + + data = { + "name": "hamilton", + "owner": {"name": "DAGWorks-Inc"}, + "versions": [{"name": "0.1.0", "id": 1}, {"name": "0.2.0", "id": "2"}], + } + + validator = PydanticModelValidator(model=Repo, importance="warn") + validation_result = validator.validate(data) + assert not validation_result.passes + + +def test_complex_pydantic_check_output_passes(): + class Owner(BaseModel): + name: str + + class Version(BaseModel): + name: str + id: int + + class Repo(BaseModel): + name: str + owner: Owner + versions: List[Version] + + @check_output(model=Repo, importance="warn") + def dummy() -> Dict[str, Any]: + return { + "name": "hamilton", + "owner": {"name": "DAGWorks-Inc"}, + "versions": [{"name": "0.1.0", "id": 1}, {"name": "0.2.0", "id": 2}], + } + + node = Node.from_fn(dummy) + validators = check_output(model=Repo).get_validators(node) + assert len(validators) == 1 + validator = validators[0] + result_success = validator.validate(node()) + assert result_success.passes + + +def test_complex_pydantic_check_output_fails(): + class Owner(BaseModel): + name: str + + class Version(BaseModel): + name: str + id: int + + class Repo(BaseModel): + name: str + owner: Owner + versions: List[Version] + + @check_output(model=Repo, importance="warn") + def dummy() -> Dict[str, Any]: + return { + "name": "hamilton", + "owner": {"name": "DAGWorks-Inc"}, + "versions": [ + {"name": "0.1.0", "id": "one"}, # id should be an int + {"name": "0.2.0", "id": "two"}, # id should be an int + ], + } + + node = Node.from_fn(dummy) + validators = check_output(model=Repo).get_validators(node) + assert len(validators) == 1 + validator = validators[0] + result = validator.validate(node()) + assert not result.passes + + +def test_basic_pydantic_plugin_check_output_passes(): + class DummyModel(BaseModel): + value: float + + def dummy() -> DummyModel: + return DummyModel(value=15.0) + + node = Node.from_fn(dummy) + validators = h_pydantic.check_output().get_validators(node) + assert len(validators) == 1 + validator = validators[0] + result_success = validator.validate(node()) + assert result_success.passes + + +def test_basic_pydantic_plugin_check_output_fails(): + class DummyModel(BaseModel): + value: float + + def dummy() -> DummyModel: + return DummyModel(value="fifteen") # type: ignore + + node = Node.from_fn(dummy) + validators = h_pydantic.check_output().get_validators(node) + assert len(validators) == 1 + validator = validators[0] + + with pytest.raises(ValidationError): + result = validator.validate(node()) + assert not result.passes + + +def test_complex_pydantic_plugin_check_output_passes(): + class Owner(BaseModel): + name: str + + class Version(BaseModel): + name: str + id: int + + class Repo(BaseModel): + name: str + owner: Owner + versions: List[Version] + + def dummy() -> Repo: + return Repo( + name="hamilton", + owner=Owner(name="DAGWorks-Inc"), + versions=[Version(name="0.1.0", id=1), Version(name="0.2.0", id=2)], + ) + + node = Node.from_fn(dummy) + validators = h_pydantic.check_output().get_validators(node) + assert len(validators) == 1 + validator = validators[0] + result_success = validator.validate(node()) + assert result_success.passes + + +def test_complex_pydantic_plugin_check_output_fails(): + class Owner(BaseModel): + name: str + + class Version(BaseModel): + name: str + id: int + + class Repo(BaseModel): + name: str + owner: Owner + versions: List[Version] + + def dummy() -> Repo: + return Repo( + name="hamilton", + owner=Owner(name="DAGWorks-Inc"), + versions=[ + Version(name="0.1.0", id=1), + Version(name="0.2.0", id="two"), # type: ignore + ], + ) + + node = Node.from_fn(dummy) + validators = h_pydantic.check_output().get_validators(node) + assert len(validators) == 1 + validator = validators[0] + + with pytest.raises(ValidationError): + result = validator.validate(node()) + assert not result.passes