Skip to content

Commit

Permalink
Merge pull request #100 from dsgibbons/fix/optional-list
Browse files Browse the repository at this point in the history
fix: optional list of structs
  • Loading branch information
thomasaarholt authored Sep 2, 2024
2 parents 477a242 + d4b6536 commit 4448fbe
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 17 deletions.
25 changes: 8 additions & 17 deletions src/patito/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,29 +277,20 @@ def _find_errors( # noqa: C901
elif schema.dtypes[column_name] == pl.List(pl.Struct):
list_annotation = schema.model_fields[column_name].annotation
assert list_annotation is not None
nested_schema = list_annotation.__args__[0]

# Additional unpack required if structs column is optional
if is_optional(nested_schema):
nested_schema = unwrap_optional(nested_schema)
# An optional struct means that we allow the struct entry to be
# null. It is the inner model that is responsible for determining
# whether its fields are optional or not. Since the struct is optional,
# we need to filter out any null rows as the inner model may disallow
# nulls on a particular field
if is_optional(list_annotation):
list_annotation = unwrap_optional(list_annotation)
# An optional list means that we allow the list entry to be
# null. Since the list is optional, we need to filter out any
# null rows.

# NB As of Polars 1.1, struct_col.is_null() cannot return True
# The following code has been added to accomodate this

struct_fields = dataframe_tmp[column_name].struct.fields
col_struct = pl.col(column_name).struct
only_non_null_expr = ~pl.all_horizontal(
[col_struct.field(name).is_null() for name in struct_fields]
)
dataframe_tmp = dataframe_tmp.filter(only_non_null_expr)
dataframe_tmp = dataframe_tmp.filter(pl.col(column_name).is_not_null())
if dataframe_tmp.is_empty():
continue

nested_schema = list_annotation.__args__[0]

list_struct_errors = _find_errors(
dataframe=dataframe_tmp.select(column_name)
.explode(column_name)
Expand Down
46 changes: 46 additions & 0 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,52 @@ class ListModel(pt.Model):
ListModel.validate(valid_df.with_columns(pl.col(old).alias(new)))


def test_optional_nested_list() -> None:
"""It should be able to validate optional structs organized in lists."""

class Inner(pt.Model):
name: str
reliability: bool
level: int

class Outer(pt.Model):
id: str
code: str
label: str
inner_types: Optional[list[Inner]] # noqa: UP007

good_df = pl.DataFrame(
{
"id": [1, 2, 3],
"code": ["A", "B", "C"],
"label": ["a", "b", "c"],
"inner_types": [
[{"name": "a", "reliability": True, "level": 1}],
[{"name": "b", "reliability": False, "level": 2}],
None,
],
}
)
df = Outer.DataFrame(good_df).cast().derive()
df.validate()

bad_df = pl.DataFrame(
{
"id": [1, 2, 3],
"code": ["A", "B", "C"],
"label": ["a", "b", "c"],
"inner_types": [
[{"name": "a", "level": 1}], # missing reliability
[{"name": "b", "reliability": False, "level": 2}],
None,
],
}
)
df = Outer.DataFrame(bad_df).cast().derive()
with pytest.raises(DataFrameValidationError):
df.validate()


def test_nested_field_attrs() -> None:
"""Ensure that constraints are respected even when embedded inside 'anyOf'."""

Expand Down

0 comments on commit 4448fbe

Please sign in to comment.