Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some warnings after migration #39

Merged
merged 6 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/patito/_pydantic/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ def valid_dtypes_for_model(
cls: Type[ModelType],
) -> Mapping[str, FrozenSet[DataTypeClass]]:
return {
column: DtypeResolver(cls.model_fields[column].annotation).valid_polars_dtypes()
if cls.column_infos[column].dtype is None
else DataTypeGroup([cls.dtypes[column]], match_base_type=False)
column: (
DtypeResolver(cls.model_fields[column].annotation).valid_polars_dtypes()
if cls.column_infos[column].dtype is None
else DataTypeGroup([cls.dtypes[column]], match_base_type=False)
)
for column in cls.columns
}

Expand Down
115 changes: 107 additions & 8 deletions src/patito/polars.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Logic related to the wrapping of the polars data frame library."""

from __future__ import annotations

from typing import (
Expand Down Expand Up @@ -82,14 +83,57 @@ def collect(
See documentation of polars.DataFrame.collect for full description of
parameters.
"""
df = super().collect(*args, **kwargs)
background = kwargs.pop("background", False)
df = super().collect(*args, background=background, **kwargs)
if getattr(self, "model", False):
cls = DataFrame._construct_dataframe_model_class(model=self.model)
else:
cls = DataFrame
return cls._from_pydf(df._df)

