Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify metaclass to allow DataFrame[Foo] type propagation #99

Merged
merged 2 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 68 additions & 67 deletions src/patito/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,33 +54,63 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):

model: type[ModelType]

@classmethod
def _construct_lazyframe_model_class(
cls: type[LDF], model: type[ModelType] | None
) -> type[LazyFrame[ModelType]]:
"""Return custom LazyFrame sub-class where LazyFrame.model is set.
def set_model(self, model: type[OtherModelType]) -> LazyFrame[OtherModelType]:
"""Associate a given patito ``Model`` with the dataframe.

Can be used to construct a LazyFrame class where
DataFrame.set_model(model) is implicitly invoked at collection.
The model schema is used by methods that depend on a model being associated with
the given dataframe such as :ref:`DataFrame.validate() <DataFrame.validate>`
and :ref:`DataFrame.get() <DataFrame.get>`.

``DataFrame(...).set_model(Model)`` is equivalent with ``Model.DataFrame(...)``.

Args:
model: A patito model which should be used to validate the final dataframe.
If None is provided, the regular LazyFrame class will be returned.
model (Model): Sub-class of ``patito.Model`` declaring the schema of the
dataframe.

Returns:
A custom LazyFrame model class where LazyFrame.model has been correctly
"hard-coded" to the given model.
DataFrame[Model]: Returns the same dataframe, but with an attached model
that is required for certain model-specific dataframe methods to work.

"""
if model is None:
return cls
Examples:
>>> from typing_extensions import Literal
>>> import patito as pt
>>> import polars as pl
>>> class SchoolClass(pt.Model):
... year: int = pt.Field(dtype=pl.UInt16)
... letter: Literal["A", "B"] = pt.Field(dtype=pl.Categorical)
...
>>> classes = pt.DataFrame(
... {"year": [1, 1, 2, 2], "letter": list("ABAB")}
... ).set_model(SchoolClass)
>>> classes
shape: (4, 2)
┌──────┬────────┐
│ year ┆ letter │
│ --- ┆ --- │
│ i64 ┆ str │
╞══════╪════════╡
│ 1 ┆ A │
│ 1 ┆ B │
│ 2 ┆ A │
│ 2 ┆ B │
└──────┴────────┘
>>> casted_classes = classes.cast()
>>> casted_classes
shape: (4, 2)
┌──────┬────────┐
│ year ┆ letter │
│ --- ┆ --- │
│ u16 ┆ cat │
╞══════╪════════╡
│ 1 ┆ A │
│ 1 ┆ B │
│ 2 ┆ A │
│ 2 ┆ B │
└──────┴────────┘
>>> casted_classes.validate()

new_class = type(
f"{model.__name__}LazyFrame",
(cls,),
{"model": model},
)
return new_class
"""
return model.LazyFrame._from_pyldf(self._ldf) # type: ignore

