From fed230b2dc5eae65662e7505dbe3d4732a3ec01b Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Sun, 18 Aug 2024 13:43:06 +0100 Subject: [PATCH 1/2] Modify metaclass to generate DataFrame directly --- src/patito/pydantic.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/patito/pydantic.py b/src/patito/pydantic.py index 05aa1d4..fcd412a 100644 --- a/src/patito/pydantic.py +++ b/src/patito/pydantic.py @@ -76,15 +76,19 @@ def __init__(cls, name: str, bases: tuple, clsdict: dict, **kwargs) -> None: """ super().__init__(name, bases, clsdict, **kwargs) - # Add a custom subclass of patito.DataFrame to the model class, - # where .set_model() has been implicitly set. - cls.DataFrame = DataFrame._construct_dataframe_model_class( - model=cls, # type: ignore + NewDataFrame = type( + f"{cls.__name__}DataFrame", + (DataFrame,), + {"model": cls}, ) - # Similarly for LazyFrame - cls.LazyFrame = LazyFrame._construct_lazyframe_model_class( - model=cls, # type: ignore + cls.DataFrame: type[DataFrame[cls]] = NewDataFrame # type: ignore + + NewLazyFrame = type( + f"{cls.__name__}LazyFrame", + (LazyFrame,), + {"model": cls}, ) + cls.LazyFrame: type[LazyFrame[cls]] = NewLazyFrame # type: ignore def __hash__(self) -> int: """Return hash of the model class.""" From 4cb4bd8401b3492e558f5d9e61c5b86763f0ad5f Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Sun, 18 Aug 2024 13:43:44 +0100 Subject: [PATCH 2/2] Modify frame code to support new metaclass --- src/patito/polars.py | 135 ++++++++++++++++++++++--------------------- 1 file changed, 68 insertions(+), 67 deletions(-) diff --git a/src/patito/polars.py b/src/patito/polars.py index de5643c..5d5364c 100644 --- a/src/patito/polars.py +++ b/src/patito/polars.py @@ -54,33 +54,63 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]): model: type[ModelType] - @classmethod - def _construct_lazyframe_model_class( - cls: type[LDF], model: type[ModelType] | None - ) -> type[LazyFrame[ModelType]]: - """Return custom LazyFrame sub-class where LazyFrame.model is set. + def set_model(self, model: type[OtherModelType]) -> LazyFrame[OtherModelType]: + """Associate a given patito ``Model`` with the dataframe. - Can be used to construct a LazyFrame class where - DataFrame.set_model(model) is implicitly invoked at collection. + The model schema is used by methods that depend on a model being associated with + the given dataframe such as :ref:`DataFrame.validate() ` + and :ref:`DataFrame.get() `. + + ``DataFrame(...).set_model(Model)`` is equivalent with ``Model.DataFrame(...)``. Args: - model: A patito model which should be used to validate the final dataframe. - If None is provided, the regular LazyFrame class will be returned. + model (Model): Sub-class of ``patito.Model`` declaring the schema of the + dataframe. Returns: - A custom LazyFrame model class where LazyFrame.model has been correctly - "hard-coded" to the given model. + DataFrame[Model]: Returns the same dataframe, but with an attached model + that is required for certain model-specific dataframe methods to work. - """ - if model is None: - return cls + Examples: + >>> from typing_extensions import Literal + >>> import patito as pt + >>> import polars as pl + >>> class SchoolClass(pt.Model): + ... year: int = pt.Field(dtype=pl.UInt16) + ... letter: Literal["A", "B"] = pt.Field(dtype=pl.Categorical) + ... + >>> classes = pt.DataFrame( + ... {"year": [1, 1, 2, 2], "letter": list("ABAB")} + ... ).set_model(SchoolClass) + >>> classes + shape: (4, 2) + ┌──────┬────────┐ + │ year ┆ letter │ + │ --- ┆ --- │ + │ i64 ┆ str │ + ╞══════╪════════╡ + │ 1 ┆ A │ + │ 1 ┆ B │ + │ 2 ┆ A │ + │ 2 ┆ B │ + └──────┴────────┘ + >>> casted_classes = classes.cast() + >>> casted_classes + shape: (4, 2) + ┌──────┬────────┐ + │ year ┆ letter │ + │ --- ┆ --- │ + │ u16 ┆ cat │ + ╞══════╪════════╡ + │ 1 ┆ A │ + │ 1 ┆ B │ + │ 2 ┆ A │ + │ 2 ┆ B │ + └──────┴────────┘ + >>> casted_classes.validate() - new_class = type( - f"{model.__name__}LazyFrame", - (cls,), - {"model": model}, - ) - return new_class + """ + return model.LazyFrame._from_pyldf(self._ldf) # type: ignore def collect( self, @@ -93,12 +123,11 @@ def collect( parameters. """ background = kwargs.pop("background", False) - df = super().collect(*args, background=background, **kwargs) + df: pl.DataFrame = super().collect(*args, background=background, **kwargs) + df = DataFrame(df) if getattr(self, "model", False): - cls = DataFrame._construct_dataframe_model_class(model=self.model) - else: - cls = DataFrame - return cls._from_pydf(df._df) + df = df.set_model(self.model) + return df def derive(self: LDF, columns: list[str] | None = None) -> LDF: """Populate columns which have ``pt.Field(derived_from=...)`` definitions. @@ -307,7 +336,10 @@ def cast( @classmethod def from_existing(cls: type[LDF], lf: pl.LazyFrame) -> LDF: """Construct a patito.DataFrame object from an existing polars.DataFrame object.""" - return cls.model.LazyFrame._from_pyldf(lf._ldf).cast() + if getattr(cls, "model", False): + return cls.model.LazyFrame._from_pyldf(super().lazy()._ldf) # type: ignore + + return LazyFrame._from_pyldf(lf._ldf) # type: ignore class DataFrame(pl.DataFrame, Generic[ModelType]): @@ -341,30 +373,6 @@ class DataFrame(pl.DataFrame, Generic[ModelType]): model: type[ModelType] - @classmethod - def _construct_dataframe_model_class( - cls: type[DF], model: type[OtherModelType] - ) -> type[DataFrame[OtherModelType]]: - """Return custom DataFrame sub-class where DataFrame.model is set. - - Can be used to construct a DataFrame class where - DataFrame.set_model(model) is implicitly invoked at instantiation. - - Args: - model: A patito model which should be used to validate the dataframe. - - Returns: - A custom DataFrame model class where DataFrame._model has been correctly - "hard-coded" to the given model. - - """ - new_class = type( - f"{model.model_json_schema()['title']}DataFrame", - (cls,), - {"model": model}, - ) - return new_class - def lazy(self: DataFrame[ModelType]) -> LazyFrame[ModelType]: """Convert DataFrame into LazyFrame. @@ -374,15 +382,12 @@ def lazy(self: DataFrame[ModelType]) -> LazyFrame[ModelType]: A new LazyFrame object. """ - lazyframe_class: LazyFrame[ModelType] = ( - LazyFrame._construct_lazyframe_model_class( - model=getattr(self, "model", None) - ) - ) # type: ignore - ldf = lazyframe_class._from_pyldf(super().lazy()._ldf) - return ldf + if getattr(self, "model", False): + return self.model.LazyFrame._from_pyldf(super().lazy()._ldf) # type: ignore + + return LazyFrame._from_pyldf(super().lazy()._ldf) # type: ignore - def set_model(self, model): # type: ignore[no-untyped-def] # noqa: ANN001, ANN201 + def set_model(self, model: type[OtherModelType]) -> DataFrame[OtherModelType]: """Associate a given patito ``Model`` with the dataframe. The model schema is used by methods that depend on a model being associated with @@ -438,11 +443,7 @@ def set_model(self, model): # type: ignore[no-untyped-def] # noqa: ANN001, ANN2 >>> casted_classes.validate() """ - cls = self._construct_dataframe_model_class(model=model) - return cast( - DataFrame[model], - cls._from_pydf(self._df), - ) + return model.DataFrame(self._df) def unalias(self: DF) -> DF: """Un-aliases column names using information from pydantic validation_alias. @@ -503,7 +504,6 @@ def cast( def drop( self: DF, columns: str | Collection[str] | None = None, - *more_columns: str, ) -> DF: """Drop one or more columns from the dataframe. @@ -515,7 +515,6 @@ def drop( columns: A single column string name, or list of strings, indicating which columns to drop. If not specified, all columns *not* specified by the associated dataframe model will be dropped. - more_columns: Additional named columns to drop. Returns: DataFrame[Model]: New dataframe without the specified columns. @@ -538,7 +537,9 @@ def drop( """ if columns is not None: - return self._from_pydf(super().drop(columns)._df) + # I get a single null row if I try to use super() here, so go via + # pl.DataFrame instead. + return self._from_pydf(pl.DataFrame(self._df).drop(columns)._df) else: return self.drop(list(set(self.columns) - set(self.model.columns))) @@ -706,7 +707,7 @@ def fill_null( # else pl.lit(default_value, self.model.dtypes[column]).alias(column) for column, default_value in self.model.defaults.items() ] - ).set_model(self.model) + ).set_model(self.model) # type: ignore def get(self, predicate: pl.Expr | None = None) -> ModelType: """Fetch the single row that matches the given polars predicate.