Skip to content

Commit

Permalink
Merge pull request #121 from maxhipperson/fix-int-anno
Browse files Browse the repository at this point in the history
int model annotation should return only integer dtypes as valid
  • Loading branch information
thomasaarholt authored Nov 14, 2024
2 parents f5792d6 + 12208a1 commit c5b4d6f
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 19 deletions.
4 changes: 2 additions & 2 deletions src/patito/_pydantic/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections.abc import Mapping
from functools import cache, reduce
from operator import and_
from operator import or_
from typing import TYPE_CHECKING, Any

import polars as pl
Expand Down Expand Up @@ -144,7 +144,7 @@ def _valid_polars_dtypes_for_schema(
valid_type_sets.append(
self._pydantic_subschema_to_valid_polars_types(schema)
)
return reduce(and_, valid_type_sets) if valid_type_sets else DataTypeGroup([])
return reduce(or_, valid_type_sets) if valid_type_sets else DataTypeGroup([])

def _pydantic_subschema_to_valid_polars_types(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/patito/_pydantic/dtypes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _pyd_type_to_valid_dtypes(
_validate_enum_values(pyd_type, enum)
return DataTypeGroup([pl.Enum(enum), pl.String], match_base_type=False)
if pyd_type.value == "integer":
return DataTypeGroup(INTEGER_DTYPES | FLOAT_DTYPES)
return DataTypeGroup(INTEGER_DTYPES)
elif pyd_type.value == "number":
return (
FLOAT_DTYPES
Expand Down
42 changes: 29 additions & 13 deletions tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def test_valids_basic_annotations() -> None:
"""Test type annotations match polars dtypes."""
# base types
assert DtypeResolver(str).valid_polars_dtypes() == STRING_DTYPES
assert DtypeResolver(int).valid_polars_dtypes() == DataTypeGroup(
INTEGER_DTYPES | FLOAT_DTYPES
)
assert DtypeResolver(int).valid_polars_dtypes() == DataTypeGroup(INTEGER_DTYPES)
assert DtypeResolver(float).valid_polars_dtypes() == FLOAT_DTYPES
assert DtypeResolver(bool).valid_polars_dtypes() == BOOLEAN_DTYPES

Expand All @@ -64,10 +62,14 @@ def test_valids_basic_annotations() -> None:
assert (
DtypeResolver(str | None | None).valid_polars_dtypes() == STRING_DTYPES
) # superfluous None is ok
assert DtypeResolver(Union[int, float]).valid_polars_dtypes() == FLOAT_DTYPES
assert (
DtypeResolver(Union[str, int]).valid_polars_dtypes() == frozenset()
) # incompatible
DtypeResolver(Union[int, float]).valid_polars_dtypes()
== FLOAT_DTYPES | INTEGER_DTYPES
)
assert (
DtypeResolver(Union[str, int]).valid_polars_dtypes()
== STRING_DTYPES | INTEGER_DTYPES
)

# invalids
assert DtypeResolver(object).valid_polars_dtypes() == frozenset()
Expand All @@ -87,13 +89,13 @@ def test_valids_nested_annotations() -> None:
pl.List(pl.String)
}
assert len(DtypeResolver(list[int]).valid_polars_dtypes()) == len(
DataTypeGroup(INTEGER_DTYPES | FLOAT_DTYPES)
DataTypeGroup(INTEGER_DTYPES)
)
assert len(DtypeResolver(list[Union[int, float]]).valid_polars_dtypes()) == len(
FLOAT_DTYPES
INTEGER_DTYPES | FLOAT_DTYPES
)
assert len(DtypeResolver(list[Optional[int]]).valid_polars_dtypes()) == len(
DataTypeGroup(INTEGER_DTYPES | FLOAT_DTYPES)
DataTypeGroup(INTEGER_DTYPES)
)
assert DtypeResolver(list[list[str]]).valid_polars_dtypes() == {
pl.List(pl.List(pl.String))
Expand All @@ -116,7 +118,10 @@ def test_valids_nested_annotations() -> None:
def test_dtype_validation() -> None:
"""Ensure python types match polars types."""
validate_polars_dtype(int, pl.Int16) # no issue
validate_polars_dtype(int, pl.Float64) # no issue

with pytest.raises(ValueError, match="Invalid dtype"):
validate_polars_dtype(int, pl.Float64)

with pytest.raises(ValueError, match="Invalid dtype"):
validate_polars_dtype(int, pl.String)

Expand Down Expand Up @@ -195,11 +200,22 @@ def test_annotation_validation() -> None:

with pytest.raises(ValueError, match="Valid dtypes are:"):
validate_annotation(Union[int, float])
with pytest.raises(ValueError, match="not compatible with any polars dtypes"):

# Unions are unsupported as actual polars dtypes but are not supported by Patito IF a default dtype is provided
# TODO: Does it make sense for Patito to support union given that the underlying dataframe cannot?
with pytest.raises(
ValueError,
match="Unable to determine default dtype",
):
validate_annotation(Union[str, int])

validate_annotation(list[Optional[int]])
with pytest.raises(ValueError, match="not compatible with any polars dtypes"):

with pytest.raises(ValueError, match="Unable to determine default dtype"):
validate_annotation(list[Union[str, int]])
with pytest.raises(ValueError, match="Valid dtypes are:"):

with pytest.raises(
ValueError,
match="Unable to determine default dtype",
):
validate_annotation(list[Union[int, float]])
6 changes: 3 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_mapping_to_polars_dtypes() -> None:

assert CompleteModel.valid_dtypes == {
"str_column": {pl.String},
"int_column": DataTypeGroup(INTEGER_DTYPES | FLOAT_DTYPES),
"int_column": DataTypeGroup(INTEGER_DTYPES),
"float_column": FLOAT_DTYPES,
"bool_column": {pl.Boolean},
"date_column": DATE_DTYPES,
Expand All @@ -235,11 +235,11 @@ def test_mapping_to_polars_dtypes() -> None:
]
),
"list_int_column": DataTypeGroup(
[pl.List(x) for x in DataTypeGroup(INTEGER_DTYPES | FLOAT_DTYPES)]
[pl.List(x) for x in DataTypeGroup(INTEGER_DTYPES)]
),
"list_str_column": DataTypeGroup([pl.List(pl.String)]),
"list_opt_column": DataTypeGroup(
[pl.List(x) for x in DataTypeGroup(INTEGER_DTYPES | FLOAT_DTYPES)]
[pl.List(x) for x in DataTypeGroup(INTEGER_DTYPES)]
),
}

Expand Down

0 comments on commit c5b4d6f

Please sign in to comment.