From fb97947c2b6c141d8e84a4c874cc961193d52f0d Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 21 Feb 2024 13:59:34 +0000 Subject: [PATCH] we can make typing...better? --- narwhals/pandas_like/dataframe.py | 45 ++++++++++++++++--------------- narwhals/spec/__init__.py | 4 +-- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/narwhals/pandas_like/dataframe.py b/narwhals/pandas_like/dataframe.py index c9c9260dc..c01f974b3 100644 --- a/narwhals/pandas_like/dataframe.py +++ b/narwhals/pandas_like/dataframe.py @@ -43,6 +43,9 @@ def __init__( def columns(self) -> list[str]: return self.dataframe.columns.tolist() + def _dispatch_to_lazy(self, method: str, *args: Any, **kwargs: Any) -> Self: + return getattr(self.lazy(), method)(*args, **kwargs).collect() + def __repr__(self) -> str: # pragma: no cover header = f" Standard DataFrame (api_version={self.api_version}) " length = len(header) @@ -100,42 +103,40 @@ def select( self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr, - ) -> DataFrameT: - return self.lazy().select(*exprs, **named_exprs).collect() + ) -> Self: + return self._dispatch_to_lazy("select", *exprs, **named_exprs) def filter( self, *predicates: IntoExpr | Iterable[IntoExpr], - ) -> DataFrameT: - return self.lazy().filter(*predicates).collect() + ) -> Self: + return self._dispatch_to_lazy("filter", *predicates) def with_columns( self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr, - ) -> DataFrameT: - return self.lazy().with_columns(*exprs, **named_exprs).collect() + ) -> Self: + return self._dispatch_to_lazy("with_columns", *exprs, **named_exprs) def sort( self, by: str | Iterable[str], *more_by: str, descending: bool | Iterable[bool] = False, - ) -> DataFrameT: - return self.lazy().sort(by, *more_by, descending=descending).collect() + ) -> Self: + return self._dispatch_to_lazy("sort", by, *more_by, descending=descending) def join( self, - other: DataFrameT, + other: Self, *, how: Literal["left", "inner", "outer"] = "inner", left_on: str | list[str], right_on: str | list[str], - ) -> DataFrameT: - return ( - self.lazy() - .join(other.lazy(), how=how, left_on=left_on, right_on=right_on) - .collect() + ) -> Self: + return self._dispatch_to_lazy( + "join", other.lazy(), how=how, left_on=left_on, right_on=right_on ) def lazy(self) -> LazyFrame: @@ -145,14 +146,14 @@ def lazy(self) -> LazyFrame: implementation=self._implementation, ) - def head(self, n: int) -> DataFrameT: - return self.lazy().head(n).collect() + def head(self, n: int) -> Self: + return self._dispatch_to_lazy("head", n) - def unique(self, subset: list[str]) -> DataFrameT: - return self.lazy().unique(subset).collect() + def unique(self, subset: list[str]) -> Self: + return self._dispatch_to_lazy("unique", subset) - def rename(self, mapping: dict[str, str]) -> DataFrameT: - return self.lazy().rename(mapping).collect() + def rename(self, mapping: dict[str, str]) -> Self: + return self._dispatch_to_lazy("rename", mapping) def to_numpy(self) -> Any: return self.dataframe.to_numpy() @@ -301,7 +302,7 @@ def sort( # Other def join( self, - other: LazyFrameT, + other: Self, *, how: Literal["left", "inner", "outer"] = "inner", left_on: str | list[str], @@ -332,7 +333,7 @@ def join( ) # Conversion - def collect(self) -> DataFrameT: + def collect(self) -> DataFrame: return DataFrame( self.dataframe, api_version=self.api_version, diff --git a/narwhals/spec/__init__.py b/narwhals/spec/__init__.py index ab157ea55..72e57066f 100644 --- a/narwhals/spec/__init__.py +++ b/narwhals/spec/__init__.py @@ -189,7 +189,7 @@ def lazy(self) -> LazyFrame: def join( self, - other: DataFrame, + other: Self, *, how: Literal["inner"] = "inner", left_on: str | list[str], @@ -255,7 +255,7 @@ def group_by(self, *keys: str | Iterable[str]) -> LazyGroupBy: def join( self, - other: LazyFrame, + other: Self, *, how: Literal["left", "inner", "outer"] = "inner", left_on: str | list[str],