def collect(
self,
Expand All @@ -93,12 +123,11 @@ def collect(
parameters.
"""
background = kwargs.pop("background", False)
df = super().collect(*args, background=background, **kwargs)
df: pl.DataFrame = super().collect(*args, background=background, **kwargs)
df = DataFrame(df)
if getattr(self, "model", False):
cls = DataFrame._construct_dataframe_model_class(model=self.model)
else:
cls = DataFrame
return cls._from_pydf(df._df)
df = df.set_model(self.model)
return df

def derive(self: LDF, columns: list[str] | None = None) -> LDF:
"""Populate columns which have ``pt.Field(derived_from=...)`` definitions.
Expand Down Expand Up @@ -307,7 +336,10 @@ def cast(
@classmethod
def from_existing(cls: type[LDF], lf: pl.LazyFrame) -> LDF:
"""Construct a patito.DataFrame object from an existing polars.DataFrame object."""
return cls.model.LazyFrame._from_pyldf(lf._ldf).cast()
if getattr(cls, "model", False):
return cls.model.LazyFrame._from_pyldf(super().lazy()._ldf) # type: ignore

return LazyFrame._from_pyldf(lf._ldf) # type: ignore


class DataFrame(pl.DataFrame, Generic[ModelType]):
Expand Down Expand Up @@ -341,30 +373,6 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):

model: type[ModelType]

@classmethod
def _construct_dataframe_model_class(
cls: type[DF], model: type[OtherModelType]
) -> type[DataFrame[OtherModelType]]:
"""Return custom DataFrame sub-class where DataFrame.model is set.

Can be used to construct a DataFrame class where
DataFrame.set_model(model) is implicitly invoked at instantiation.

Args:
model: A patito model which should be used to validate the dataframe.

Returns:
A custom DataFrame model class where DataFrame._model has been correctly
"hard-coded" to the given model.

"""
new_class = type(
f"{model.model_json_schema()['title']}DataFrame",
(cls,),
{"model": model},
)
return new_class

def lazy(self: DataFrame[ModelType]) -> LazyFrame[ModelType]:
"""Convert DataFrame into LazyFrame.

Expand All @@ -374,15 +382,12 @@ def lazy(self: DataFrame[ModelType]) -> LazyFrame[ModelType]:
A new LazyFrame object.

"""
lazyframe_class: LazyFrame[ModelType] = (
LazyFrame._construct_lazyframe_model_class(
model=getattr(self, "model", None)
)
) # type: ignore
ldf = lazyframe_class._from_pyldf(super().lazy()._ldf)
return ldf
if getattr(self, "model", False):
return self.model.LazyFrame._from_pyldf(super().lazy()._ldf) # type: ignore

return LazyFrame._from_pyldf(super().lazy()._ldf) # type: ignore

def set_model(self, model): # type: ignore[no-untyped-def] # noqa: ANN001, ANN201
def set_model(self, model: type[OtherModelType]) -> DataFrame[OtherModelType]:
"""Associate a given patito ``Model`` with the dataframe.

The model schema is used by methods that depend on a model being associated with
Expand Down Expand Up @@ -438,11 +443,7 @@ def set_model(self, model): # type: ignore[no-untyped-def] # noqa: ANN001, ANN2
>>> casted_classes.validate()

"""
cls = self._construct_dataframe_model_class(model=model)
return cast(
DataFrame[model],
cls._from_pydf(self._df),
)
return model.DataFrame(self._df)

def unalias(self: DF) -> DF:
"""Un-aliases column names using information from pydantic validation_alias.
Expand Down Expand Up @@ -503,7 +504,6 @@ def cast(
def drop(
self: DF,
columns: str | Collection[str] | None = None,
*more_columns: str,
) -> DF:
"""Drop one or more columns from the dataframe.

Expand All @@ -515,7 +515,6 @@ def drop(
columns: A single column string name, or list of strings, indicating
which columns to drop. If not specified, all columns *not*
specified by the associated dataframe model will be dropped.
more_columns: Additional named columns to drop.

Returns:
DataFrame[Model]: New dataframe without the specified columns.
Expand All @@ -538,7 +537,9 @@ def drop(

"""
if columns is not None:
return self._from_pydf(super().drop(columns)._df)
# I get a single null row if I try to use super() here, so go via
# pl.DataFrame instead.
return self._from_pydf(pl.DataFrame(self._df).drop(columns)._df)
else:
return self.drop(list(set(self.columns) - set(self.model.columns)))

Expand Down Expand Up @@ -706,7 +707,7 @@ def fill_null(
# else pl.lit(default_value, self.model.dtypes[column]).alias(column)
for column, default_value in self.model.defaults.items()
]
).set_model(self.model)
).set_model(self.model) # type: ignore

def get(self, predicate: pl.Expr | None = None) -> ModelType:
"""Fetch the single row that matches the given polars predicate.
Expand Down
18 changes: 11 additions & 7 deletions src/patito/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,19 @@ def __init__(cls, name: str, bases: tuple, clsdict: dict, **kwargs) -> None:

"""
super().__init__(name, bases, clsdict, **kwargs)
# Add a custom subclass of patito.DataFrame to the model class,
# where .set_model() has been implicitly set.
cls.DataFrame = DataFrame._construct_dataframe_model_class(
model=cls, # type: ignore
NewDataFrame = type(
f"{cls.__name__}DataFrame",
(DataFrame,),
{"model": cls},
)
# Similarly for LazyFrame
cls.LazyFrame = LazyFrame._construct_lazyframe_model_class(
model=cls, # type: ignore
cls.DataFrame: type[DataFrame[cls]] = NewDataFrame # type: ignore

NewLazyFrame = type(
f"{cls.__name__}LazyFrame",
(LazyFrame,),
{"model": cls},
)
cls.LazyFrame: type[LazyFrame[cls]] = NewLazyFrame # type: ignore

def __hash__(self) -> int:
"""Return hash of the model class."""
Expand Down
Loading