Skip to content

Commit

Permalink
fix: filter out null rows before calling _find_errors again on option…
Browse files Browse the repository at this point in the history
…al structs
  • Loading branch information
dsgibbons committed Apr 6, 2024
1 parent 9b5e22d commit deb1493
Showing 1 changed file with 26 additions and 6 deletions.
32 changes: 26 additions & 6 deletions src/patito/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@

from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING, Any, Optional, Sequence, Type, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Optional,
Sequence,
Type,
Union,
_UnionGenericAlias,
cast,
)

import polars as pl
from pydantic.aliases import AliasGenerator
Expand Down Expand Up @@ -261,10 +269,16 @@ def _find_errors( # noqa: C901
if schema.dtypes[column_name] == pl.Struct:
nested_schema = schema.model_fields[column_name].annotation

with contextlib.suppress(AttributeError):
# Additional unpack required if structs column is optional
# Additional unpack required if structs column is optional
if type(nested_schema) == _UnionGenericAlias:
nested_schema = nested_schema.__args__[0]

# We need to filter out any null rows, as the submodel won't
# know that all of a row's columns may be null
dataframe = dataframe.filter(pl.col(column_name).is_not_null())
if dataframe.is_empty():
continue

struct_errors = _find_errors(
dataframe=dataframe.select(column_name).unnest(column_name),
schema=nested_schema,
Expand All @@ -283,10 +297,16 @@ def _find_errors( # noqa: C901
elif schema.dtypes[column_name] == pl.List(pl.Struct):
nested_schema = schema.model_fields[column_name].annotation.__args__[0]

with contextlib.suppress(AttributeError):
# Additional unpack required if list of structs column is optional
# Additional unpack required if structs column is optional
if type(nested_schema) == _UnionGenericAlias:
nested_schema = nested_schema.__args__[0]

# We need to filter out any null rows, as the submodel won't
# know that all of a row's columns may be null
dataframe = dataframe.filter(pl.col(column_name).is_not_null())
if dataframe.is_empty():
continue

list_struct_errors = _find_errors(
dataframe=dataframe.select(column_name)
.explode(column_name)
Expand Down

0 comments on commit deb1493

Please sign in to comment.