From 279b8249b7140d666a5440135e8260ffa5225e51 Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Tue, 27 Feb 2024 15:52:59 +0100 Subject: [PATCH 1/6] Add missing docstring --- src/patito/polars.py | 55 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/src/patito/polars.py b/src/patito/polars.py index 108e6b3..7cea599 100644 --- a/src/patito/polars.py +++ b/src/patito/polars.py @@ -1,4 +1,5 @@ """Logic related to the wrapping of the polars data frame library.""" + from __future__ import annotations from typing import ( @@ -90,6 +91,48 @@ def collect( return cls._from_pydf(df._df) def derive(self: LDF, columns: list[str] | None = None) -> LDF: + """Populate columns which have ``pt.Field(derived_from=...)`` definitions. + + If a column field on the data frame model has ``patito.Field(derived_from=...)`` + specified, the given value will be used to define the column. If + ``derived_from`` is set to a string, the column will be derived from the given + column name. Alternatively, an arbitrary polars expression can be given, the + result of which will be used to populate the column values. + + Args: + ---- + columns: Optionally, a list of column names to derive. If not provided, all + columns are used. + + Returns: + ------- + DataFrame[Model]: A new dataframe where all derivable columns are provided. + + Raises: + ------ + TypeError: If the ``derived_from`` parameter of ``patito.Field`` is given + as something else than a string or polars expression. + + Examples: + -------- + >>> import patito as pt + >>> import polars as pl + >>> class Foo(pt.Model): + ... bar: int = pt.Field(derived_from="foo") + ... double_bar: int = pt.Field(derived_from=2 * pl.col("bar")) + ... + >>> Foo.DataFrame({"foo": [1, 2]}).derive() + shape: (2, 3) + ┌─────┬────────────┬─────┐ + │ bar ┆ double_bar ┆ foo │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 │ + ╞═════╪════════════╪═════╡ + │ 1 ┆ 2 ┆ 1 │ + │ 2 ┆ 4 ┆ 2 │ + └─────┴────────────┴─────┘ + + """ derived_columns = [] props = self.model._schema_properties() original_columns = set(self.columns) @@ -620,12 +663,12 @@ def fill_null( ) return self.with_columns( [ - pl.col(column).fill_null( - pl.lit(default_value, self.model.dtypes[column]) - ) - if column in self.columns - else pl.Series( - column, [default_value], self.model.dtypes[column] + ( + pl.col(column).fill_null( + pl.lit(default_value, self.model.dtypes[column]) + ) + if column in self.columns + else pl.Series(column, [default_value], self.model.dtypes[column]) ) # NOTE: hack to get around polars bug https://github.com/pola-rs/polars/issues/13602 # else pl.lit(default_value, self.model.dtypes[column]).alias(column) for column, default_value in self.model.defaults.items() From 116093eb2f60593fb59f965cc278e3c2f7d84c1a Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Tue, 27 Feb 2024 15:55:22 +0100 Subject: [PATCH 2/6] fix type error by explictly setting background=False unless set in kwarg --- src/patito/polars.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/patito/polars.py b/src/patito/polars.py index 7cea599..ddba9de 100644 --- a/src/patito/polars.py +++ b/src/patito/polars.py @@ -83,7 +83,8 @@ def collect( See documentation of polars.DataFrame.collect for full description of parameters. """ - df = super().collect(*args, **kwargs) + background = kwargs.pop("background", False) + df = super().collect(*args, background=background, **kwargs) if getattr(self, "model", False): cls = DataFrame._construct_dataframe_model_class(model=self.model) else: From b3cf8adbfbcb6a2f8110351a11da0d34b95c1677 Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Wed, 28 Feb 2024 15:49:21 +0100 Subject: [PATCH 3/6] add docstrings and fix error raising warning --- src/patito/polars.py | 57 +++++++++++++++++++++++++++++++++++++++++- src/patito/pydantic.py | 11 +++++--- 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/src/patito/polars.py b/src/patito/polars.py index ddba9de..f752b18 100644 --- a/src/patito/polars.py +++ b/src/patito/polars.py @@ -183,6 +183,17 @@ def _derive_column( return df, derived_columns def unalias(self: LDF) -> LDF: + """Un-aliases column names using information from pydantic validation_alias. + + In order of preference - model field name then validation_aliases in order of occurrence + + limitation - AliasChoice validation type only supports selecting a single element of an array + + Returns + ------- + DataFrame[Model]: A dataframe with columns normalized to model names. + + """ if not any(fi.validation_alias for fi in self.model.model_fields.values()): return self exprs = [] @@ -233,6 +244,47 @@ def to_expr(va: str | AliasPath | AliasChoices) -> Optional[pl.Expr]: def cast( self: LDF, strict: bool = False, columns: Optional[Sequence[str]] = None ) -> LDF: + """Cast columns to `dtypes` specified by the associated Patito model. + + Args: + ---- + strict: If set to ``False``, columns which are technically compliant with + the specified field type, will not be casted. For example, a column + annotated with ``int`` is technically compliant with ``pl.UInt8``, even + if ``pl.Int64`` is the default dtype associated with ``int``-annotated + fields. If ``strict`` is set to ``True``, the resulting dtypes will + be forced to the default dtype associated with each python type. + columns: Optionally, a list of column names to cast. If not provided, all + columns are casted. + + Returns: + ------- + LazyFrame[Model]: A dataframe with columns casted to the correct dtypes. + + Examples: + -------- + Create a simple model: + + >>> import patito as pt + >>> import polars as pl + >>> class Product(pt.Model): + ... name: str + ... cent_price: int = pt.Field(dtype=pl.UInt16) + ... + + Now we can use this model to cast some simple data: + + >>> Product.LazyFrame({"name": ["apple"], "cent_price": ["8"]}).cast().collect() + shape: (1, 2) + ┌───────┬────────────┐ + │ name ┆ cent_price │ + │ --- ┆ --- │ + │ str ┆ u16 │ + ╞═══════╪════════════╡ + │ apple ┆ 8 │ + └───────┴────────────┘ + + """ properties = self.model._schema_properties() valid_dtypes = self.model.valid_dtypes default_dtypes = self.model.dtypes @@ -251,7 +303,7 @@ def cast( @classmethod def from_existing(cls: Type[LDF], lf: pl.LazyFrame) -> LDF: - """Constructs a patito.DataFrame object from an existing polars.DataFrame object""" + """Construct a patito.DataFrame object from an existing polars.DataFrame object.""" return cls.model.LazyFrame._from_pyldf(lf._ldf).cast() @@ -422,6 +474,8 @@ def cast( if ``pl.Int64`` is the default dtype associated with ``int``-annotated fields. If ``strict`` is set to ``True``, the resulting dtypes will be forced to the default dtype associated with each python type. + columns: Optionally, a list of column names to cast. If not provided, all + columns are casted. Returns: ------- @@ -784,6 +838,7 @@ def _pydantic_model(self) -> Type[Model]: ) def as_polars(self) -> pl.DataFrame: + """Convert patito dataframe to polars dataframe.""" return pl.DataFrame._from_pydf(self._df) @classmethod diff --git a/src/patito/pydantic.py b/src/patito/pydantic.py index 170891a..c44c3f3 100644 --- a/src/patito/pydantic.py +++ b/src/patito/pydantic.py @@ -1,4 +1,5 @@ """Logic related to wrapping logic around the pydantic library.""" + from __future__ import annotations import itertools @@ -101,10 +102,12 @@ def __init__(cls, name: str, bases: tuple, clsdict: dict, **kwargs) -> None: ) def __hash__(self) -> int: + """Return hash of the model class.""" return super().__hash__() @property def column_infos(cls: Type[ModelType]) -> Mapping[str, ColumnInfo]: + """Return column information for the model.""" return column_infos_for_model(cls) @property @@ -203,7 +206,7 @@ def valid_dtypes( # type: ignore ... >>> pprint(MyModel.valid_dtypes) {'bool_column': DataTypeGroup({Boolean}), - 'float_column': DataTypeGroup({Float32, Float64}), + 'float_column': DataTypeGroup({Float64, Float32}), 'int_column': DataTypeGroup({Int8, Int16, Int32, @@ -338,6 +341,7 @@ def unique_columns( # type: ignore def derived_columns( cls: Type[ModelType], # type: ignore[misc] ) -> set[str]: + """Return set of columns which are derived from other columns.""" infos = cls.column_infos return { column for column in cls.columns if infos[column].derived_from is not None @@ -490,6 +494,7 @@ def validate( dataframe: Polars DataFrame to be validated. columns: Optional list of columns to validate. If not provided, all columns of the dataframe will be validated. + **kwargs: Additional keyword arguments to be passed to the validation Raises: ------ @@ -688,10 +693,10 @@ def example_value( # noqa: C901 try: props_o = cls.model_schema["$defs"][properties["title"]]["properties"] return {f: cls.example_value(properties=props_o[f]) for f in props_o} - except AttributeError: + except AttributeError as err: raise NotImplementedError( "Nested example generation only supported for nested pt.Model classes." - ) + ) from err elif field_type == "array": return [cls.example_value(properties=properties["items"])] From efc447217b9551572ad8ca7b19057fe56157abe5 Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Sun, 3 Mar 2024 12:26:18 +0100 Subject: [PATCH 4/6] Remove unneeded type ignores --- src/patito/pydantic.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/src/patito/pydantic.py b/src/patito/pydantic.py index c44c3f3..30da384 100644 --- a/src/patito/pydantic.py +++ b/src/patito/pydantic.py @@ -69,7 +69,7 @@ class ModelMetaclass(PydanticModelMetaclass, Generic[CI]): - """Metclass used by patito.Model. + """Metaclass used by patito.Model. Responsible for setting any relevant model-dependent class properties. """ @@ -129,7 +129,7 @@ def model_schema(cls: Type[ModelType]) -> Mapping[str, Mapping[str, Any]]: return schema_for_model(cls) @property - def columns(cls: Type[ModelType]) -> List[str]: # type: ignore + def columns(cls: Type[ModelType]) -> List[str]: """Return the name of the dataframe columns specified by the fields of the model. Returns: @@ -150,9 +150,7 @@ def columns(cls: Type[ModelType]) -> List[str]: # type: ignore return list(cls.model_fields.keys()) @property - def dtypes( # type: ignore - cls: Type[ModelType], # pyright: ignore - ) -> dict[str, DataTypeClass | DataType]: + def dtypes(cls: Type[ModelType]) -> dict[str, DataTypeClass | DataType]: """Return the polars dtypes of the dataframe. Unless Field(dtype=...) is specified, the highest signed column dtype @@ -177,8 +175,8 @@ def dtypes( # type: ignore return default_dtypes_for_model(cls) @property - def valid_dtypes( # type: ignore - cls: Type[ModelType], # pyright: ignore + def valid_dtypes( + cls: Type[ModelType], ) -> Mapping[str, FrozenSet[DataTypeClass | DataType]]: """Return a list of polars dtypes which Patito considers valid for each field. @@ -223,9 +221,7 @@ def valid_dtypes( # type: ignore return valid_dtypes_for_model(cls) @property - def defaults( # type: ignore - cls: Type[ModelType], # pyright: ignore - ) -> dict[str, Any]: + def defaults(cls: Type[ModelType]) -> dict[str, Any]: """Return default field values specified on the model. Returns: @@ -252,9 +248,7 @@ def defaults( # type: ignore } @property - def non_nullable_columns( # type: ignore - cls: Type[ModelType], # pyright: ignore - ) -> set[str]: + def non_nullable_columns(cls: Type[ModelType]) -> set[str]: """Return names of those columns that are non-nullable in the schema. Returns: From 9eff95b5f18e1b70ed667051d6bb770e004f6a79 Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Sun, 3 Mar 2024 12:49:31 +0100 Subject: [PATCH 5/6] Some more unecessary type: ignore, and follow pyright recommendation --- src/patito/_pydantic/dtypes/dtypes.py | 8 +++++--- src/patito/pydantic.py | 18 ++++++------------ 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/patito/_pydantic/dtypes/dtypes.py b/src/patito/_pydantic/dtypes/dtypes.py index 1348fef..d42d43c 100644 --- a/src/patito/_pydantic/dtypes/dtypes.py +++ b/src/patito/_pydantic/dtypes/dtypes.py @@ -27,9 +27,11 @@ def valid_dtypes_for_model( cls: Type[ModelType], ) -> Mapping[str, FrozenSet[DataTypeClass]]: return { - column: DtypeResolver(cls.model_fields[column].annotation).valid_polars_dtypes() - if cls.column_infos[column].dtype is None - else DataTypeGroup([cls.dtypes[column]], match_base_type=False) + column: ( + DtypeResolver(cls.model_fields[column].annotation).valid_polars_dtypes() + if cls.column_infos[column].dtype is None + else DataTypeGroup([cls.dtypes[column]], match_base_type=False) + ) for column in cls.columns } diff --git a/src/patito/pydantic.py b/src/patito/pydantic.py index 30da384..e2e610f 100644 --- a/src/patito/pydantic.py +++ b/src/patito/pydantic.py @@ -279,9 +279,7 @@ def non_nullable_columns(cls: Type[ModelType]) -> set[str]: ) @property - def nullable_columns( # type: ignore - cls: Type[ModelType], # pyright: ignore - ) -> set[str]: + def nullable_columns(cls: Type[ModelType]) -> set[str]: """Return names of those columns that are nullable in the schema. Returns: @@ -305,9 +303,7 @@ def nullable_columns( # type: ignore return set(cls.columns) - cls.non_nullable_columns @property - def unique_columns( # type: ignore - cls: Type[ModelType], # pyright: ignore - ) -> set[str]: + def unique_columns(cls: Type[ModelType]) -> set[str]: """Return columns with uniqueness constraint. Returns: @@ -332,9 +328,7 @@ def unique_columns( # type: ignore return {column for column in cls.columns if infos[column].unique} @property - def derived_columns( - cls: Type[ModelType], # type: ignore[misc] - ) -> set[str]: + def derived_columns(cls: Type[ModelType]) -> set[str]: """Return set of columns which are derived from other columns.""" infos = cls.column_infos return { @@ -360,7 +354,7 @@ def validate_schema(cls: Type[ModelType]): @classmethod def from_row( - cls: Type[ModelType], # type: ignore[misc] + cls: Type[ModelType], row: Union["pd.DataFrame", pl.DataFrame], validate: bool = True, ) -> ModelType: @@ -404,7 +398,7 @@ def from_row( dataframe = row elif _PANDAS_AVAILABLE and isinstance(row, pd.DataFrame): dataframe = pl.DataFrame._from_pandas(row) - elif _PANDAS_AVAILABLE and isinstance(row, pd.Series): # type: ignore[unreachable] + elif _PANDAS_AVAILABLE and isinstance(row, pd.Series): return cls(**dict(row.items())) # type: ignore[unreachable] else: raise TypeError(f"{cls.__name__}.from_row not implemented for {type(row)}.") @@ -1326,7 +1320,7 @@ def _derive_field( def FieldCI( - column_info: CI, *args: Any, **kwargs: Any + column_info: Type[ColumnInfo], *args: Any, **kwargs: Any ) -> Any: # annotate with Any to make the downstream type annotations happy ci = column_info(**kwargs) for field in ci.model_fields_set: From 03bff6bcdc17914082b300d0aa9994fe62f7bb38 Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Sun, 3 Mar 2024 12:50:06 +0100 Subject: [PATCH 6/6] Remove valid_dtypes docstring example since it is giving a different sort sometimes on py39 --- src/patito/pydantic.py | 30 ++---------------------------- 1 file changed, 2 insertions(+), 28 deletions(-) diff --git a/src/patito/pydantic.py b/src/patito/pydantic.py index e2e610f..c59a8f6 100644 --- a/src/patito/pydantic.py +++ b/src/patito/pydantic.py @@ -182,41 +182,15 @@ def valid_dtypes( The first item of each list is the default dtype chosen by Patito. - Returns: + Returns ------- A dictionary mapping each column string name to a list of valid dtypes. - Raises: + Raises ------ NotImplementedError: If one or more model fields are annotated with types not compatible with polars. - Example: - ------- - >>> from pprint import pprint - >>> import patito as pt - - >>> class MyModel(pt.Model): - ... bool_column: bool - ... str_column: str - ... int_column: int - ... float_column: float - ... - >>> pprint(MyModel.valid_dtypes) - {'bool_column': DataTypeGroup({Boolean}), - 'float_column': DataTypeGroup({Float64, Float32}), - 'int_column': DataTypeGroup({Int8, - Int16, - Int32, - Int64, - UInt8, - UInt16, - UInt32, - UInt64, - Float32, - Float64}), - 'str_column': DataTypeGroup({String})} - """ return valid_dtypes_for_model(cls)