def derive(self: LDF, columns: list[str] | None = None) -> LDF:
"""Populate columns which have ``pt.Field(derived_from=...)`` definitions.

If a column field on the data frame model has ``patito.Field(derived_from=...)``
specified, the given value will be used to define the column. If
``derived_from`` is set to a string, the column will be derived from the given
column name. Alternatively, an arbitrary polars expression can be given, the
result of which will be used to populate the column values.

Args:
----
columns: Optionally, a list of column names to derive. If not provided, all
columns are used.

Returns:
-------
DataFrame[Model]: A new dataframe where all derivable columns are provided.

Raises:
------
TypeError: If the ``derived_from`` parameter of ``patito.Field`` is given
as something else than a string or polars expression.

Examples:
--------
>>> import patito as pt
>>> import polars as pl
>>> class Foo(pt.Model):
... bar: int = pt.Field(derived_from="foo")
... double_bar: int = pt.Field(derived_from=2 * pl.col("bar"))
...
>>> Foo.DataFrame({"foo": [1, 2]}).derive()
shape: (2, 3)
┌─────┬────────────┬─────┐
│ bar ┆ double_bar ┆ foo │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪════════════╪═════╡
│ 1 ┆ 2 ┆ 1 │
│ 2 ┆ 4 ┆ 2 │
└─────┴────────────┴─────┘

"""
derived_columns = []
props = self.model._schema_properties()
original_columns = set(self.columns)
Expand Down Expand Up @@ -139,6 +183,17 @@ def _derive_column(
return df, derived_columns

def unalias(self: LDF) -> LDF:
"""Un-aliases column names using information from pydantic validation_alias.

In order of preference - model field name then validation_aliases in order of occurrence

limitation - AliasChoice validation type only supports selecting a single element of an array

Returns
-------
DataFrame[Model]: A dataframe with columns normalized to model names.

"""
if not any(fi.validation_alias for fi in self.model.model_fields.values()):
return self
exprs = []
Expand Down Expand Up @@ -189,6 +244,47 @@ def to_expr(va: str | AliasPath | AliasChoices) -> Optional[pl.Expr]:
def cast(
self: LDF, strict: bool = False, columns: Optional[Sequence[str]] = None
) -> LDF:
"""Cast columns to `dtypes` specified by the associated Patito model.

Args:
----
strict: If set to ``False``, columns which are technically compliant with
the specified field type, will not be casted. For example, a column
annotated with ``int`` is technically compliant with ``pl.UInt8``, even
if ``pl.Int64`` is the default dtype associated with ``int``-annotated
fields. If ``strict`` is set to ``True``, the resulting dtypes will
be forced to the default dtype associated with each python type.
columns: Optionally, a list of column names to cast. If not provided, all
columns are casted.

Returns:
-------
LazyFrame[Model]: A dataframe with columns casted to the correct dtypes.

Examples:
--------
Create a simple model:

>>> import patito as pt
>>> import polars as pl
>>> class Product(pt.Model):
... name: str
... cent_price: int = pt.Field(dtype=pl.UInt16)
...

Now we can use this model to cast some simple data:

>>> Product.LazyFrame({"name": ["apple"], "cent_price": ["8"]}).cast().collect()
shape: (1, 2)
┌───────┬────────────┐
│ name ┆ cent_price │
│ --- ┆ --- │
│ str ┆ u16 │
╞═══════╪════════════╡
│ apple ┆ 8 │
└───────┴────────────┘

"""
properties = self.model._schema_properties()
valid_dtypes = self.model.valid_dtypes
default_dtypes = self.model.dtypes
Expand All @@ -207,7 +303,7 @@ def cast(

@classmethod
def from_existing(cls: Type[LDF], lf: pl.LazyFrame) -> LDF:
"""Constructs a patito.DataFrame object from an existing polars.DataFrame object"""
"""Construct a patito.DataFrame object from an existing polars.DataFrame object."""
return cls.model.LazyFrame._from_pyldf(lf._ldf).cast()


Expand Down Expand Up @@ -378,6 +474,8 @@ def cast(
if ``pl.Int64`` is the default dtype associated with ``int``-annotated
fields. If ``strict`` is set to ``True``, the resulting dtypes will
be forced to the default dtype associated with each python type.
columns: Optionally, a list of column names to cast. If not provided, all
columns are casted.

Returns:
-------
Expand Down Expand Up @@ -620,12 +718,12 @@ def fill_null(
)
return self.with_columns(
[
pl.col(column).fill_null(
pl.lit(default_value, self.model.dtypes[column])
)
if column in self.columns
else pl.Series(
column, [default_value], self.model.dtypes[column]
(
pl.col(column).fill_null(
pl.lit(default_value, self.model.dtypes[column])
)
if column in self.columns
else pl.Series(column, [default_value], self.model.dtypes[column])
) # NOTE: hack to get around polars bug https://github.com/pola-rs/polars/issues/13602
# else pl.lit(default_value, self.model.dtypes[column]).alias(column)
for column, default_value in self.model.defaults.items()
Expand Down Expand Up @@ -740,6 +838,7 @@ def _pydantic_model(self) -> Type[Model]:
)

def as_polars(self) -> pl.DataFrame:
"""Convert patito dataframe to polars dataframe."""
return pl.DataFrame._from_pydf(self._df)

@classmethod
Expand Down
77 changes: 22 additions & 55 deletions src/patito/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Logic related to wrapping logic around the pydantic library."""

from __future__ import annotations

import itertools
Expand Down Expand Up @@ -68,7 +69,7 @@


class ModelMetaclass(PydanticModelMetaclass, Generic[CI]):
"""Metclass used by patito.Model.
"""Metaclass used by patito.Model.

Responsible for setting any relevant model-dependent class properties.
"""
Expand Down Expand Up @@ -101,10 +102,12 @@ def __init__(cls, name: str, bases: tuple, clsdict: dict, **kwargs) -> None:
)

def __hash__(self) -> int:
"""Return hash of the model class."""
return super().__hash__()

@property
def column_infos(cls: Type[ModelType]) -> Mapping[str, ColumnInfo]:
"""Return column information for the model."""
return column_infos_for_model(cls)

@property
Expand All @@ -126,7 +129,7 @@ def model_schema(cls: Type[ModelType]) -> Mapping[str, Mapping[str, Any]]:
return schema_for_model(cls)

@property
def columns(cls: Type[ModelType]) -> List[str]: # type: ignore
def columns(cls: Type[ModelType]) -> List[str]:
"""Return the name of the dataframe columns specified by the fields of the model.

Returns:
Expand All @@ -147,9 +150,7 @@ def columns(cls: Type[ModelType]) -> List[str]: # type: ignore
return list(cls.model_fields.keys())

@property
def dtypes( # type: ignore
cls: Type[ModelType], # pyright: ignore
) -> dict[str, DataTypeClass | DataType]:
def dtypes(cls: Type[ModelType]) -> dict[str, DataTypeClass | DataType]:
"""Return the polars dtypes of the dataframe.

Unless Field(dtype=...) is specified, the highest signed column dtype
Expand All @@ -174,55 +175,27 @@ def dtypes( # type: ignore
return default_dtypes_for_model(cls)

@property
def valid_dtypes( # type: ignore
cls: Type[ModelType], # pyright: ignore
def valid_dtypes(
cls: Type[ModelType],
) -> Mapping[str, FrozenSet[DataTypeClass | DataType]]:
"""Return a list of polars dtypes which Patito considers valid for each field.

The first item of each list is the default dtype chosen by Patito.

Returns:
Returns
-------
A dictionary mapping each column string name to a list of valid dtypes.

Raises:
Raises
------
NotImplementedError: If one or more model fields are annotated with types
not compatible with polars.

Example:
-------
>>> from pprint import pprint
>>> import patito as pt

>>> class MyModel(pt.Model):
... bool_column: bool
... str_column: str
... int_column: int
... float_column: float
...
>>> pprint(MyModel.valid_dtypes)
{'bool_column': DataTypeGroup({Boolean}),
'float_column': DataTypeGroup({Float32, Float64}),
'int_column': DataTypeGroup({Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float32,
Float64}),
'str_column': DataTypeGroup({String})}

"""
return valid_dtypes_for_model(cls)

@property
def defaults( # type: ignore
cls: Type[ModelType], # pyright: ignore
) -> dict[str, Any]:
def defaults(cls: Type[ModelType]) -> dict[str, Any]:
"""Return default field values specified on the model.

Returns:
Expand All @@ -249,9 +222,7 @@ def defaults( # type: ignore
}

@property
def non_nullable_columns( # type: ignore
cls: Type[ModelType], # pyright: ignore
) -> set[str]:
def non_nullable_columns(cls: Type[ModelType]) -> set[str]:
"""Return names of those columns that are non-nullable in the schema.

Returns:
Expand Down Expand Up @@ -282,9 +253,7 @@ def non_nullable_columns( # type: ignore
)

@property
def nullable_columns( # type: ignore
cls: Type[ModelType], # pyright: ignore
) -> set[str]:
def nullable_columns(cls: Type[ModelType]) -> set[str]:
"""Return names of those columns that are nullable in the schema.

Returns:
Expand All @@ -308,9 +277,7 @@ def nullable_columns( # type: ignore
return set(cls.columns) - cls.non_nullable_columns

@property
def unique_columns( # type: ignore
cls: Type[ModelType], # pyright: ignore
) -> set[str]:
def unique_columns(cls: Type[ModelType]) -> set[str]:
"""Return columns with uniqueness constraint.

Returns:
Expand All @@ -335,9 +302,8 @@ def unique_columns( # type: ignore
return {column for column in cls.columns if infos[column].unique}

@property
def derived_columns(
cls: Type[ModelType], # type: ignore[misc]
) -> set[str]:
def derived_columns(cls: Type[ModelType]) -> set[str]:
"""Return set of columns which are derived from other columns."""
infos = cls.column_infos
return {
column for column in cls.columns if infos[column].derived_from is not None
Expand All @@ -362,7 +328,7 @@ def validate_schema(cls: Type[ModelType]):

@classmethod
def from_row(
cls: Type[ModelType], # type: ignore[misc]
cls: Type[ModelType],
row: Union["pd.DataFrame", pl.DataFrame],
validate: bool = True,
) -> ModelType:
Expand Down Expand Up @@ -406,7 +372,7 @@ def from_row(
dataframe = row
elif _PANDAS_AVAILABLE and isinstance(row, pd.DataFrame):
dataframe = pl.DataFrame._from_pandas(row)
elif _PANDAS_AVAILABLE and isinstance(row, pd.Series): # type: ignore[unreachable]
elif _PANDAS_AVAILABLE and isinstance(row, pd.Series):
return cls(**dict(row.items())) # type: ignore[unreachable]
else:
raise TypeError(f"{cls.__name__}.from_row not implemented for {type(row)}.")
Expand Down Expand Up @@ -490,6 +456,7 @@ def validate(
dataframe: Polars DataFrame to be validated.
columns: Optional list of columns to validate. If not provided, all columns
of the dataframe will be validated.
**kwargs: Additional keyword arguments to be passed to the validation

Raises:
------
Expand Down Expand Up @@ -688,10 +655,10 @@ def example_value( # noqa: C901
try:
props_o = cls.model_schema["$defs"][properties["title"]]["properties"]
return {f: cls.example_value(properties=props_o[f]) for f in props_o}
except AttributeError:
except AttributeError as err:
raise NotImplementedError(
"Nested example generation only supported for nested pt.Model classes."
)
) from err

elif field_type == "array":
return [cls.example_value(properties=properties["items"])]
Expand Down Expand Up @@ -1327,7 +1294,7 @@ def _derive_field(


def FieldCI(
column_info: CI, *args: Any, **kwargs: Any
column_info: Type[ColumnInfo], *args: Any, **kwargs: Any
) -> Any: # annotate with Any to make the downstream type annotations happy
ci = column_info(**kwargs)
for field in ci.model_fields_set:
Expand Down
Loading