Skip to content

Commit

Permalink
Merge pull request #52 from dsgibbons/fix/validate-fields-on-structs
Browse files Browse the repository at this point in the history
Add field validation on structs
  • Loading branch information
thomasaarholt authored Apr 26, 2024
2 parents ae2fd82 + 11d388a commit ef2d95c
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 1 deletion.
69 changes: 68 additions & 1 deletion src/patito/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional, Sequence, Type, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Optional,
Sequence,
Type,
Union,
_UnionGenericAlias,
cast,
)

import polars as pl
from pydantic.aliases import AliasGenerator
Expand Down Expand Up @@ -252,6 +261,64 @@ def _find_errors( # noqa: C901
)
)

# Intercept struct columns, and process errors separately
if schema.dtypes[column_name] == pl.Struct:
nested_schema = schema.model_fields[column_name].annotation

# Additional unpack required if structs column is optional
if type(nested_schema) == _UnionGenericAlias:
nested_schema = nested_schema.__args__[0]

# We need to filter out any null rows as the submodel won't know
# that all of a row's columns may be null
dataframe = dataframe.filter(pl.col(column_name).is_not_null())
if dataframe.is_empty():
continue

struct_errors = _find_errors(
dataframe=dataframe.select(column_name).unnest(column_name),
schema=nested_schema,
)

# Format nested errors
for error in struct_errors:
error._loc = f"{column_name}.{error._loc}"

errors.extend(struct_errors)

# No need to do any more checks
continue

# Intercept list of structs columns, and process errors separately
elif schema.dtypes[column_name] == pl.List(pl.Struct):
nested_schema = schema.model_fields[column_name].annotation.__args__[0]

# Additional unpack required if structs column is optional
if type(nested_schema) == _UnionGenericAlias:
nested_schema = nested_schema.__args__[0]

# We need to filter out any null rows as the submodel won't know
# that all of a row's columns may be null
dataframe = dataframe.filter(pl.col(column_name).is_not_null())
if dataframe.is_empty():
continue

list_struct_errors = _find_errors(
dataframe=dataframe.select(column_name)
.explode(column_name)
.unnest(column_name),
schema=nested_schema,
)

# Format nested errors
for error in list_struct_errors:
error._loc = f"{column_name}.{error._loc}"

errors.extend(list_struct_errors)

# No need to do any more checks
continue

# Check for bounded value fields
col = pl.col(column_name)
filters = {
Expand Down
105 changes: 105 additions & 0 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,111 @@ class ListEnumModel(pt.Model):
assert errors[0] == error_expected


class _PositiveStruct(pt.Model):
x: int = pt.Field(gt=0)


class _PositiveStructModel(pt.Model):
positive_struct: _PositiveStruct


def test_simple_struct_validation() -> None:
"""Test validation of model with struct column."""
valid_df = pl.DataFrame({"positive_struct": [{"x": 1}, {"x": 2}, {"x": 3}]})
_PositiveStructModel.validate(valid_df)

bad_df = pl.DataFrame({"positive_struct": [{"x": -1}, {"x": 2}, {"x": 3}]})
with pytest.raises(DataFrameValidationError):
_PositiveStructModel.validate(bad_df)


def test_nested_struct_validation() -> None:
"""Test validation of model with nested struct column."""

class NestedPositiveStructModel(pt.Model):
positive_struct_model: _PositiveStructModel

valid_df = pl.DataFrame(
{
"positive_struct_model": [
{"positive_struct": {"x": 1}},
{"positive_struct": {"x": 2}},
{"positive_struct": {"x": 3}},
]
}
)
NestedPositiveStructModel.validate(valid_df)

bad_df = pl.DataFrame(
{
"positive_struct_model": [
{"positive_struct": {"x": -1}},
{"positive_struct": {"x": 2}},
{"positive_struct": {"x": 3}},
]
}
)
with pytest.raises(DataFrameValidationError):
NestedPositiveStructModel.validate(bad_df)


def test_list_struct_validation() -> None:
"""Test validation of model with list of structs column."""

class ListPositiveStructModel(pt.Model):
list_positive_struct: list[_PositiveStruct]

valid_df = pl.DataFrame(
{"list_positive_struct": [[{"x": 1}, {"x": 2}], [{"x": 3}, {"x": 4}, {"x": 5}]]}
)
ListPositiveStructModel.validate(valid_df)

bad_df = pl.DataFrame(
{
"list_positive_struct": [
[{"x": 1}, {"x": 2}],
[{"x": 3}, {"x": -4}, {"x": 5}],
]
}
)
with pytest.raises(DataFrameValidationError):
ListPositiveStructModel.validate(bad_df)


def test_struct_validation_with_polars_constraint() -> None:
"""Test validation of models with constrained struct column."""

class Interval(pt.Model):
x_min: int
x_max: int = pt.Field(constraints=pt.col("x_min") <= pt.col("x_max"))

class IntervalModel(pt.Model):
interval: Interval

valid_df = pl.DataFrame(
{
"interval": [
{"x_min": 0, "x_max": 1},
{"x_min": 0, "x_max": 0},
{"x_min": -1, "x_max": 1},
]
}
)
IntervalModel.validate(valid_df)

bad_df = pl.DataFrame(
{
"interval": [
{"x_min": 0, "x_max": 1},
{"x_min": 1, "x_max": 0},
{"x_min": -1, "x_max": 1},
]
}
)
with pytest.raises(DataFrameValidationError):
IntervalModel.validate(bad_df)


def test_uniqueness_constraint_validation() -> None:
"""Uniqueness constraints should be validated."""

Expand Down

0 comments on commit ef2d95c

Please sign in to comment.