diff --git a/src/patito/validators.py b/src/patito/validators.py index b748a73..62d018f 100644 --- a/src/patito/validators.py +++ b/src/patito/validators.py @@ -2,8 +2,16 @@ from __future__ import annotations -import contextlib -from typing import TYPE_CHECKING, Any, Optional, Sequence, Type, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Optional, + Sequence, + Type, + Union, + _UnionGenericAlias, + cast, +) import polars as pl from pydantic.aliases import AliasGenerator @@ -261,10 +269,16 @@ def _find_errors( # noqa: C901 if schema.dtypes[column_name] == pl.Struct: nested_schema = schema.model_fields[column_name].annotation - with contextlib.suppress(AttributeError): - # Additional unpack required if structs column is optional + # Additional unpack required if structs column is optional + if type(nested_schema) == _UnionGenericAlias: nested_schema = nested_schema.__args__[0] + # We need to filter out any null rows, as the submodel won't + # know that all of a row's columns may be null + dataframe = dataframe.filter(pl.col(column_name).is_not_null()) + if dataframe.is_empty(): + continue + struct_errors = _find_errors( dataframe=dataframe.select(column_name).unnest(column_name), schema=nested_schema, @@ -283,10 +297,16 @@ def _find_errors( # noqa: C901 elif schema.dtypes[column_name] == pl.List(pl.Struct): nested_schema = schema.model_fields[column_name].annotation.__args__[0] - with contextlib.suppress(AttributeError): - # Additional unpack required if list of structs column is optional + # Additional unpack required if structs column is optional + if type(nested_schema) == _UnionGenericAlias: nested_schema = nested_schema.__args__[0] + # We need to filter out any null rows, as the submodel won't + # know that all of a row's columns may be null + dataframe = dataframe.filter(pl.col(column_name).is_not_null()) + if dataframe.is_empty(): + continue + list_struct_errors = _find_errors( dataframe=dataframe.select(column_name) .explode(column_name)