Skip to content

Commit

Permalink
Merge pull request #98 from dsgibbons/feat/allow-missing-field
Browse files Browse the repository at this point in the history
feat: add allow missing field
  • Loading branch information
thomasaarholt authored Sep 3, 2024
2 parents 4448fbe + 67f1374 commit e55ddc5
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 16 deletions.
2 changes: 2 additions & 0 deletions src/patito/_pydantic/column_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class ColumnInfo(BaseModel, arbitrary_types_allowed=True):
"""patito-side model for storing column metadata.
Args:
allow_missing (bool): Column may be missing.
constraints (Union[polars.Expression, List[polars.Expression]): A single
constraint or list of constraints, expressed as a polars expression objects.
All rows must satisfy the given constraint. You can refer to the given column
Expand All @@ -96,6 +97,7 @@ class ColumnInfo(BaseModel, arbitrary_types_allowed=True):
"""

allow_missing: Optional[bool] = None # noqa: UP007
dtype: Annotated[
Optional[Union[DataTypeClass, DataType]], # noqa: UP007
BeforeValidator(dtype_deserializer),
Expand Down
1 change: 1 addition & 0 deletions src/patito/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,7 @@ def Field(
can be read with the below examples.
Args:
allow_missing (bool): Column may be missing.
column_info: (Type[ColumnInfo]): ColumnInfo object to pass args to.
constraints (Union[polars.Expression, List[polars.Expression]): A single
constraint or list of constraints, expressed as a polars expression objects.
Expand Down
50 changes: 34 additions & 16 deletions src/patito/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def _find_errors( # noqa: C901
if not allow_missing_columns:
# Check if any columns are missing
for missing_column in set(schema_subset) - set(dataframe.columns):
col_info = schema.column_infos.get(missing_column)
if col_info is not None and col_info.allow_missing:
continue

errors.append(
ErrorWrapper(
MissingColumnsError("Missing column"),
Expand Down Expand Up @@ -202,15 +206,19 @@ def _find_errors( # noqa: C901
continue

polars_type = dataframe_datatypes[column_name]
if polars_type not in valid_dtypes[column_name]:
errors.append(
ErrorWrapper(
ColumnDTypeError(
f"Polars dtype {polars_type} does not match model field type."
),
loc=column_name,
if polars_type not in [
pl.Struct,
pl.List(pl.Struct),
]: # defer struct validation for recursive call to _find_errors later
if polars_type not in valid_dtypes[column_name]:
errors.append(
ErrorWrapper(
ColumnDTypeError(
f"Polars dtype {polars_type} does not match model field type."
),
loc=column_name,
)
)
)

# Test for when only specific values are accepted
e = _find_enum_errors(
Expand Down Expand Up @@ -247,7 +255,7 @@ def _find_errors( # noqa: C901
# we need to filter out any null rows as the inner model may disallow
# nulls on a particular field

# NB As of Polars 1.1, struct_col.is_null() cannot return True
# NB As of Polars 1.1, struct_col.is_null() cannot return True
# The following code has been added to accomodate this

struct_fields = dataframe_tmp[column_name].struct.fields
Expand Down Expand Up @@ -278,23 +286,33 @@ def _find_errors( # noqa: C901
list_annotation = schema.model_fields[column_name].annotation
assert list_annotation is not None

# Additional unpack required if structs column is optional
# Handle Optional[list[pl.Struct]]
if is_optional(list_annotation):
list_annotation = unwrap_optional(list_annotation)
# An optional list means that we allow the list entry to be
# null. Since the list is optional, we need to filter out any
# null rows.

dataframe_tmp = dataframe_tmp.filter(pl.col(column_name).is_not_null())
if dataframe_tmp.is_empty():
continue

# Unpack list schema
nested_schema = list_annotation.__args__[0]

list_struct_errors = _find_errors(
dataframe=dataframe_tmp.select(column_name)
dataframe_tmp = (
dataframe_tmp.select(column_name)
.explode(column_name)
.unnest(column_name),
.unnest(column_name)
)

# Handle list[Optional[pl.Struct]]
if is_optional(nested_schema):
nested_schema = unwrap_optional(nested_schema)

dataframe_tmp = dataframe_tmp.filter(pl.all().is_not_null())
if dataframe_tmp.is_empty():
continue

list_struct_errors = _find_errors(
dataframe=dataframe_tmp,
schema=nested_schema,
)

Expand Down
113 changes: 113 additions & 0 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,119 @@ class SingleColumnModel(pt.Model):
) # kwargs are passed via model-centric validation API


def test_allow_missing_column_validation() -> None:
"""Validation should allow missing columns."""

class SingleColumnModel(pt.Model):
column_1: int
column_2: str = pt.Field(allow_missing=True)

# First we raise an error because we are missing column_1
with pytest.raises(DataFrameValidationError) as e_info:
validate(dataframe=pl.DataFrame(), schema=SingleColumnModel)

errors = e_info.value.errors()
assert len(e_info.value.errors()) == 1
assert sorted(errors, key=lambda e: e["loc"]) == [
{
"loc": ("column_1",),
"msg": "Missing column",
"type": "type_error.missingcolumns",
},
]

df_missing_column_2 = pl.DataFrame({"column_1": [1, 2, 3]})
validate(dataframe=df_missing_column_2, schema=SingleColumnModel)
SingleColumnModel.validate(df_missing_column_2)


def test_allow_missing_nested_column_validation() -> None:
"""Validation should allow missing nested columns."""

class InnerModel(pt.Model):
column_1: int
column_2: str = pt.Field(allow_missing=True)

class OuterModel(pt.Model):
inner: InnerModel
other: str

df_missing_nested_column_2 = pl.DataFrame(
{"inner": [{"column_1": 1}, {"column_1": 2}], "other": ["a", "b"]}
)
validate(dataframe=df_missing_nested_column_2, schema=OuterModel)
OuterModel.validate(df_missing_nested_column_2)

class OuterModelWithOptionalInner(pt.Model):
inner: Optional[InnerModel] # noqa: UP007
other: str

df_missing_nested_column_2 = pl.DataFrame(
{"inner": [{"column_1": 1}, None], "other": ["a", "b"]}
)
validate(dataframe=df_missing_nested_column_2, schema=OuterModelWithOptionalInner)
OuterModelWithOptionalInner.validate(df_missing_nested_column_2)

class OuterModelWithListInner(pt.Model):
inner: list[InnerModel]
other: str

df_missing_nested_column_2 = pl.DataFrame(
{
"inner": [
[{"column_1": 1}, {"column_1": 2}],
[{"column_1": 3}, {"column_1": 4}],
],
"other": ["a", "b"],
}
)
validate(dataframe=df_missing_nested_column_2, schema=OuterModelWithListInner)
OuterModelWithListInner.validate(df_missing_nested_column_2)

class OuterModelWithOptionalListInner(pt.Model):
inner: Optional[list[InnerModel]] # noqa: UP007
other: str

df_missing_nested_column_2 = pl.DataFrame(
{"inner": [[{"column_1": 1}, {"column_1": 2}], None], "other": ["a", "b"]}
)
validate(
dataframe=df_missing_nested_column_2, schema=OuterModelWithOptionalListInner
)
OuterModelWithOptionalListInner.validate(df_missing_nested_column_2)

class OuterModelWithListOptionalInner(pt.Model):
inner: list[Optional[InnerModel]] # noqa: UP007
other: str

df_missing_nested_column_2 = pl.DataFrame(
{
"inner": [[{"column_1": 1}, None], [None, {"column_1": 2}, None]],
"other": ["a", "b"],
}
)
validate(
dataframe=df_missing_nested_column_2, schema=OuterModelWithListOptionalInner
)
OuterModelWithListOptionalInner.validate(df_missing_nested_column_2)

class OuterModelWithOptionalListOptionalInner(pt.Model):
inner: Optional[list[Optional[InnerModel]]] # noqa: UP007
other: str

df_missing_nested_column_2 = pl.DataFrame(
{
"inner": [[{"column_1": 1}, None], [None, {"column_1": 2}, None], None],
"other": ["a", "b", "c"],
}
)
validate(
dataframe=df_missing_nested_column_2,
schema=OuterModelWithOptionalListOptionalInner,
)
OuterModelWithOptionalListOptionalInner.validate(df_missing_nested_column_2)


def test_superfluous_column_validation() -> None:
"""Validation should catch superfluous columns."""

Expand Down

0 comments on commit e55ddc5

Please sign in to comment.