From 2139f934a3997e8969604a24e6eaf682bd030e30 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Sun, 29 Sep 2024 21:23:17 +0100 Subject: [PATCH] feat: add dtypes to stable api (#1087) --- docs/how_it_works.md | 5 ++ narwhals/_arrow/dataframe.py | 43 ++++++++++++--- narwhals/_arrow/expr.py | 27 ++++++++- narwhals/_arrow/namespace.py | 79 +++++++++++++------------- narwhals/_arrow/selectors.py | 44 +++++++++------ narwhals/_arrow/series.py | 26 ++++++--- narwhals/_arrow/utils.py | 9 ++- narwhals/_dask/dataframe.py | 19 +++++-- narwhals/_dask/expr.py | 20 ++++++- narwhals/_dask/group_by.py | 3 +- narwhals/_dask/namespace.py | 80 ++++++++++++--------------- narwhals/_dask/selectors.py | 44 +++++++++------ narwhals/_dask/utils.py | 5 +- narwhals/_duckdb/dataframe.py | 18 +++--- narwhals/_duckdb/series.py | 11 +++- narwhals/_ibis/dataframe.py | 15 +++-- narwhals/_ibis/series.py | 11 +++- narwhals/_interchange/dataframe.py | 19 ++++--- narwhals/_interchange/series.py | 18 +++--- narwhals/_pandas_like/dataframe.py | 19 ++++++- narwhals/_pandas_like/expr.py | 21 ++++++- narwhals/_pandas_like/group_by.py | 1 + narwhals/_pandas_like/namespace.py | 76 ++++++++++++++----------- narwhals/_pandas_like/selectors.py | 45 +++++++++------ narwhals/_pandas_like/series.py | 15 ++++- narwhals/_pandas_like/utils.py | 20 ++++--- narwhals/_polars/dataframe.py | 68 +++++++++++++++++------ narwhals/_polars/expr.py | 8 ++- narwhals/_polars/namespace.py | 89 +++++++++++++++--------------- narwhals/_polars/series.py | 26 ++++++--- narwhals/_polars/utils.py | 11 ++-- narwhals/functions.py | 57 +++++++++++++++++-- narwhals/stable/v1/__init__.py | 73 +++++++++++++++--------- narwhals/stable/v1/dtypes.py | 47 ++++++++++++++++ narwhals/translate.py | 71 ++++++++++++++++++++---- narwhals/typing.py | 27 +++++++++ narwhals/utils.py | 3 +- tests/from_dict_test.py | 22 ++++++-- tests/new_series_test.py | 22 +++++++- tests/stable_api_test.py | 8 +++ 40 files changed, 851 insertions(+), 374 deletions(-) create mode 100644 narwhals/stable/v1/dtypes.py diff --git a/docs/how_it_works.md b/docs/how_it_works.md index cda98a2b66..cc808cc6f9 100644 --- a/docs/how_it_works.md +++ b/docs/how_it_works.md @@ -75,6 +75,7 @@ from narwhals.utils import parse_version pn = PandasLikeNamespace( implementation=Implementation.PANDAS, backend_version=parse_version(pd.__version__), + dtypes=nw.dtypes, ) print(nw.col("a")._call(pn)) ``` @@ -101,6 +102,7 @@ import pandas as pd pn = PandasLikeNamespace( implementation=Implementation.PANDAS, backend_version=parse_version(pd.__version__), + dtypes=nw.dtypes, ) df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) @@ -108,6 +110,7 @@ df = PandasLikeDataFrame( df_pd, implementation=Implementation.PANDAS, backend_version=parse_version(pd.__version__), + dtypes=nw.dtypes, ) expression = pn.col("a") + 1 result = expression._call(df) @@ -196,6 +199,7 @@ import pandas as pd pn = PandasLikeNamespace( implementation=Implementation.PANDAS, backend_version=parse_version(pd.__version__), + dtypes=nw.dtypes, ) df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) @@ -210,6 +214,7 @@ backend, and it does so by passing a Narwhals-compliant namespace to `nw.Expr._c pn = PandasLikeNamespace( implementation=Implementation.PANDAS, backend_version=parse_version(pd.__version__), + dtypes=nw.dtypes, ) expr = (nw.col("a") + 1)._call(pn) print(expr) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 12cb16d2eb..efc3431773 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -33,21 +33,27 @@ from narwhals._arrow.series import ArrowSeries from narwhals._arrow.typing import IntoArrowExpr from narwhals.dtypes import DType + from narwhals.typing import DTypes class ArrowDataFrame: # --- not in the spec --- def __init__( - self, native_dataframe: pa.Table, *, backend_version: tuple[int, ...] + self, + native_dataframe: pa.Table, + *, + backend_version: tuple[int, ...], + dtypes: DTypes, ) -> None: self._native_frame = native_dataframe self._implementation = Implementation.PYARROW self._backend_version = backend_version + self._dtypes = dtypes def __narwhals_namespace__(self) -> ArrowNamespace: from narwhals._arrow.namespace import ArrowNamespace - return ArrowNamespace(backend_version=self._backend_version) + return ArrowNamespace(backend_version=self._backend_version, dtypes=self._dtypes) def __native_namespace__(self: Self) -> ModuleType: if self._implementation is Implementation.PYARROW: @@ -63,7 +69,9 @@ def __narwhals_lazyframe__(self) -> Self: return self def _from_native_frame(self, df: Any) -> Self: - return self.__class__(df, backend_version=self._backend_version) + return self.__class__( + df, backend_version=self._backend_version, dtypes=self._dtypes + ) @property def shape(self) -> tuple[int, int]: @@ -111,6 +119,7 @@ def get_column(self, name: str) -> ArrowSeries: self._native_frame[name], name=name, backend_version=self._backend_version, + dtypes=self._dtypes, ) def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray: @@ -151,6 +160,7 @@ def __getitem__( self._native_frame[item], name=item, backend_version=self._backend_version, + dtypes=self._dtypes, ) elif ( isinstance(item, tuple) @@ -191,12 +201,14 @@ def __getitem__( self._native_frame[col_name], name=col_name, backend_version=self._backend_version, + dtypes=self._dtypes, ) selected_rows = select_rows(self._native_frame, item[0]) return ArrowSeries( selected_rows[col_name], name=col_name, backend_version=self._backend_version, + dtypes=self._dtypes, ) elif isinstance(item, slice): @@ -234,7 +246,7 @@ def __getitem__( def schema(self) -> dict[str, DType]: schema = self._native_frame.schema return { - name: native_to_narwhals_dtype(dtype) + name: native_to_narwhals_dtype(dtype, self._dtypes) for name, dtype in zip(schema.names, schema.types) } @@ -410,7 +422,12 @@ def to_dict(self, *, as_series: bool) -> Any: from narwhals._arrow.series import ArrowSeries return { - name: ArrowSeries(col, name=name, backend_version=self._backend_version) + name: ArrowSeries( + col, + name=name, + backend_version=self._backend_version, + dtypes=self._dtypes, + ) for name, col in names_and_values } else: @@ -471,7 +488,9 @@ def lazy(self) -> Self: return self def collect(self) -> ArrowDataFrame: - return ArrowDataFrame(self._native_frame, backend_version=self._backend_version) + return ArrowDataFrame( + self._native_frame, backend_version=self._backend_version, dtypes=self._dtypes + ) def clone(self) -> Self: msg = "clone is not yet supported on PyArrow tables" @@ -541,7 +560,12 @@ def is_duplicated(self: Self) -> ArrowSeries: ).column(f"{col_token}_count"), 1, ) - return ArrowSeries(is_duplicated, name="", backend_version=self._backend_version) + return ArrowSeries( + is_duplicated, + name="", + backend_version=self._backend_version, + dtypes=self._dtypes, + ) def is_unique(self: Self) -> ArrowSeries: import pyarrow.compute as pc # ignore-banned-import() @@ -551,7 +575,10 @@ def is_unique(self: Self) -> ArrowSeries: is_duplicated = self.is_duplicated()._native_series return ArrowSeries( - pc.invert(is_duplicated), name="", backend_version=self._backend_version + pc.invert(is_duplicated), + name="", + backend_version=self._backend_version, + dtypes=self._dtypes, ) def unique( diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 367dc9b448..6d1001c110 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -17,6 +17,7 @@ from narwhals._arrow.series import ArrowSeries from narwhals._arrow.typing import IntoArrowExpr from narwhals.dtypes import DType + from narwhals.typing import DTypes class ArrowExpr: @@ -29,6 +30,7 @@ def __init__( root_names: list[str] | None, output_names: list[str] | None, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> None: self._call = call self._depth = depth @@ -38,6 +40,7 @@ def __init__( self._output_names = output_names self._implementation = Implementation.PYARROW self._backend_version = backend_version + self._dtypes = dtypes def __repr__(self) -> str: # pragma: no cover return ( @@ -50,7 +53,10 @@ def __repr__(self) -> str: # pragma: no cover @classmethod def from_column_names( - cls: type[Self], *column_names: str, backend_version: tuple[int, ...] + cls: type[Self], + *column_names: str, + backend_version: tuple[int, ...], + dtypes: DTypes, ) -> Self: from narwhals._arrow.series import ArrowSeries @@ -60,6 +66,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: df._native_frame[column_name], name=column_name, backend_version=df._backend_version, + dtypes=df._dtypes, ) for column_name in column_names ] @@ -71,11 +78,15 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: root_names=list(column_names), output_names=list(column_names), backend_version=backend_version, + dtypes=dtypes, ) @classmethod def from_column_indices( - cls: type[Self], *column_indices: int, backend_version: tuple[int, ...] + cls: type[Self], + *column_indices: int, + backend_version: tuple[int, ...], + dtypes: DTypes, ) -> Self: from narwhals._arrow.series import ArrowSeries @@ -85,6 +96,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: df._native_frame[column_index], name=df._native_frame.column_names[column_index], backend_version=df._backend_version, + dtypes=df._dtypes, ) for column_index in column_indices ] @@ -96,12 +108,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: root_names=None, output_names=None, backend_version=backend_version, + dtypes=dtypes, ) def __narwhals_namespace__(self) -> ArrowNamespace: from narwhals._arrow.namespace import ArrowNamespace - return ArrowNamespace(backend_version=self._backend_version) + return ArrowNamespace(backend_version=self._backend_version, dtypes=self._dtypes) def __narwhals_expr__(self) -> None: ... @@ -246,6 +259,7 @@ def alias(self, name: str) -> Self: root_names=self._root_names, output_names=[name], backend_version=self._backend_version, + dtypes=self._dtypes, ) def null_count(self) -> Self: @@ -352,6 +366,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: root_names=self._root_names, output_names=self._output_names, backend_version=self._backend_version, + dtypes=self._dtypes, ) def mode(self: Self) -> Self: @@ -573,6 +588,7 @@ def keep(self: Self) -> ArrowExpr: root_names=root_names, output_names=root_names, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) def map(self: Self, function: Callable[[str], str]) -> ArrowExpr: @@ -598,6 +614,7 @@ def map(self: Self, function: Callable[[str], str]) -> ArrowExpr: root_names=root_names, output_names=output_names, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) def prefix(self: Self, prefix: str) -> ArrowExpr: @@ -621,6 +638,7 @@ def prefix(self: Self, prefix: str) -> ArrowExpr: root_names=root_names, output_names=output_names, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) def suffix(self: Self, suffix: str) -> ArrowExpr: @@ -645,6 +663,7 @@ def suffix(self: Self, suffix: str) -> ArrowExpr: root_names=root_names, output_names=output_names, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) def to_lowercase(self: Self) -> ArrowExpr: @@ -669,6 +688,7 @@ def to_lowercase(self: Self) -> ArrowExpr: root_names=root_names, output_names=output_names, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) def to_uppercase(self: Self) -> ArrowExpr: @@ -693,4 +713,5 @@ def to_uppercase(self: Self) -> ArrowExpr: root_names=root_names, output_names=output_names, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index e1fb8c6b93..3e7f4ecc9b 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -7,7 +7,6 @@ from typing import Literal from typing import cast -from narwhals import dtypes from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.expr import ArrowExpr from narwhals._arrow.selectors import ArrowSelectorNamespace @@ -23,32 +22,11 @@ from typing import Callable from narwhals._arrow.typing import IntoArrowExpr + from narwhals.dtypes import DType + from narwhals.typing import DTypes class ArrowNamespace: - Int64 = dtypes.Int64 - Int32 = dtypes.Int32 - Int16 = dtypes.Int16 - Int8 = dtypes.Int8 - UInt64 = dtypes.UInt64 - UInt32 = dtypes.UInt32 - UInt16 = dtypes.UInt16 - UInt8 = dtypes.UInt8 - Float64 = dtypes.Float64 - Float32 = dtypes.Float32 - Boolean = dtypes.Boolean - Object = dtypes.Object - Unknown = dtypes.Unknown - Categorical = dtypes.Categorical - Enum = dtypes.Enum - String = dtypes.String - Datetime = dtypes.Datetime - Duration = dtypes.Duration - Date = dtypes.Date - List = dtypes.List - Struct = dtypes.Struct - Array = dtypes.Array - def _create_expr_from_callable( self, func: Callable[[ArrowDataFrame], list[ArrowSeries]], @@ -67,6 +45,7 @@ def _create_expr_from_callable( root_names=root_names, output_names=output_names, backend_version=self._backend_version, + dtypes=self._dtypes, ) def _create_expr_from_series(self, series: ArrowSeries) -> ArrowExpr: @@ -79,6 +58,7 @@ def _create_expr_from_series(self, series: ArrowSeries) -> ArrowExpr: root_names=None, output_names=None, backend_version=self._backend_version, + dtypes=self._dtypes, ) def _create_series_from_scalar(self, value: Any, series: ArrowSeries) -> ArrowSeries: @@ -90,6 +70,7 @@ def _create_series_from_scalar(self, value: Any, series: ArrowSeries) -> ArrowSe [value], name=series.name, backend_version=self._backend_version, + dtypes=self._dtypes, ) def _create_compliant_series(self, value: Any) -> ArrowSeries: @@ -101,26 +82,28 @@ def _create_compliant_series(self, value: Any) -> ArrowSeries: native_series=pa.chunked_array([value]), name="", backend_version=self._backend_version, + dtypes=self._dtypes, ) # --- not in spec --- - def __init__(self, *, backend_version: tuple[int, ...]) -> None: + def __init__(self, *, backend_version: tuple[int, ...], dtypes: DTypes) -> None: self._backend_version = backend_version self._implementation = Implementation.PYARROW + self._dtypes = dtypes # --- selection --- def col(self, *column_names: str) -> ArrowExpr: from narwhals._arrow.expr import ArrowExpr return ArrowExpr.from_column_names( - *column_names, backend_version=self._backend_version + *column_names, backend_version=self._backend_version, dtypes=self._dtypes ) def nth(self, *column_indices: int) -> ArrowExpr: from narwhals._arrow.expr import ArrowExpr return ArrowExpr.from_column_indices( - *column_indices, backend_version=self._backend_version + *column_indices, backend_version=self._backend_version, dtypes=self._dtypes ) def len(self) -> ArrowExpr: @@ -131,6 +114,7 @@ def len(self) -> ArrowExpr: [len(df._native_frame)], name="len", backend_version=self._backend_version, + dtypes=self._dtypes, ) ], depth=0, @@ -138,6 +122,7 @@ def len(self) -> ArrowExpr: root_names=None, output_names=["len"], backend_version=self._backend_version, + dtypes=self._dtypes, ) def all(self) -> ArrowExpr: @@ -150,6 +135,7 @@ def all(self) -> ArrowExpr: df._native_frame[column_name], name=column_name, backend_version=df._backend_version, + dtypes=df._dtypes, ) for column_name in df.columns ], @@ -158,14 +144,16 @@ def all(self) -> ArrowExpr: root_names=None, output_names=None, backend_version=self._backend_version, + dtypes=self._dtypes, ) - def lit(self, value: Any, dtype: dtypes.DType | None) -> ArrowExpr: + def lit(self, value: Any, dtype: DType | None) -> ArrowExpr: def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: arrow_series = ArrowSeries._from_iterable( data=[value], name="lit", backend_version=self._backend_version, + dtypes=self._dtypes, ) if dtype: return arrow_series.cast(dtype) @@ -178,6 +166,7 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: root_names=None, output_names=["lit"], backend_version=self._backend_version, + dtypes=self._dtypes, ) def all_horizontal(self, *exprs: IntoArrowExpr) -> ArrowExpr: @@ -230,7 +219,7 @@ def mean_horizontal(self, *exprs: IntoArrowExpr) -> IntoArrowExpr: total = reduce(lambda x, y: x + y, (e.fill_null(0.0) for e in arrow_exprs)) n_non_zero = reduce( lambda x, y: x + y, - ((1 - e.is_null().cast(self.Int64())) for e in arrow_exprs), + ((1 - e.is_null().cast(self._dtypes.Int64())) for e in arrow_exprs), ) return total / n_non_zero @@ -246,54 +235,54 @@ def concat( return ArrowDataFrame( horizontal_concat(dfs), backend_version=self._backend_version, + dtypes=self._dtypes, ) if how == "vertical": return ArrowDataFrame( vertical_concat(dfs), backend_version=self._backend_version, + dtypes=self._dtypes, ) raise NotImplementedError def sum(self, *column_names: str) -> ArrowExpr: return ArrowExpr.from_column_names( - *column_names, - backend_version=self._backend_version, + *column_names, backend_version=self._backend_version, dtypes=self._dtypes ).sum() def mean(self, *column_names: str) -> ArrowExpr: return ArrowExpr.from_column_names( - *column_names, - backend_version=self._backend_version, + *column_names, backend_version=self._backend_version, dtypes=self._dtypes ).mean() def max(self, *column_names: str) -> ArrowExpr: return ArrowExpr.from_column_names( - *column_names, - backend_version=self._backend_version, + *column_names, backend_version=self._backend_version, dtypes=self._dtypes ).max() def min(self, *column_names: str) -> ArrowExpr: return ArrowExpr.from_column_names( - *column_names, - backend_version=self._backend_version, + *column_names, backend_version=self._backend_version, dtypes=self._dtypes ).min() @property def selectors(self) -> ArrowSelectorNamespace: - return ArrowSelectorNamespace(backend_version=self._backend_version) + return ArrowSelectorNamespace( + backend_version=self._backend_version, dtypes=self._dtypes + ) def when( self, *predicates: IntoArrowExpr, ) -> ArrowWhen: - plx = self.__class__(backend_version=self._backend_version) + plx = self.__class__(backend_version=self._backend_version, dtypes=self._dtypes) if predicates: condition = plx.all_horizontal(*predicates) else: msg = "at least one predicate needs to be provided" raise TypeError(msg) - return ArrowWhen(condition, self._backend_version) + return ArrowWhen(condition, self._backend_version, dtypes=self._dtypes) class ArrowWhen: @@ -303,11 +292,14 @@ def __init__( backend_version: tuple[int, ...], then_value: Any = None, otherwise_value: Any = None, + *, + dtypes: DTypes, ) -> None: self._backend_version = backend_version self._condition = condition self._then_value = then_value self._otherwise_value = otherwise_value + self._dtypes = dtypes def __call__(self, df: ArrowDataFrame) -> list[ArrowSeries]: import pyarrow as pa # ignore-banned-import @@ -316,7 +308,7 @@ def __call__(self, df: ArrowDataFrame) -> list[ArrowSeries]: from narwhals._arrow.namespace import ArrowNamespace from narwhals._expression_parsing import parse_into_expr - plx = ArrowNamespace(backend_version=self._backend_version) + plx = ArrowNamespace(backend_version=self._backend_version, dtypes=self._dtypes) condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] # type: ignore[arg-type] try: @@ -327,6 +319,7 @@ def __call__(self, df: ArrowDataFrame) -> list[ArrowSeries]: [self._then_value] * len(condition), name="literal", backend_version=self._backend_version, + dtypes=self._dtypes, ) value_series = cast(ArrowSeries, value_series) @@ -370,6 +363,7 @@ def then(self, value: ArrowExpr | ArrowSeries | Any) -> ArrowThen: root_names=None, output_names=None, backend_version=self._backend_version, + dtypes=self._dtypes, ) @@ -383,9 +377,10 @@ def __init__( root_names: list[str] | None, output_names: list[str] | None, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> None: self._backend_version = backend_version - + self._dtypes = dtypes self._call = call self._depth = depth self._function_name = function_name diff --git a/narwhals/_arrow/selectors.py b/narwhals/_arrow/selectors.py index 569724c458..d5a8ccae09 100644 --- a/narwhals/_arrow/selectors.py +++ b/narwhals/_arrow/selectors.py @@ -4,7 +4,6 @@ from typing import Any from typing import NoReturn -from narwhals import dtypes from narwhals._arrow.expr import ArrowExpr from narwhals.utils import Implementation @@ -14,12 +13,14 @@ from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.series import ArrowSeries from narwhals.dtypes import DType + from narwhals.typing import DTypes class ArrowSelectorNamespace: - def __init__(self: Self, *, backend_version: tuple[int, ...]) -> None: + def __init__(self: Self, *, backend_version: tuple[int, ...], dtypes: DTypes) -> None: self._backend_version = backend_version self._implementation = Implementation.PYARROW + self._dtypes = dtypes def by_dtype(self: Self, dtypes: list[DType | type[DType]]) -> ArrowSelector: def func(df: ArrowDataFrame) -> list[ArrowSeries]: @@ -32,32 +33,33 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: root_names=None, output_names=None, backend_version=self._backend_version, + dtypes=self._dtypes, ) def numeric(self: Self) -> ArrowSelector: return self.by_dtype( [ - dtypes.Int64, - dtypes.Int32, - dtypes.Int16, - dtypes.Int8, - dtypes.UInt64, - dtypes.UInt32, - dtypes.UInt16, - dtypes.UInt8, - dtypes.Float64, - dtypes.Float32, + self._dtypes.Int64, + self._dtypes.Int32, + self._dtypes.Int16, + self._dtypes.Int8, + self._dtypes.UInt64, + self._dtypes.UInt32, + self._dtypes.UInt16, + self._dtypes.UInt8, + self._dtypes.Float64, + self._dtypes.Float32, ], ) def categorical(self: Self) -> ArrowSelector: - return self.by_dtype([dtypes.Categorical]) + return self.by_dtype([self._dtypes.Categorical]) def string(self: Self) -> ArrowSelector: - return self.by_dtype([dtypes.String]) + return self.by_dtype([self._dtypes.String]) def boolean(self: Self) -> ArrowSelector: - return self.by_dtype([dtypes.Boolean]) + return self.by_dtype([self._dtypes.Boolean]) def all(self: Self) -> ArrowSelector: def func(df: ArrowDataFrame) -> list[ArrowSeries]: @@ -70,6 +72,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: root_names=None, output_names=None, backend_version=self._backend_version, + dtypes=self._dtypes, ) @@ -91,6 +94,7 @@ def _to_expr(self: Self) -> ArrowExpr: root_names=self._root_names, output_names=self._output_names, backend_version=self._backend_version, + dtypes=self._dtypes, ) def __sub__(self: Self, other: Self | Any) -> ArrowSelector | Any: @@ -108,6 +112,7 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]: root_names=None, output_names=None, backend_version=self._backend_version, + dtypes=self._dtypes, ) else: return self._to_expr() - other @@ -127,6 +132,7 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]: root_names=None, output_names=None, backend_version=self._backend_version, + dtypes=self._dtypes, ) else: return self._to_expr() | other @@ -146,12 +152,18 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]: root_names=None, output_names=None, backend_version=self._backend_version, + dtypes=self._dtypes, ) else: return self._to_expr() & other def __invert__(self: Self) -> ArrowSelector: - return ArrowSelectorNamespace(backend_version=self._backend_version).all() - self + return ( + ArrowSelectorNamespace( + backend_version=self._backend_version, dtypes=self._dtypes + ).all() + - self + ) def __rsub__(self: Self, other: Any) -> NoReturn: raise NotImplementedError diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index a3c3e89e94..183cf37b7f 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -25,6 +25,7 @@ from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.namespace import ArrowNamespace from narwhals.dtypes import DType + from narwhals.typing import DTypes class ArrowSeries: @@ -34,11 +35,13 @@ def __init__( *, name: str, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> None: self._name = name self._native_series = native_series self._implementation = Implementation.PYARROW self._backend_version = backend_version + self._dtypes = dtypes def _from_native_series(self, series: Any) -> Self: import pyarrow as pa # ignore-banned-import() @@ -49,6 +52,7 @@ def _from_native_series(self, series: Any) -> Self: series, name=self._name, backend_version=self._backend_version, + dtypes=self._dtypes, ) @classmethod @@ -58,6 +62,7 @@ def _from_iterable( name: str, *, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> Self: import pyarrow as pa # ignore-banned-import() @@ -65,12 +70,13 @@ def _from_iterable( pa.chunked_array([data]), name=name, backend_version=backend_version, + dtypes=dtypes, ) def __narwhals_namespace__(self) -> ArrowNamespace: from narwhals._arrow.namespace import ArrowNamespace - return ArrowNamespace(backend_version=self._backend_version) + return ArrowNamespace(backend_version=self._backend_version, dtypes=self._dtypes) def __len__(self) -> int: return len(self._native_series) @@ -361,11 +367,12 @@ def alias(self, name: str) -> Self: self._native_series, name=name, backend_version=self._backend_version, + dtypes=self._dtypes, ) @property def dtype(self: Self) -> DType: - return native_to_narwhals_dtype(self._native_series.type) + return native_to_narwhals_dtype(self._native_series.type, self._dtypes) def abs(self) -> Self: import pyarrow.compute as pc # ignore-banned-import() @@ -438,7 +445,7 @@ def cast(self, dtype: DType) -> Self: import pyarrow.compute as pc # ignore-banned-import() ser = self._native_series - dtype = narwhals_to_native_dtype(dtype) + dtype = narwhals_to_native_dtype(dtype, self._dtypes) return self._from_native_series(pc.cast(ser, dtype)) def null_count(self: Self) -> int: @@ -474,7 +481,10 @@ def arg_true(self) -> Self: ser = self._native_series res = np.flatnonzero(ser) return self._from_iterable( - res, name=self.name, backend_version=self._backend_version + res, + name=self.name, + backend_version=self._backend_version, + dtypes=self._dtypes, ) def item(self: Self, index: int | None = None) -> Any: @@ -520,8 +530,7 @@ def value_counts( val_count = val_count.sort_by([(value_name_, "descending")]) return ArrowDataFrame( - val_count, - backend_version=self._backend_version, + val_count, backend_version=self._backend_version, dtypes=self._dtypes ) def zip_with(self: Self, mask: Self, other: Self) -> Self: @@ -574,7 +583,9 @@ def to_frame(self: Self) -> ArrowDataFrame: from narwhals._arrow.dataframe import ArrowDataFrame df = pa.Table.from_arrays([self._native_series], names=[self.name]) - return ArrowDataFrame(df, backend_version=self._backend_version) + return ArrowDataFrame( + df, backend_version=self._backend_version, dtypes=self._dtypes + ) def to_pandas(self: Self) -> Any: import pandas as pd # ignore-banned-import() @@ -670,6 +681,7 @@ def to_dummies( return ArrowDataFrame( pa.Table.from_arrays(columns, names=names), backend_version=self._backend_version, + dtypes=self._dtypes, ).select(*sorted(names)[int(drop_first) :]) def quantile( diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 90195a386c..d51a4b25de 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -4,16 +4,17 @@ from typing import Any from typing import Sequence -from narwhals import dtypes from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: import pyarrow as pa from narwhals._arrow.series import ArrowSeries + from narwhals.dtypes import DType + from narwhals.typing import DTypes -def native_to_narwhals_dtype(dtype: Any) -> dtypes.DType: +def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType: import pyarrow as pa # ignore-banned-import if pa.types.is_int64(dtype): @@ -63,11 +64,9 @@ def native_to_narwhals_dtype(dtype: Any) -> dtypes.DType: return dtypes.Unknown() # pragma: no cover -def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any: +def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any: import pyarrow as pa # ignore-banned-import - from narwhals import dtypes - if isinstance_or_issubclass(dtype, dtypes.Float64): return pa.float64() if isinstance_or_issubclass(dtype, dtypes.Float32): diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 9538d6d890..916583eaa4 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -26,15 +26,21 @@ from narwhals._dask.namespace import DaskNamespace from narwhals._dask.typing import IntoDaskExpr from narwhals.dtypes import DType + from narwhals.typing import DTypes class DaskLazyFrame: def __init__( - self, native_dataframe: dd.DataFrame, *, backend_version: tuple[int, ...] + self, + native_dataframe: dd.DataFrame, + *, + backend_version: tuple[int, ...], + dtypes: DTypes, ) -> None: self._native_frame = native_dataframe self._backend_version = backend_version self._implementation = Implementation.DASK + self._dtypes = dtypes def __native_namespace__(self: Self) -> ModuleType: if self._implementation is Implementation.DASK: @@ -46,13 +52,15 @@ def __native_namespace__(self: Self) -> ModuleType: def __narwhals_namespace__(self) -> DaskNamespace: from narwhals._dask.namespace import DaskNamespace - return DaskNamespace(backend_version=self._backend_version) + return DaskNamespace(backend_version=self._backend_version, dtypes=self._dtypes) def __narwhals_lazyframe__(self) -> Self: return self def _from_native_frame(self, df: Any) -> Self: - return self.__class__(df, backend_version=self._backend_version) + return self.__class__( + df, backend_version=self._backend_version, dtypes=self._dtypes + ) def with_columns(self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self: df = self._native_frame @@ -70,6 +78,7 @@ def collect(self) -> Any: result, implementation=Implementation.PANDAS, backend_version=parse_version(pd.__version__), + dtypes=self._dtypes, ) @property @@ -92,7 +101,7 @@ def filter( from narwhals._dask.namespace import DaskNamespace - plx = DaskNamespace(backend_version=self._backend_version) + plx = DaskNamespace(backend_version=self._backend_version, dtypes=self._dtypes) expr = plx.all_horizontal(*predicates) # Safety: all_horizontal's expression only returns a single column. mask = expr._call(self)[0] @@ -140,7 +149,7 @@ def drop_nulls(self: Self, subset: str | list[str] | None) -> Self: @property def schema(self) -> dict[str, DType]: return { - col: native_to_narwhals_dtype(self._native_frame.loc[:, col]) + col: native_to_narwhals_dtype(self._native_frame.loc[:, col], self._dtypes) for col in self._native_frame.columns } diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index eda0fd5895..d8d86692e3 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -19,6 +19,7 @@ from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.namespace import DaskNamespace from narwhals.dtypes import DType + from narwhals.typing import DTypes class DaskExpr: @@ -34,6 +35,7 @@ def __init__( # a reduction, such as `nw.col('a').sum()` returns_scalar: bool, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> None: self._call = call self._depth = depth @@ -42,6 +44,7 @@ def __init__( self._output_names = output_names self._returns_scalar = returns_scalar self._backend_version = backend_version + self._dtypes = dtypes def __narwhals_expr__(self) -> None: ... @@ -49,13 +52,14 @@ def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover # Unused, just for compatibility with PandasLikeExpr from narwhals._dask.namespace import DaskNamespace - return DaskNamespace(backend_version=self._backend_version) + return DaskNamespace(backend_version=self._backend_version, dtypes=self._dtypes) @classmethod def from_column_names( cls: type[Self], *column_names: str, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> Self: def func(df: DaskLazyFrame) -> list[dask_expr.Series]: return [df._native_frame.loc[:, column_name] for column_name in column_names] @@ -68,6 +72,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: output_names=list(column_names), returns_scalar=False, backend_version=backend_version, + dtypes=dtypes, ) @classmethod @@ -75,6 +80,7 @@ def from_column_indices( cls: type[Self], *column_indices: int, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> Self: def func(df: DaskLazyFrame) -> list[dask_expr.Series]: return [ @@ -89,6 +95,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: output_names=None, returns_scalar=False, backend_version=backend_version, + dtypes=dtypes, ) def _from_call( @@ -146,6 +153,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: output_names=output_names, returns_scalar=self._returns_scalar or returns_scalar, backend_version=self._backend_version, + dtypes=self._dtypes, ) def alias(self, name: str) -> Self: @@ -161,6 +169,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: output_names=[name], returns_scalar=self._returns_scalar, backend_version=self._backend_version, + dtypes=self._dtypes, ) def __add__(self, other: Any) -> Self: @@ -677,6 +686,7 @@ def func(df: DaskLazyFrame) -> list[Any]: output_names=self._output_names, returns_scalar=False, backend_version=self._backend_version, + dtypes=self._dtypes, ) def mode(self: Self) -> Self: @@ -700,7 +710,7 @@ def cast( dtype: DType | type[DType], ) -> Self: def func(_input: Any, dtype: DType | type[DType]) -> Any: - dtype = narwhals_to_native_dtype(dtype) + dtype = narwhals_to_native_dtype(dtype, self._dtypes) return _input.astype(dtype) return self._from_call( @@ -977,6 +987,7 @@ def keep(self: Self) -> DaskExpr: output_names=root_names, returns_scalar=self._expr._returns_scalar, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) def map(self: Self, function: Callable[[str], str]) -> DaskExpr: @@ -1003,6 +1014,7 @@ def map(self: Self, function: Callable[[str], str]) -> DaskExpr: output_names=output_names, returns_scalar=self._expr._returns_scalar, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) def prefix(self: Self, prefix: str) -> DaskExpr: @@ -1027,6 +1039,7 @@ def prefix(self: Self, prefix: str) -> DaskExpr: output_names=output_names, returns_scalar=self._expr._returns_scalar, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) def suffix(self: Self, suffix: str) -> DaskExpr: @@ -1052,6 +1065,7 @@ def suffix(self: Self, suffix: str) -> DaskExpr: output_names=output_names, returns_scalar=self._expr._returns_scalar, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) def to_lowercase(self: Self) -> DaskExpr: @@ -1077,6 +1091,7 @@ def to_lowercase(self: Self) -> DaskExpr: output_names=output_names, returns_scalar=self._expr._returns_scalar, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) def to_uppercase(self: Self) -> DaskExpr: @@ -1102,4 +1117,5 @@ def to_uppercase(self: Self) -> DaskExpr: output_names=output_names, returns_scalar=self._expr._returns_scalar, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index d79c95d7b8..55ef69f468 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -84,8 +84,7 @@ def _from_native_frame(self, df: DaskLazyFrame) -> DaskLazyFrame: from narwhals._dask.dataframe import DaskLazyFrame return DaskLazyFrame( - df, - backend_version=self._df._backend_version, + df, backend_version=self._df._backend_version, dtypes=self._df._dtypes ) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 39c6471920..01d5bea489 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -9,7 +9,6 @@ from typing import NoReturn from typing import cast -from narwhals import dtypes from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.expr import DaskExpr from narwhals._dask.selectors import DaskSelectorNamespace @@ -24,38 +23,19 @@ from narwhals._dask.typing import IntoDaskExpr from narwhals.dtypes import DType + from narwhals.typing import DTypes class DaskNamespace: - Int64 = dtypes.Int64 - Int32 = dtypes.Int32 - Int16 = dtypes.Int16 - Int8 = dtypes.Int8 - UInt64 = dtypes.UInt64 - UInt32 = dtypes.UInt32 - UInt16 = dtypes.UInt16 - UInt8 = dtypes.UInt8 - Float64 = dtypes.Float64 - Float32 = dtypes.Float32 - Boolean = dtypes.Boolean - Object = dtypes.Object - Unknown = dtypes.Unknown - Categorical = dtypes.Categorical - Enum = dtypes.Enum - String = dtypes.String - Datetime = dtypes.Datetime - Duration = dtypes.Duration - Date = dtypes.Date - List = dtypes.List - Struct = dtypes.Struct - Array = dtypes.Array - @property def selectors(self) -> DaskSelectorNamespace: - return DaskSelectorNamespace(backend_version=self._backend_version) + return DaskSelectorNamespace( + backend_version=self._backend_version, dtypes=self._dtypes + ) - def __init__(self, *, backend_version: tuple[int, ...]) -> None: + def __init__(self, *, backend_version: tuple[int, ...], dtypes: DTypes) -> None: self._backend_version = backend_version + self._dtypes = dtypes def all(self) -> DaskExpr: def func(df: DaskLazyFrame) -> list[dask_expr.Series]: @@ -69,25 +49,28 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: output_names=None, returns_scalar=False, backend_version=self._backend_version, + dtypes=self._dtypes, ) def col(self, *column_names: str) -> DaskExpr: return DaskExpr.from_column_names( - *column_names, - backend_version=self._backend_version, + *column_names, backend_version=self._backend_version, dtypes=self._dtypes ) def nth(self, *column_indices: int) -> DaskExpr: return DaskExpr.from_column_indices( - *column_indices, - backend_version=self._backend_version, + *column_indices, backend_version=self._backend_version, dtypes=self._dtypes ) - def lit(self, value: Any, dtype: dtypes.DType | None) -> DaskExpr: + def lit(self, value: Any, dtype: DType | None) -> DaskExpr: def convert_if_dtype( series: dask_expr.Series, dtype: DType | type[DType] ) -> dask_expr.Series: - return series.astype(narwhals_to_native_dtype(dtype)) if dtype else series + return ( + series.astype(narwhals_to_native_dtype(dtype, self._dtypes)) + if dtype + else series + ) return DaskExpr( lambda df: [ @@ -101,30 +84,27 @@ def convert_if_dtype( output_names=["lit"], returns_scalar=False, backend_version=self._backend_version, + dtypes=self._dtypes, ) def min(self, *column_names: str) -> DaskExpr: return DaskExpr.from_column_names( - *column_names, - backend_version=self._backend_version, + *column_names, backend_version=self._backend_version, dtypes=self._dtypes ).min() def max(self, *column_names: str) -> DaskExpr: return DaskExpr.from_column_names( - *column_names, - backend_version=self._backend_version, + *column_names, backend_version=self._backend_version, dtypes=self._dtypes ).max() def mean(self, *column_names: str) -> DaskExpr: return DaskExpr.from_column_names( - *column_names, - backend_version=self._backend_version, + *column_names, backend_version=self._backend_version, dtypes=self._dtypes ).mean() def sum(self, *column_names: str) -> DaskExpr: return DaskExpr.from_column_names( - *column_names, - backend_version=self._backend_version, + *column_names, backend_version=self._backend_version, dtypes=self._dtypes ).sum() def len(self) -> DaskExpr: @@ -150,6 +130,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: output_names=["len"], returns_scalar=True, backend_version=self._backend_version, + dtypes=self._dtypes, ) def all_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -167,6 +148,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: output_names=reduce_output_names(parsed_exprs), returns_scalar=False, backend_version=self._backend_version, + dtypes=self._dtypes, ) def any_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -184,6 +166,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: output_names=reduce_output_names(parsed_exprs), returns_scalar=False, backend_version=self._backend_version, + dtypes=self._dtypes, ) def sum_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: @@ -201,6 +184,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: output_names=reduce_output_names(parsed_exprs), returns_scalar=False, backend_version=self._backend_version, + dtypes=self._dtypes, ) def concat( @@ -224,6 +208,7 @@ def concat( return DaskLazyFrame( dd.concat(native_frames, axis=0, join="inner"), backend_version=self._backend_version, + dtypes=self._dtypes, ) if how == "horizontal": all_column_names: list[str] = [ @@ -241,6 +226,7 @@ def concat( return DaskLazyFrame( dd.concat(native_frames, axis=1, join="outer"), backend_version=self._backend_version, + dtypes=self._dtypes, ) raise NotImplementedError @@ -282,14 +268,16 @@ def when( self, *predicates: IntoDaskExpr, ) -> DaskWhen: - plx = self.__class__(backend_version=self._backend_version) + plx = self.__class__(backend_version=self._backend_version, dtypes=self._dtypes) if predicates: condition = plx.all_horizontal(*predicates) else: msg = "at least one predicate needs to be provided" raise TypeError(msg) - return DaskWhen(condition, self._backend_version, returns_scalar=False) + return DaskWhen( + condition, self._backend_version, returns_scalar=False, dtypes=self._dtypes + ) class DaskWhen: @@ -301,18 +289,20 @@ def __init__( otherwise_value: Any = None, *, returns_scalar: bool, + dtypes: DTypes, ) -> None: self._backend_version = backend_version self._condition = condition self._then_value = then_value self._otherwise_value = otherwise_value self._returns_scalar = returns_scalar + self._dtypes = dtypes def __call__(self, df: DaskLazyFrame) -> list[dask_expr.Series]: from narwhals._dask.namespace import DaskNamespace from narwhals._expression_parsing import parse_into_expr - plx = DaskNamespace(backend_version=self._backend_version) + plx = DaskNamespace(backend_version=self._backend_version, dtypes=self._dtypes) condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] # type: ignore[arg-type] condition = cast("dask_expr.Series", condition) @@ -349,6 +339,7 @@ def then(self, value: DaskExpr | Any) -> DaskThen: output_names=None, returns_scalar=self._returns_scalar, backend_version=self._backend_version, + dtypes=self._dtypes, ) @@ -363,9 +354,10 @@ def __init__( output_names: list[str] | None, returns_scalar: bool, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> None: self._backend_version = backend_version - + self._dtypes = dtypes self._call = call self._depth = depth self._function_name = function_name diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index 54131a8a59..4d9af1110e 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -4,7 +4,6 @@ from typing import Any from typing import NoReturn -from narwhals import dtypes from narwhals._dask.expr import DaskExpr if TYPE_CHECKING: @@ -13,11 +12,13 @@ from narwhals._dask.dataframe import DaskLazyFrame from narwhals.dtypes import DType + from narwhals.typing import DTypes class DaskSelectorNamespace: - def __init__(self: Self, *, backend_version: tuple[int, ...]) -> None: + def __init__(self: Self, *, backend_version: tuple[int, ...], dtypes: DTypes) -> None: self._backend_version = backend_version + self._dtypes = dtypes def by_dtype(self: Self, dtypes: list[DType | type[DType]]) -> DaskSelector: def func(df: DaskLazyFrame) -> list[Any]: @@ -33,32 +34,33 @@ def func(df: DaskLazyFrame) -> list[Any]: output_names=None, backend_version=self._backend_version, returns_scalar=False, + dtypes=self._dtypes, ) def numeric(self: Self) -> DaskSelector: return self.by_dtype( [ - dtypes.Int64, - dtypes.Int32, - dtypes.Int16, - dtypes.Int8, - dtypes.UInt64, - dtypes.UInt32, - dtypes.UInt16, - dtypes.UInt8, - dtypes.Float64, - dtypes.Float32, + self._dtypes.Int64, + self._dtypes.Int32, + self._dtypes.Int16, + self._dtypes.Int8, + self._dtypes.UInt64, + self._dtypes.UInt32, + self._dtypes.UInt16, + self._dtypes.UInt8, + self._dtypes.Float64, + self._dtypes.Float32, ], ) def categorical(self: Self) -> DaskSelector: - return self.by_dtype([dtypes.Categorical]) + return self.by_dtype([self._dtypes.Categorical]) def string(self: Self) -> DaskSelector: - return self.by_dtype([dtypes.String]) + return self.by_dtype([self._dtypes.String]) def boolean(self: Self) -> DaskSelector: - return self.by_dtype([dtypes.Boolean]) + return self.by_dtype([self._dtypes.Boolean]) def all(self: Self) -> DaskSelector: def func(df: DaskLazyFrame) -> list[Any]: @@ -72,6 +74,7 @@ def func(df: DaskLazyFrame) -> list[Any]: output_names=None, backend_version=self._backend_version, returns_scalar=False, + dtypes=self._dtypes, ) @@ -94,6 +97,7 @@ def _to_expr(self: Self) -> DaskExpr: output_names=self._output_names, backend_version=self._backend_version, returns_scalar=self._returns_scalar, + dtypes=self._dtypes, ) def __sub__(self: Self, other: DaskSelector | Any) -> DaskSelector | Any: @@ -112,6 +116,7 @@ def call(df: DaskLazyFrame) -> list[Any]: output_names=None, backend_version=self._backend_version, returns_scalar=self._returns_scalar, + dtypes=self._dtypes, ) else: return self._to_expr() - other @@ -132,6 +137,7 @@ def call(df: DaskLazyFrame) -> list[dask_expr.Series]: output_names=None, backend_version=self._backend_version, returns_scalar=self._returns_scalar, + dtypes=self._dtypes, ) else: return self._to_expr() | other @@ -152,12 +158,18 @@ def call(df: DaskLazyFrame) -> list[Any]: output_names=None, backend_version=self._backend_version, returns_scalar=self._returns_scalar, + dtypes=self._dtypes, ) else: return self._to_expr() & other def __invert__(self: Self) -> DaskSelector: - return DaskSelectorNamespace(backend_version=self._backend_version).all() - self + return ( + DaskSelectorNamespace( + backend_version=self._backend_version, dtypes=self._dtypes + ).all() + - self + ) def __rsub__(self: Self, other: Any) -> NoReturn: raise NotImplementedError diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 02dedab4e2..f7636bd5fc 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -14,6 +14,7 @@ from narwhals._dask.dataframe import DaskLazyFrame from narwhals.dtypes import DType + from narwhals.typing import DTypes def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any: @@ -83,9 +84,7 @@ def validate_comparand(lhs: dask_expr.Series, rhs: dask_expr.Series) -> None: raise RuntimeError(msg) -def narwhals_to_native_dtype(dtype: DType | type[DType]) -> Any: - from narwhals import dtypes - +def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any: if isinstance_or_issubclass(dtype, dtypes.Float64): return "float64" if isinstance_or_issubclass(dtype, dtypes.Float32): diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 099a91b726..5e8cb73d3f 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING from typing import Any -from narwhals import dtypes from narwhals.utils import parse_version if TYPE_CHECKING: @@ -13,11 +12,11 @@ from typing_extensions import Self from narwhals._duckdb.series import DuckDBInterchangeSeries + from narwhals.dtypes import DType + from narwhals.typing import DTypes -def map_duckdb_dtype_to_narwhals_dtype( - duckdb_dtype: Any, -) -> dtypes.DType: +def map_duckdb_dtype_to_narwhals_dtype(duckdb_dtype: Any, dtypes: DTypes) -> DType: duckdb_dtype = str(duckdb_dtype) if duckdb_dtype == "BIGINT": return dtypes.Int64() @@ -59,8 +58,9 @@ def map_duckdb_dtype_to_narwhals_dtype( class DuckDBInterchangeFrame: - def __init__(self, df: Any) -> None: + def __init__(self, df: Any, dtypes: DTypes) -> None: self._native_frame = df + self._dtypes = dtypes def __narwhals_dataframe__(self) -> Any: return self @@ -68,12 +68,16 @@ def __narwhals_dataframe__(self) -> Any: def __getitem__(self, item: str) -> DuckDBInterchangeSeries: from narwhals._duckdb.series import DuckDBInterchangeSeries - return DuckDBInterchangeSeries(self._native_frame.select(item)) + return DuckDBInterchangeSeries( + self._native_frame.select(item), dtypes=self._dtypes + ) def __getattr__(self, attr: str) -> Any: if attr == "schema": return { - column_name: map_duckdb_dtype_to_narwhals_dtype(duckdb_dtype) + column_name: map_duckdb_dtype_to_narwhals_dtype( + duckdb_dtype, self._dtypes + ) for column_name, duckdb_dtype in zip( self._native_frame.columns, self._native_frame.types ) diff --git a/narwhals/_duckdb/series.py b/narwhals/_duckdb/series.py index f19a6f76fd..a7dbdd549f 100644 --- a/narwhals/_duckdb/series.py +++ b/narwhals/_duckdb/series.py @@ -1,20 +1,27 @@ from __future__ import annotations +from typing import TYPE_CHECKING from typing import Any from narwhals._duckdb.dataframe import map_duckdb_dtype_to_narwhals_dtype +if TYPE_CHECKING: + from narwhals.typing import DTypes + class DuckDBInterchangeSeries: - def __init__(self, df: Any) -> None: + def __init__(self, df: Any, dtypes: DTypes) -> None: self._native_series = df + self._dtypes = dtypes def __narwhals_series__(self) -> Any: return self def __getattr__(self, attr: str) -> Any: if attr == "dtype": - return map_duckdb_dtype_to_narwhals_dtype(self._native_series.types[0]) + return map_duckdb_dtype_to_narwhals_dtype( + self._native_series.types[0], self._dtypes + ) msg = ( # pragma: no cover f"Attribute {attr} is not supported for metadata-only dataframes.\n\n" "If you would like to see this kind of object better supported in " diff --git a/narwhals/_ibis/dataframe.py b/narwhals/_ibis/dataframe.py index f0dc8f6eb0..6f53e277dc 100644 --- a/narwhals/_ibis/dataframe.py +++ b/narwhals/_ibis/dataframe.py @@ -3,19 +3,17 @@ from typing import TYPE_CHECKING from typing import Any -from narwhals import dtypes - if TYPE_CHECKING: import pandas as pd import pyarrow as pa from typing_extensions import Self from narwhals._ibis.series import IbisInterchangeSeries + from narwhals.dtypes import DType + from narwhals.typing import DTypes -def map_ibis_dtype_to_narwhals_dtype( - ibis_dtype: Any, -) -> dtypes.DType: +def map_ibis_dtype_to_narwhals_dtype(ibis_dtype: Any, dtypes: DTypes) -> DType: if ibis_dtype.is_int64(): return dtypes.Int64() if ibis_dtype.is_int32(): @@ -52,8 +50,9 @@ def map_ibis_dtype_to_narwhals_dtype( class IbisInterchangeFrame: - def __init__(self, df: Any) -> None: + def __init__(self, df: Any, dtypes: DTypes) -> None: self._native_frame = df + self._dtypes = dtypes def __narwhals_dataframe__(self) -> Any: return self @@ -61,7 +60,7 @@ def __narwhals_dataframe__(self) -> Any: def __getitem__(self, item: str) -> IbisInterchangeSeries: from narwhals._ibis.series import IbisInterchangeSeries - return IbisInterchangeSeries(self._native_frame[item]) + return IbisInterchangeSeries(self._native_frame[item], dtypes=self._dtypes) def to_pandas(self: Self) -> pd.DataFrame: return self._native_frame.to_pandas() @@ -72,7 +71,7 @@ def to_arrow(self: Self) -> pa.Table: def __getattr__(self, attr: str) -> Any: if attr == "schema": return { - column_name: map_ibis_dtype_to_narwhals_dtype(ibis_dtype) + column_name: map_ibis_dtype_to_narwhals_dtype(ibis_dtype, self._dtypes) for column_name, ibis_dtype in self._native_frame.schema().items() } msg = ( diff --git a/narwhals/_ibis/series.py b/narwhals/_ibis/series.py index 73e3b6d471..2f6cd6faa3 100644 --- a/narwhals/_ibis/series.py +++ b/narwhals/_ibis/series.py @@ -1,20 +1,27 @@ from __future__ import annotations +from typing import TYPE_CHECKING from typing import Any from narwhals._ibis.dataframe import map_ibis_dtype_to_narwhals_dtype +if TYPE_CHECKING: + from narwhals.typing import DTypes + class IbisInterchangeSeries: - def __init__(self, df: Any) -> None: + def __init__(self, df: Any, dtypes: DTypes) -> None: self._native_series = df + self._dtypes = dtypes def __narwhals_series__(self) -> Any: return self def __getattr__(self, attr: str) -> Any: if attr == "dtype": - return map_ibis_dtype_to_narwhals_dtype(self._native_series.type()) + return map_ibis_dtype_to_narwhals_dtype( + self._native_series.type(), self._dtypes + ) msg = ( f"Attribute {attr} is not supported for metadata-only dataframes.\n\n" "If you would like to see this kind of object better supported in " diff --git a/narwhals/_interchange/dataframe.py b/narwhals/_interchange/dataframe.py index 975da216fc..1dc671dc74 100644 --- a/narwhals/_interchange/dataframe.py +++ b/narwhals/_interchange/dataframe.py @@ -5,7 +5,6 @@ from typing import Any from typing import NoReturn -from narwhals import dtypes from narwhals.utils import parse_version if TYPE_CHECKING: @@ -14,6 +13,8 @@ from typing_extensions import Self from narwhals._interchange.series import InterchangeSeries + from narwhals.dtypes import DType + from narwhals.typing import DTypes class DtypeKind(enum.IntEnum): @@ -28,8 +29,8 @@ class DtypeKind(enum.IntEnum): def map_interchange_dtype_to_narwhals_dtype( - interchange_dtype: tuple[DtypeKind, int, Any, Any], -) -> dtypes.DType: + interchange_dtype: tuple[DtypeKind, int, Any, Any], dtypes: DTypes +) -> DType: if interchange_dtype[0] == DtypeKind.INT: if interchange_dtype[1] == 64: return dtypes.Int64() @@ -73,9 +74,10 @@ def map_interchange_dtype_to_narwhals_dtype( class InterchangeFrame: - def __init__(self, df: Any) -> None: + def __init__(self, df: Any, dtypes: DTypes) -> None: self._native_frame = df self._interchange_frame = df.__dataframe__() + self._dtypes = dtypes def __narwhals_dataframe__(self) -> Any: return self @@ -83,13 +85,16 @@ def __narwhals_dataframe__(self) -> Any: def __getitem__(self, item: str) -> InterchangeSeries: from narwhals._interchange.series import InterchangeSeries - return InterchangeSeries(self._interchange_frame.get_column_by_name(item)) + return InterchangeSeries( + self._interchange_frame.get_column_by_name(item), dtypes=self._dtypes + ) @property - def schema(self) -> dict[str, dtypes.DType]: + def schema(self) -> dict[str, DType]: return { column_name: map_interchange_dtype_to_narwhals_dtype( - self._interchange_frame.get_column_by_name(column_name).dtype + self._interchange_frame.get_column_by_name(column_name).dtype, + self._dtypes, ) for column_name in self._interchange_frame.column_names() } diff --git a/narwhals/_interchange/series.py b/narwhals/_interchange/series.py index 70f84d12f1..00426e6c0e 100644 --- a/narwhals/_interchange/series.py +++ b/narwhals/_interchange/series.py @@ -2,27 +2,27 @@ from typing import TYPE_CHECKING from typing import Any -from typing import NoReturn from narwhals._interchange.dataframe import map_interchange_dtype_to_narwhals_dtype if TYPE_CHECKING: - from narwhals.dtypes import DType + from narwhals.typing import DTypes class InterchangeSeries: - def __init__(self, df: Any) -> None: + def __init__(self, df: Any, dtypes: DTypes) -> None: self._native_series = df + self._dtypes = dtypes def __narwhals_series__(self) -> Any: return self - @property - def dtype(self) -> DType: - return map_interchange_dtype_to_narwhals_dtype(self._native_series.dtype) - - def __getattr__(self, attr: str) -> NoReturn: - msg = ( + def __getattr__(self, attr: str) -> Any: + if attr == "dtype": + return map_interchange_dtype_to_narwhals_dtype( + self._native_series.dtype, dtypes=self._dtypes + ) + msg = ( # pragma: no cover f"Attribute {attr} is not supported for metadata-only dataframes.\n\n" "Hint: you probably called `nw.from_native` on an object which isn't fully " "supported by Narwhals, yet implements `__dataframe__`. If you would like to " diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 47a6bb39bb..aae86cef73 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -35,6 +35,7 @@ from narwhals._pandas_like.series import PandasLikeSeries from narwhals._pandas_like.typing import IntoPandasLikeExpr from narwhals.dtypes import DType + from narwhals.typing import DTypes class PandasLikeDataFrame: @@ -45,11 +46,13 @@ def __init__( *, implementation: Implementation, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> None: self._validate_columns(native_dataframe.columns) self._native_frame = native_dataframe self._implementation = implementation self._backend_version = backend_version + self._dtypes = dtypes def __narwhals_dataframe__(self) -> Self: return self @@ -60,7 +63,9 @@ def __narwhals_lazyframe__(self) -> Self: def __narwhals_namespace__(self) -> PandasLikeNamespace: from narwhals._pandas_like.namespace import PandasLikeNamespace - return PandasLikeNamespace(self._implementation, self._backend_version) + return PandasLikeNamespace( + self._implementation, self._backend_version, dtypes=self._dtypes + ) def __native_namespace__(self: Self) -> ModuleType: if self._implementation in { @@ -92,6 +97,7 @@ def _from_native_frame(self, df: Any) -> Self: df, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def get_column(self, name: str) -> PandasLikeSeries: @@ -101,6 +107,7 @@ def get_column(self, name: str) -> PandasLikeSeries: self._native_frame.loc[:, name], implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray: @@ -153,6 +160,7 @@ def __getitem__( self._native_frame.loc[:, item], implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) elif ( @@ -208,6 +216,7 @@ def __getitem__( native_series, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) elif is_sequence_but_not_str(item) or (is_numpy_array(item) and item.ndim == 1): @@ -265,7 +274,7 @@ def iter_rows( @property def schema(self) -> dict[str, DType]: return { - col: native_to_narwhals_dtype(self._native_frame.loc[:, col]) + col: native_to_narwhals_dtype(self._native_frame.loc[:, col], self._dtypes) for col in self._native_frame.columns } @@ -306,6 +315,7 @@ def with_row_index(self, name: str) -> Self: index=self._native_frame.index, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ).alias(name) return self._from_native_frame( horizontal_concat( @@ -417,6 +427,7 @@ def collect(self) -> PandasLikeDataFrame: self._native_frame, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) # --- actions --- @@ -623,6 +634,7 @@ def to_dict(self, *, as_series: bool = False) -> dict[str, Any]: self._native_frame.loc[:, col], implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) for col in self.columns } @@ -672,6 +684,7 @@ def is_duplicated(self: Self) -> PandasLikeSeries: self._native_frame.duplicated(keep=False), implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def is_empty(self: Self) -> bool: @@ -684,6 +697,7 @@ def is_unique(self: Self) -> PandasLikeSeries: ~self._native_frame.duplicated(keep=False), implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def null_count(self: Self) -> PandasLikeDataFrame: @@ -691,6 +705,7 @@ def null_count(self: Self) -> PandasLikeDataFrame: self._native_frame.isna().sum(axis=0).to_frame().transpose(), implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def item(self: Self, row: int | None = None, column: int | str | None = None) -> Any: diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 512699515a..52c237aaad 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -14,6 +14,7 @@ from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.namespace import PandasLikeNamespace + from narwhals.typing import DTypes from narwhals.utils import Implementation @@ -28,6 +29,7 @@ def __init__( output_names: list[str] | None, implementation: Implementation, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> None: self._call = call self._depth = depth @@ -37,6 +39,7 @@ def __init__( self._output_names = output_names self._implementation = implementation self._backend_version = backend_version + self._dtypes = dtypes def __repr__(self) -> str: # pragma: no cover return ( @@ -50,7 +53,9 @@ def __repr__(self) -> str: # pragma: no cover def __narwhals_namespace__(self) -> PandasLikeNamespace: from narwhals._pandas_like.namespace import PandasLikeNamespace - return PandasLikeNamespace(self._implementation, self._backend_version) + return PandasLikeNamespace( + self._implementation, self._backend_version, dtypes=self._dtypes + ) def __narwhals_expr__(self) -> None: ... @@ -60,6 +65,7 @@ def from_column_names( *column_names: str, implementation: Implementation, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> Self: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return [ @@ -67,6 +73,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: df._native_frame.loc[:, column_name], implementation=df._implementation, backend_version=df._backend_version, + dtypes=df._dtypes, ) for column_name in column_names ] @@ -79,6 +86,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: output_names=list(column_names), implementation=implementation, backend_version=backend_version, + dtypes=dtypes, ) @classmethod @@ -87,6 +95,7 @@ def from_column_indices( *column_indices: int, implementation: Implementation, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> Self: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return [ @@ -94,6 +103,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: df._native_frame.iloc[:, column_index], implementation=df._implementation, backend_version=df._backend_version, + dtypes=df._dtypes, ) for column_index in column_indices ] @@ -106,6 +116,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: output_names=None, implementation=implementation, backend_version=backend_version, + dtypes=dtypes, ) def cast( @@ -308,6 +319,7 @@ def alias(self, name: str) -> Self: output_names=[name], implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def over(self, keys: list[str]) -> Self: @@ -333,6 +345,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: output_names=self._output_names, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def is_duplicated(self) -> Self: @@ -586,6 +599,7 @@ def keep(self: Self) -> PandasLikeExpr: output_names=root_names, implementation=self._expr._implementation, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) def map(self: Self, function: Callable[[str], str]) -> PandasLikeExpr: @@ -612,6 +626,7 @@ def map(self: Self, function: Callable[[str], str]) -> PandasLikeExpr: output_names=output_names, implementation=self._expr._implementation, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) def prefix(self: Self, prefix: str) -> PandasLikeExpr: @@ -636,6 +651,7 @@ def prefix(self: Self, prefix: str) -> PandasLikeExpr: output_names=output_names, implementation=self._expr._implementation, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) def suffix(self: Self, suffix: str) -> PandasLikeExpr: @@ -661,6 +677,7 @@ def suffix(self: Self, suffix: str) -> PandasLikeExpr: output_names=output_names, implementation=self._expr._implementation, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) def to_lowercase(self: Self) -> PandasLikeExpr: @@ -686,6 +703,7 @@ def to_lowercase(self: Self) -> PandasLikeExpr: output_names=output_names, implementation=self._expr._implementation, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) def to_uppercase(self: Self) -> PandasLikeExpr: @@ -711,4 +729,5 @@ def to_uppercase(self: Self) -> PandasLikeExpr: output_names=output_names, implementation=self._expr._implementation, backend_version=self._expr._backend_version, + dtypes=self._expr._dtypes, ) diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 55c038f9dd..f20383460e 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -92,6 +92,7 @@ def _from_native_frame(self, df: PandasLikeDataFrame) -> PandasLikeDataFrame: df, implementation=self._df._implementation, backend_version=self._df._backend_version, + dtypes=self._df._dtypes, ) def __iter__(self) -> Iterator[tuple[Any, PandasLikeDataFrame]]: diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 7356524d3f..6aacf2856e 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -8,7 +8,6 @@ from typing import Literal from typing import cast -from narwhals import dtypes from narwhals._expression_parsing import combine_root_names from narwhals._expression_parsing import parse_into_exprs from narwhals._expression_parsing import reduce_output_names @@ -22,45 +21,30 @@ if TYPE_CHECKING: from narwhals._pandas_like.typing import IntoPandasLikeExpr + from narwhals.dtypes import DType + from narwhals.typing import DTypes from narwhals.utils import Implementation class PandasLikeNamespace: - Int64 = dtypes.Int64 - Int32 = dtypes.Int32 - Int16 = dtypes.Int16 - Int8 = dtypes.Int8 - UInt64 = dtypes.UInt64 - UInt32 = dtypes.UInt32 - UInt16 = dtypes.UInt16 - UInt8 = dtypes.UInt8 - Float64 = dtypes.Float64 - Float32 = dtypes.Float32 - Boolean = dtypes.Boolean - Object = dtypes.Object - Unknown = dtypes.Unknown - Categorical = dtypes.Categorical - Enum = dtypes.Enum - String = dtypes.String - Datetime = dtypes.Datetime - Duration = dtypes.Duration - Date = dtypes.Date - List = dtypes.List - Struct = dtypes.Struct - Array = dtypes.Array - @property def selectors(self) -> PandasSelectorNamespace: return PandasSelectorNamespace( - implementation=self._implementation, backend_version=self._backend_version + implementation=self._implementation, + backend_version=self._backend_version, + dtypes=self._dtypes, ) # --- not in spec --- def __init__( - self, implementation: Implementation, backend_version: tuple[int, ...] + self, + implementation: Implementation, + backend_version: tuple[int, ...], + dtypes: DTypes, ) -> None: self._implementation = implementation self._backend_version = backend_version + self._dtypes = dtypes def _create_expr_from_callable( self, @@ -79,6 +63,7 @@ def _create_expr_from_callable( output_names=output_names, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def _create_series_from_scalar( @@ -90,6 +75,7 @@ def _create_series_from_scalar( index=series._native_series.index[0:1], implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def _create_expr_from_series(self, series: PandasLikeSeries) -> PandasLikeExpr: @@ -101,6 +87,7 @@ def _create_expr_from_series(self, series: PandasLikeSeries) -> PandasLikeExpr: output_names=None, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def _create_compliant_series(self, value: Any) -> PandasLikeSeries: @@ -108,6 +95,7 @@ def _create_compliant_series(self, value: Any) -> PandasLikeSeries: value, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) # --- selection --- @@ -116,6 +104,7 @@ def col(self, *column_names: str) -> PandasLikeExpr: *column_names, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def nth(self, *column_indices: int) -> PandasLikeExpr: @@ -123,6 +112,7 @@ def nth(self, *column_indices: int) -> PandasLikeExpr: *column_indices, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def all(self) -> PandasLikeExpr: @@ -132,6 +122,7 @@ def all(self) -> PandasLikeExpr: df._native_frame.loc[:, column_name], implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) for column_name in df.columns ], @@ -141,9 +132,10 @@ def all(self) -> PandasLikeExpr: output_names=None, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) - def lit(self, value: Any, dtype: dtypes.DType | None) -> PandasLikeExpr: + def lit(self, value: Any, dtype: DType | None) -> PandasLikeExpr: def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries: pandas_series = PandasLikeSeries._from_iterable( data=[value], @@ -151,6 +143,7 @@ def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries: index=df._native_frame.index[0:1], implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) if dtype: return pandas_series.cast(dtype) @@ -164,6 +157,7 @@ def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries: output_names=["lit"], implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) # --- reduction --- @@ -172,6 +166,7 @@ def sum(self, *column_names: str) -> PandasLikeExpr: *column_names, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ).sum() def mean(self, *column_names: str) -> PandasLikeExpr: @@ -179,6 +174,7 @@ def mean(self, *column_names: str) -> PandasLikeExpr: *column_names, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ).mean() def max(self, *column_names: str) -> PandasLikeExpr: @@ -186,6 +182,7 @@ def max(self, *column_names: str) -> PandasLikeExpr: *column_names, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ).max() def min(self, *column_names: str) -> PandasLikeExpr: @@ -193,6 +190,7 @@ def min(self, *column_names: str) -> PandasLikeExpr: *column_names, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ).min() def len(self) -> PandasLikeExpr: @@ -204,6 +202,7 @@ def len(self) -> PandasLikeExpr: index=[0], implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) ], depth=0, @@ -212,6 +211,7 @@ def len(self) -> PandasLikeExpr: output_names=["len"], implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) # --- horizontal --- @@ -284,6 +284,7 @@ def concat( ), implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) if how == "vertical": return PandasLikeDataFrame( @@ -294,6 +295,7 @@ def concat( ), implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) raise NotImplementedError @@ -301,14 +303,18 @@ def when( self, *predicates: IntoPandasLikeExpr, ) -> PandasWhen: - plx = self.__class__(self._implementation, self._backend_version) + plx = self.__class__( + self._implementation, self._backend_version, dtypes=self._dtypes + ) if predicates: condition = plx.all_horizontal(*predicates) else: msg = "at least one predicate needs to be provided" raise TypeError(msg) - return PandasWhen(condition, self._implementation, self._backend_version) + return PandasWhen( + condition, self._implementation, self._backend_version, dtypes=self._dtypes + ) class PandasWhen: @@ -319,12 +325,15 @@ def __init__( backend_version: tuple[int, ...], then_value: Any = None, otherwise_value: Any = None, + *, + dtypes: DTypes, ) -> None: self._implementation = implementation self._backend_version = backend_version self._condition = condition self._then_value = then_value self._otherwise_value = otherwise_value + self._dtypes = dtypes def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: from narwhals._expression_parsing import parse_into_expr @@ -332,7 +341,9 @@ def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: from narwhals._pandas_like.utils import validate_column_comparand plx = PandasLikeNamespace( - implementation=self._implementation, backend_version=self._backend_version + implementation=self._implementation, + backend_version=self._backend_version, + dtypes=self._dtypes, ) condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] # type: ignore[arg-type] @@ -346,6 +357,7 @@ def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: index=condition._native_series.index, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) value_series = cast(PandasLikeSeries, value_series) @@ -383,6 +395,7 @@ def then(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasThen: output_names=None, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) @@ -397,10 +410,11 @@ def __init__( output_names: list[str] | None, implementation: Implementation, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> None: self._implementation = implementation self._backend_version = backend_version - + self._dtypes = dtypes self._call = call self._depth = depth self._function_name = function_name diff --git a/narwhals/_pandas_like/selectors.py b/narwhals/_pandas_like/selectors.py index 1214e12fc0..74235afa58 100644 --- a/narwhals/_pandas_like/selectors.py +++ b/narwhals/_pandas_like/selectors.py @@ -4,22 +4,27 @@ from typing import Any from typing import NoReturn -from narwhals import dtypes from narwhals._pandas_like.expr import PandasLikeExpr if TYPE_CHECKING: from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.series import PandasLikeSeries from narwhals.dtypes import DType + from narwhals.typing import DTypes from narwhals.utils import Implementation class PandasSelectorNamespace: def __init__( - self, *, implementation: Implementation, backend_version: tuple[int, ...] + self, + *, + implementation: Implementation, + backend_version: tuple[int, ...], + dtypes: DTypes, ) -> None: self._implementation = implementation self._backend_version = backend_version + self._dtypes = dtypes def by_dtype(self, dtypes: list[DType | type[DType]]) -> PandasSelector: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: @@ -33,32 +38,33 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: output_names=None, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def numeric(self) -> PandasSelector: return self.by_dtype( [ - dtypes.Int64, - dtypes.Int32, - dtypes.Int16, - dtypes.Int8, - dtypes.UInt64, - dtypes.UInt32, - dtypes.UInt16, - dtypes.UInt8, - dtypes.Float64, - dtypes.Float32, + self._dtypes.Int64, + self._dtypes.Int32, + self._dtypes.Int16, + self._dtypes.Int8, + self._dtypes.UInt64, + self._dtypes.UInt32, + self._dtypes.UInt16, + self._dtypes.UInt8, + self._dtypes.Float64, + self._dtypes.Float32, ], ) def categorical(self) -> PandasSelector: - return self.by_dtype([dtypes.Categorical]) + return self.by_dtype([self._dtypes.Categorical]) def string(self) -> PandasSelector: - return self.by_dtype([dtypes.String]) + return self.by_dtype([self._dtypes.String]) def boolean(self) -> PandasSelector: - return self.by_dtype([dtypes.Boolean]) + return self.by_dtype([self._dtypes.Boolean]) def all(self) -> PandasSelector: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: @@ -72,6 +78,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: output_names=None, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) @@ -94,6 +101,7 @@ def _to_expr(self) -> PandasLikeExpr: output_names=self._output_names, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def __sub__(self, other: PandasSelector | Any) -> PandasSelector | Any: @@ -112,6 +120,7 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: output_names=None, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) else: return self._to_expr() - other @@ -132,6 +141,7 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: output_names=None, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) else: return self._to_expr() | other @@ -152,6 +162,7 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: output_names=None, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) else: return self._to_expr() & other @@ -159,7 +170,9 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: def __invert__(self) -> PandasSelector: return ( PandasSelectorNamespace( - implementation=self._implementation, backend_version=self._backend_version + implementation=self._implementation, + backend_version=self._backend_version, + dtypes=self._dtypes, ).all() - self ) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 8557f8eeeb..dc9a00009d 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -24,6 +24,7 @@ from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals.dtypes import DType + from narwhals.typing import DTypes PANDAS_TO_NUMPY_DTYPE_NO_MISSING = { "Int64": "int64", @@ -78,11 +79,13 @@ def __init__( *, implementation: Implementation, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> None: self._name = native_series.name self._native_series = native_series self._implementation = implementation self._backend_version = backend_version + self._dtypes = dtypes # In pandas, copy-on-write becomes the default in version 3. # So, before that, we need to explicitly avoid unnecessary @@ -131,6 +134,7 @@ def _from_native_series(self, series: Any) -> Self: series, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) @classmethod @@ -142,6 +146,7 @@ def _from_iterable( *, implementation: Implementation, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> Self: return cls( native_series_from_iterable( @@ -152,6 +157,7 @@ def _from_iterable( ), implementation=implementation, backend_version=backend_version, + dtypes=dtypes, ) def __len__(self) -> int: @@ -167,7 +173,7 @@ def shape(self) -> tuple[int]: @property def dtype(self: Self) -> DType: - return native_to_narwhals_dtype(self._native_series) + return native_to_narwhals_dtype(self._native_series, self._dtypes) def scatter(self, indices: int | Sequence[int], values: Any) -> Self: if isinstance(values, self.__class__): @@ -190,7 +196,9 @@ def cast( dtype: Any, ) -> Self: ser = self._native_series - dtype = narwhals_to_native_dtype(dtype, ser.dtype, self._implementation) + dtype = narwhals_to_native_dtype( + dtype, ser.dtype, self._implementation, self._dtypes + ) return self._from_native_series(ser.astype(dtype)) def item(self: Self, index: int | None = None) -> Any: @@ -212,6 +220,7 @@ def to_frame(self) -> PandasLikeDataFrame: self._native_series.to_frame(), implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def to_list(self) -> Any: @@ -598,6 +607,7 @@ def value_counts( val_count, implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def quantile( @@ -640,6 +650,7 @@ def to_dummies( ).astype(int), implementation=self._implementation, backend_version=self._backend_version, + dtypes=self._dtypes, ) def gather_every(self: Self, n: int, offset: int = 0) -> Self: diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 286d712bfc..726a07c564 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -14,6 +14,7 @@ from narwhals._pandas_like.expr import PandasLikeExpr from narwhals._pandas_like.series import PandasLikeSeries from narwhals.dtypes import DType + from narwhals.typing import DTypes ExprT = TypeVar("ExprT", bound=PandasLikeExpr) import pandas as pd @@ -94,6 +95,7 @@ def create_native_series( *, implementation: Implementation, backend_version: tuple[int, ...], + dtypes: DTypes, ) -> PandasLikeSeries: from narwhals._pandas_like.series import PandasLikeSeries @@ -102,7 +104,10 @@ def create_native_series( iterable, index=index, name="" ) return PandasLikeSeries( - series, implementation=implementation, backend_version=backend_version + series, + implementation=implementation, + backend_version=backend_version, + dtypes=dtypes, ) else: # pragma: no cover msg = f"Expected pandas-like implementation ({PANDAS_LIKE_IMPLEMENTATION}), found {implementation}" @@ -206,9 +211,7 @@ def set_axis( return obj.set_axis(index, axis=0, **kwargs) # type: ignore[attr-defined, no-any-return] -def native_to_narwhals_dtype(column: Any) -> DType: - from narwhals import dtypes - +def native_to_narwhals_dtype(column: Any, dtypes: DTypes) -> DType: dtype = str(column.dtype) if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}: return dtypes.Int64() @@ -280,7 +283,7 @@ def native_to_narwhals_dtype(column: Any) -> DType: try: return map_interchange_dtype_to_narwhals_dtype( - df.__dataframe__().get_column(0).dtype + df.__dataframe__().get_column(0).dtype, dtypes ) except Exception: # noqa: BLE001 return dtypes.Object() @@ -308,10 +311,11 @@ def get_dtype_backend(dtype: Any, implementation: Implementation) -> str: def narwhals_to_native_dtype( # noqa: PLR0915 - dtype: DType | type[DType], starting_dtype: Any, implementation: Implementation + dtype: DType | type[DType], + starting_dtype: Any, + implementation: Implementation, + dtypes: DTypes, ) -> Any: - from narwhals import dtypes - if "polars" in str(type(dtype)): msg = ( f"Expected Narwhals object, got: {type(dtype)}.\n\n" diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index acf70778b5..a4e30ec63f 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -17,12 +17,17 @@ import numpy as np from typing_extensions import Self + from narwhals.typing import DTypes + class PolarsDataFrame: - def __init__(self, df: Any, *, backend_version: tuple[int, ...]) -> None: + def __init__( + self, df: Any, *, backend_version: tuple[int, ...], dtypes: DTypes + ) -> None: self._native_frame = df self._backend_version = backend_version self._implementation = Implementation.POLARS + self._dtypes = dtypes def __repr__(self) -> str: # pragma: no cover return "PolarsDataFrame" @@ -31,7 +36,7 @@ def __narwhals_dataframe__(self) -> Self: return self def __narwhals_namespace__(self) -> PolarsNamespace: - return PolarsNamespace(backend_version=self._backend_version) + return PolarsNamespace(backend_version=self._backend_version, dtypes=self._dtypes) def __native_namespace__(self: Self) -> ModuleType: if self._implementation is Implementation.POLARS: @@ -41,7 +46,9 @@ def __native_namespace__(self: Self) -> ModuleType: raise AssertionError(msg) def _from_native_frame(self, df: Any) -> Self: - return self.__class__(df, backend_version=self._backend_version) + return self.__class__( + df, backend_version=self._backend_version, dtypes=self._dtypes + ) def _from_native_object(self, obj: Any) -> Any: import polars as pl # ignore-banned-import() @@ -49,7 +56,9 @@ def _from_native_object(self, obj: Any) -> Any: if isinstance(obj, pl.Series): from narwhals._polars.series import PolarsSeries - return PolarsSeries(obj, backend_version=self._backend_version) + return PolarsSeries( + obj, backend_version=self._backend_version, dtypes=self._dtypes + ) if isinstance(obj, pl.DataFrame): return self._from_native_frame(obj) # scalar @@ -78,14 +87,20 @@ def __array__(self, dtype: Any | None = None, copy: bool | None = None) -> np.nd @property def schema(self) -> dict[str, Any]: schema = self._native_frame.schema - return {name: native_to_narwhals_dtype(dtype) for name, dtype in schema.items()} + return { + name: native_to_narwhals_dtype(dtype, self._dtypes) + for name, dtype in schema.items() + } def collect_schema(self) -> dict[str, Any]: if self._backend_version < (1,): # pragma: no cover schema = self._native_frame.schema else: schema = dict(self._native_frame.collect_schema()) - return {name: native_to_narwhals_dtype(dtype) for name, dtype in schema.items()} + return { + name: native_to_narwhals_dtype(dtype, self._dtypes) + for name, dtype in schema.items() + } @property def shape(self) -> tuple[int, int]: @@ -140,14 +155,18 @@ def __getitem__(self, item: Any) -> Any: if isinstance(result, pl.Series): from narwhals._polars.series import PolarsSeries - return PolarsSeries(result, backend_version=self._backend_version) + return PolarsSeries( + result, backend_version=self._backend_version, dtypes=self._dtypes + ) return self._from_native_object(result) def get_column(self, name: str) -> Any: from narwhals._polars.series import PolarsSeries return PolarsSeries( - self._native_frame.get_column(name), backend_version=self._backend_version + self._native_frame.get_column(name), + backend_version=self._backend_version, + dtypes=self._dtypes, ) def is_empty(self) -> bool: @@ -159,7 +178,9 @@ def columns(self) -> list[str]: def lazy(self) -> PolarsLazyFrame: return PolarsLazyFrame( - self._native_frame.lazy(), backend_version=self._backend_version + self._native_frame.lazy(), + backend_version=self._backend_version, + dtypes=self._dtypes, ) def to_dict(self, *, as_series: bool) -> Any: @@ -169,7 +190,9 @@ def to_dict(self, *, as_series: bool) -> Any: from narwhals._polars.series import PolarsSeries return { - name: PolarsSeries(col, backend_version=self._backend_version) + name: PolarsSeries( + col, backend_version=self._backend_version, dtypes=self._dtypes + ) for name, col in df.to_dict(as_series=True).items() } else: @@ -217,10 +240,13 @@ def unpivot( class PolarsLazyFrame: - def __init__(self, df: Any, *, backend_version: tuple[int, ...]) -> None: + def __init__( + self, df: Any, *, backend_version: tuple[int, ...], dtypes: DTypes + ) -> None: self._native_frame = df self._backend_version = backend_version self._implementation = Implementation.POLARS + self._dtypes = dtypes def __repr__(self) -> str: # pragma: no cover return "PolarsLazyFrame" @@ -229,7 +255,7 @@ def __narwhals_lazyframe__(self) -> Self: return self def __narwhals_namespace__(self) -> PolarsNamespace: - return PolarsNamespace(backend_version=self._backend_version) + return PolarsNamespace(backend_version=self._backend_version, dtypes=self._dtypes) def __native_namespace__(self: Self) -> ModuleType: if self._implementation is Implementation.POLARS: @@ -239,7 +265,9 @@ def __native_namespace__(self: Self) -> ModuleType: raise AssertionError(msg) def _from_native_frame(self, df: Any) -> Self: - return self.__class__(df, backend_version=self._backend_version) + return self.__class__( + df, backend_version=self._backend_version, dtypes=self._dtypes + ) def __getattr__(self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: @@ -257,18 +285,26 @@ def columns(self) -> list[str]: @property def schema(self) -> dict[str, Any]: schema = self._native_frame.schema - return {name: native_to_narwhals_dtype(dtype) for name, dtype in schema.items()} + return { + name: native_to_narwhals_dtype(dtype, self._dtypes) + for name, dtype in schema.items() + } def collect_schema(self) -> dict[str, Any]: if self._backend_version < (1,): # pragma: no cover schema = self._native_frame.schema else: schema = dict(self._native_frame.collect_schema()) - return {name: native_to_narwhals_dtype(dtype) for name, dtype in schema.items()} + return { + name: native_to_narwhals_dtype(dtype, self._dtypes) + for name, dtype in schema.items() + } def collect(self) -> PolarsDataFrame: return PolarsDataFrame( - self._native_frame.collect(), backend_version=self._backend_version + self._native_frame.collect(), + backend_version=self._backend_version, + dtypes=self._dtypes, ) def group_by(self, *by: str) -> Any: diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 4f15328232..8a4c93736b 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -12,18 +12,20 @@ from typing_extensions import Self from narwhals.dtypes import DType + from narwhals.typing import DTypes class PolarsExpr: - def __init__(self, expr: Any) -> None: + def __init__(self, expr: Any, dtypes: DTypes) -> None: self._native_expr = expr self._implementation = Implementation.POLARS + self._dtypes = dtypes def __repr__(self) -> str: # pragma: no cover return "PolarsExpr" def _from_native_expr(self, expr: Any) -> Self: - return self.__class__(expr) + return self.__class__(expr, dtypes=self._dtypes) def __getattr__(self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: @@ -36,7 +38,7 @@ def func(*args: Any, **kwargs: Any) -> Any: def cast(self, dtype: DType) -> Self: expr = self._native_expr - dtype = narwhals_to_native_dtype(dtype) + dtype = narwhals_to_native_dtype(dtype, self._dtypes) return self._from_native_expr(expr.cast(dtype)) def __eq__(self, other: object) -> Self: # type: ignore[override] diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 275c104fc9..21facd81fa 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -7,7 +7,6 @@ from typing import Literal from typing import Sequence -from narwhals import dtypes from narwhals._expression_parsing import parse_into_exprs from narwhals._polars.utils import extract_args_kwargs from narwhals._polars.utils import narwhals_to_native_dtype @@ -18,35 +17,15 @@ from narwhals._polars.dataframe import PolarsLazyFrame from narwhals._polars.expr import PolarsExpr from narwhals._polars.typing import IntoPolarsExpr + from narwhals.dtypes import DType + from narwhals.typing import DTypes class PolarsNamespace: - Int64 = dtypes.Int64 - Int32 = dtypes.Int32 - Int16 = dtypes.Int16 - Int8 = dtypes.Int8 - UInt64 = dtypes.UInt64 - UInt32 = dtypes.UInt32 - UInt16 = dtypes.UInt16 - UInt8 = dtypes.UInt8 - Float64 = dtypes.Float64 - Float32 = dtypes.Float32 - Boolean = dtypes.Boolean - Object = dtypes.Object - Unknown = dtypes.Unknown - Categorical = dtypes.Categorical - Enum = dtypes.Enum - String = dtypes.String - Datetime = dtypes.Datetime - Duration = dtypes.Duration - Date = dtypes.Date - List = dtypes.List - Struct = dtypes.Struct - Array = dtypes.Array - - def __init__(self, *, backend_version: tuple[int, ...]) -> None: + def __init__(self, *, backend_version: tuple[int, ...], dtypes: DTypes) -> None: self._backend_version = backend_version self._implementation = Implementation.POLARS + self._dtypes = dtypes def __getattr__(self, attr: str) -> Any: import polars as pl # ignore-banned-import @@ -55,7 +34,7 @@ def __getattr__(self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment] - return PolarsExpr(getattr(pl, attr)(*args, **kwargs)) + return PolarsExpr(getattr(pl, attr)(*args, **kwargs), dtypes=self._dtypes) return func @@ -67,7 +46,7 @@ def nth(self, *indices: int) -> PolarsExpr: if self._backend_version < (1, 0, 0): # pragma: no cover msg = "`nth` is only supported for Polars>=1.0.0. Please use `col` for columns selection instead." raise AttributeError(msg) - return PolarsExpr(pl.nth(*indices)) + return PolarsExpr(pl.nth(*indices), dtypes=self._dtypes) def len(self) -> PolarsExpr: import polars as pl # ignore-banned-import() @@ -75,8 +54,8 @@ def len(self) -> PolarsExpr: from narwhals._polars.expr import PolarsExpr if self._backend_version < (0, 20, 5): # pragma: no cover - return PolarsExpr(pl.count().alias("len")) - return PolarsExpr(pl.len()) + return PolarsExpr(pl.count().alias("len"), dtypes=self._dtypes) + return PolarsExpr(pl.len(), dtypes=self._dtypes) def concat( self, @@ -92,17 +71,24 @@ def concat( dfs: list[Any] = [item._native_frame for item in items] result = pl.concat(dfs, how=how) if isinstance(result, pl.DataFrame): - return PolarsDataFrame(result, backend_version=items[0]._backend_version) - return PolarsLazyFrame(result, backend_version=items[0]._backend_version) + return PolarsDataFrame( + result, backend_version=items[0]._backend_version, dtypes=items[0]._dtypes + ) + return PolarsLazyFrame( + result, backend_version=items[0]._backend_version, dtypes=items[0]._dtypes + ) - def lit(self, value: Any, dtype: dtypes.DType | None = None) -> PolarsExpr: + def lit(self, value: Any, dtype: DType | None = None) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr if dtype is not None: - return PolarsExpr(pl.lit(value, dtype=narwhals_to_native_dtype(dtype))) - return PolarsExpr(pl.lit(value)) + return PolarsExpr( + pl.lit(value, dtype=narwhals_to_native_dtype(dtype, self._dtypes)), + dtypes=self._dtypes, + ) + return PolarsExpr(pl.lit(value), dtypes=self._dtypes) def mean(self, *column_names: str) -> PolarsExpr: import polars as pl # ignore-banned-import() @@ -110,8 +96,8 @@ def mean(self, *column_names: str) -> PolarsExpr: from narwhals._polars.expr import PolarsExpr if self._backend_version < (0, 20, 4): # pragma: no cover - return PolarsExpr(pl.mean([*column_names])) # type: ignore[arg-type] - return PolarsExpr(pl.mean(*column_names)) + return PolarsExpr(pl.mean([*column_names]), dtypes=self._dtypes) # type: ignore[arg-type] + return PolarsExpr(pl.mean(*column_names), dtypes=self._dtypes) def mean_horizontal(self, *exprs: IntoPolarsExpr) -> PolarsExpr: import polars as pl # ignore-banned-import() @@ -125,23 +111,34 @@ def mean_horizontal(self, *exprs: IntoPolarsExpr) -> PolarsExpr: n_non_zero = reduce( lambda x, y: x + y, ((1 - e.is_null()) for e in polars_exprs) ) - return PolarsExpr(total._native_expr / n_non_zero._native_expr) + return PolarsExpr( + total._native_expr / n_non_zero._native_expr, dtypes=self._dtypes + ) - return PolarsExpr(pl.mean_horizontal([e._native_expr for e in polars_exprs])) + return PolarsExpr( + pl.mean_horizontal([e._native_expr for e in polars_exprs]), + dtypes=self._dtypes, + ) @property def selectors(self) -> PolarsSelectors: - return PolarsSelectors() + return PolarsSelectors(self._dtypes) class PolarsSelectors: - def by_dtype(self, dtypes: Iterable[dtypes.DType]) -> PolarsExpr: + def __init__(self, dtypes: DTypes) -> None: + self._dtypes = dtypes + + def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr return PolarsExpr( - pl.selectors.by_dtype([narwhals_to_native_dtype(dtype) for dtype in dtypes]) + pl.selectors.by_dtype( + [narwhals_to_native_dtype(dtype, self._dtypes) for dtype in dtypes] + ), + dtypes=self._dtypes, ) def numeric(self) -> PolarsExpr: @@ -149,32 +146,32 @@ def numeric(self) -> PolarsExpr: from narwhals._polars.expr import PolarsExpr - return PolarsExpr(pl.selectors.numeric()) + return PolarsExpr(pl.selectors.numeric(), dtypes=self._dtypes) def boolean(self) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr - return PolarsExpr(pl.selectors.boolean()) + return PolarsExpr(pl.selectors.boolean(), dtypes=self._dtypes) def string(self) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr - return PolarsExpr(pl.selectors.string()) + return PolarsExpr(pl.selectors.string(), dtypes=self._dtypes) def categorical(self) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr - return PolarsExpr(pl.selectors.categorical()) + return PolarsExpr(pl.selectors.categorical(), dtypes=self._dtypes) def all(self) -> PolarsExpr: import polars as pl # ignore-banned-import() from narwhals._polars.expr import PolarsExpr - return PolarsExpr(pl.selectors.all()) + return PolarsExpr(pl.selectors.all(), dtypes=self._dtypes) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 7f7bf94a25..0780421959 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -17,16 +17,20 @@ from narwhals._polars.dataframe import PolarsDataFrame from narwhals.dtypes import DType + from narwhals.typing import DTypes from narwhals._polars.utils import narwhals_to_native_dtype from narwhals._polars.utils import native_to_narwhals_dtype class PolarsSeries: - def __init__(self, series: Any, *, backend_version: tuple[int, ...]) -> None: + def __init__( + self, series: Any, *, backend_version: tuple[int, ...], dtypes: DTypes + ) -> None: self._native_series = series self._backend_version = backend_version self._implementation = Implementation.POLARS + self._dtypes = dtypes def __repr__(self) -> str: # pragma: no cover return "PolarsSeries" @@ -42,7 +46,9 @@ def __native_namespace__(self: Self) -> ModuleType: raise AssertionError(msg) def _from_native_series(self, series: Any) -> Self: - return self.__class__(series, backend_version=self._backend_version) + return self.__class__( + series, backend_version=self._backend_version, dtypes=self._dtypes + ) def _from_native_object(self, series: Any) -> Any: import polars as pl # ignore-banned-import() @@ -52,7 +58,9 @@ def _from_native_object(self, series: Any) -> Any: if isinstance(series, pl.DataFrame): from narwhals._polars.dataframe import PolarsDataFrame - return PolarsDataFrame(series, backend_version=self._backend_version) + return PolarsDataFrame( + series, backend_version=self._backend_version, dtypes=self._dtypes + ) # scalar return series @@ -81,7 +89,7 @@ def name(self) -> str: @property def dtype(self: Self) -> DType: - return native_to_narwhals_dtype(self._native_series.dtype) + return native_to_narwhals_dtype(self._native_series.dtype, self._dtypes) @overload def __getitem__(self, item: int) -> Any: ... @@ -94,7 +102,7 @@ def __getitem__(self, item: int | slice | Sequence[int]) -> Any | Self: def cast(self, dtype: DType) -> Self: ser = self._native_series - dtype = narwhals_to_native_dtype(dtype) + dtype = narwhals_to_native_dtype(dtype, self._dtypes) return self._from_native_series(ser.cast(dtype)) def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray: @@ -184,7 +192,9 @@ def to_dummies( separator=separator, drop_first=drop_first ) - return PolarsDataFrame(result, backend_version=self._backend_version) + return PolarsDataFrame( + result, backend_version=self._backend_version, dtypes=self._dtypes + ) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: if self._backend_version < (0, 20, 6): # pragma: no cover @@ -232,7 +242,9 @@ def value_counts( sort=sort, parallel=parallel, name=name, normalize=normalize ) - return PolarsDataFrame(result, backend_version=self._backend_version) + return PolarsDataFrame( + result, backend_version=self._backend_version, dtypes=self._dtypes + ) @property def dt(self) -> PolarsSeriesDateTimeNamespace: diff --git a/narwhals/_polars/utils.py b/narwhals/_polars/utils.py index 45c464e51f..db5a4a96bf 100644 --- a/narwhals/_polars/utils.py +++ b/narwhals/_polars/utils.py @@ -1,8 +1,11 @@ from __future__ import annotations +from typing import TYPE_CHECKING from typing import Any -from narwhals import dtypes +if TYPE_CHECKING: + from narwhals.dtypes import DType + from narwhals.typing import DTypes def extract_native(obj: Any) -> Any: @@ -26,7 +29,7 @@ def extract_args_kwargs(args: Any, kwargs: Any) -> tuple[list[Any], dict[str, An return args, kwargs -def native_to_narwhals_dtype(dtype: Any) -> dtypes.DType: +def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType: import polars as pl # ignore-banned-import() if dtype == pl.Float64: @@ -74,11 +77,9 @@ def native_to_narwhals_dtype(dtype: Any) -> dtypes.DType: return dtypes.Unknown() -def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any: +def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any: import polars as pl # ignore-banned-import() - from narwhals import dtypes - if dtype == dtypes.Float64: return pl.Float64() if dtype == dtypes.Float32: diff --git a/narwhals/functions.py b/narwhals/functions.py index 430705e662..f0bf5d4ad2 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -26,6 +26,7 @@ from narwhals.dtypes import DType from narwhals.schema import Schema from narwhals.series import Series + from narwhals.typing import DTypes def concat( @@ -194,6 +195,25 @@ def new_series( 2 ] """ + from narwhals import dtypes + + return _new_series_impl( + name, + values, + dtype, + native_namespace=native_namespace, + dtypes=dtypes, # type: ignore[arg-type] + ) + + +def _new_series_impl( + name: str, + values: Any, + dtype: DType | type[DType] | None = None, + *, + native_namespace: ModuleType, + dtypes: DTypes, +) -> Series: implementation = Implementation.from_native_namespace(native_namespace) if implementation is Implementation.POLARS: @@ -202,7 +222,7 @@ def new_series( narwhals_to_native_dtype as polars_narwhals_to_native_dtype, ) - dtype = polars_narwhals_to_native_dtype(dtype) + dtype = polars_narwhals_to_native_dtype(dtype, dtypes=dtypes) native_series = native_namespace.Series(name=name, values=values, dtype=dtype) elif implementation in { @@ -215,7 +235,12 @@ def new_series( narwhals_to_native_dtype as pandas_like_narwhals_to_native_dtype, ) - dtype = pandas_like_narwhals_to_native_dtype(dtype, None, implementation) + dtype = pandas_like_narwhals_to_native_dtype( + dtype, + None, + implementation, + dtypes, + ) native_series = native_namespace.Series(values, name=name, dtype=dtype) elif implementation is Implementation.PYARROW: @@ -224,7 +249,7 @@ def new_series( narwhals_to_native_dtype as arrow_narwhals_to_native_dtype, ) - dtype = arrow_narwhals_to_native_dtype(dtype) + dtype = arrow_narwhals_to_native_dtype(dtype, dtypes=dtypes) native_series = native_namespace.chunked_array([values], type=dtype) elif implementation is Implementation.DASK: @@ -291,6 +316,23 @@ def from_dict( │ 2 ┆ 4 │ └─────┴─────┘ """ + from narwhals import dtypes + + return _from_dict_impl( + data, + schema, + native_namespace=native_namespace, + dtypes=dtypes, # type: ignore[arg-type] + ) + + +def _from_dict_impl( + data: dict[str, Any], + schema: dict[str, DType] | Schema | None = None, + *, + native_namespace: ModuleType | None = None, + dtypes: DTypes, +) -> DataFrame[Any]: from narwhals.series import Series from narwhals.translate import to_native @@ -315,7 +357,7 @@ def from_dict( ) schema = { - name: polars_narwhals_to_native_dtype(dtype) + name: polars_narwhals_to_native_dtype(dtype, dtypes=dtypes) for name, dtype in schema.items() } @@ -334,7 +376,10 @@ def from_dict( schema = { name: pandas_like_narwhals_to_native_dtype( - schema[name], native_type, implementation + schema[name], + native_type, + implementation, + dtypes, ) for name, native_type in native_frame.dtypes.items() } @@ -348,7 +393,7 @@ def from_dict( schema = native_namespace.schema( [ - (name, arrow_narwhals_to_native_dtype(dtype)) + (name, arrow_narwhals_to_native_dtype(dtype, dtypes)) for name, dtype in schema.items() ] ) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 8c4bd877c3..b542b90fa6 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -15,35 +15,38 @@ from narwhals import selectors from narwhals.dataframe import DataFrame as NwDataFrame from narwhals.dataframe import LazyFrame as NwLazyFrame -from narwhals.dtypes import Array -from narwhals.dtypes import Boolean -from narwhals.dtypes import Categorical -from narwhals.dtypes import Date -from narwhals.dtypes import Datetime -from narwhals.dtypes import Duration -from narwhals.dtypes import Enum -from narwhals.dtypes import Float32 -from narwhals.dtypes import Float64 -from narwhals.dtypes import Int8 -from narwhals.dtypes import Int16 -from narwhals.dtypes import Int32 -from narwhals.dtypes import Int64 -from narwhals.dtypes import List -from narwhals.dtypes import Object -from narwhals.dtypes import String -from narwhals.dtypes import Struct -from narwhals.dtypes import UInt8 -from narwhals.dtypes import UInt16 -from narwhals.dtypes import UInt32 -from narwhals.dtypes import UInt64 -from narwhals.dtypes import Unknown from narwhals.expr import Expr as NwExpr from narwhals.expr import Then as NwThen from narwhals.expr import When as NwWhen from narwhals.expr import when as nw_when +from narwhals.functions import _from_dict_impl +from narwhals.functions import _new_series_impl from narwhals.functions import show_versions from narwhals.schema import Schema as NwSchema from narwhals.series import Series as NwSeries +from narwhals.stable.v1.dtypes import Array +from narwhals.stable.v1.dtypes import Boolean +from narwhals.stable.v1.dtypes import Categorical +from narwhals.stable.v1.dtypes import Date +from narwhals.stable.v1.dtypes import Datetime +from narwhals.stable.v1.dtypes import Duration +from narwhals.stable.v1.dtypes import Enum +from narwhals.stable.v1.dtypes import Float32 +from narwhals.stable.v1.dtypes import Float64 +from narwhals.stable.v1.dtypes import Int8 +from narwhals.stable.v1.dtypes import Int16 +from narwhals.stable.v1.dtypes import Int32 +from narwhals.stable.v1.dtypes import Int64 +from narwhals.stable.v1.dtypes import List +from narwhals.stable.v1.dtypes import Object +from narwhals.stable.v1.dtypes import String +from narwhals.stable.v1.dtypes import Struct +from narwhals.stable.v1.dtypes import UInt8 +from narwhals.stable.v1.dtypes import UInt16 +from narwhals.stable.v1.dtypes import UInt32 +from narwhals.stable.v1.dtypes import UInt64 +from narwhals.stable.v1.dtypes import Unknown +from narwhals.translate import _from_native_impl from narwhals.translate import get_native_namespace as nw_get_native_namespace from narwhals.translate import to_native from narwhals.typing import IntoDataFrameT @@ -811,18 +814,21 @@ def from_native( Returns: narwhals.DataFrame or narwhals.LazyFrame or narwhals.Series """ + from narwhals.stable.v1 import dtypes + # Early returns if isinstance(native_dataframe, (DataFrame, LazyFrame)) and not series_only: return native_dataframe if isinstance(native_dataframe, Series) and (series_only or allow_series): return native_dataframe - result = nw.from_native( + result = _from_native_impl( native_dataframe, strict=strict, eager_only=eager_only, eager_or_interchange_only=eager_or_interchange_only, series_only=series_only, allow_series=allow_series, + dtypes=dtypes, # type: ignore[arg-type] ) return _stableify(result) @@ -1941,8 +1947,16 @@ def new_series( 2 ] """ + from narwhals.stable.v1 import dtypes + return _stableify( - nw.new_series(name, values, dtype, native_namespace=native_namespace) + _new_series_impl( + name, + values, + dtype, + native_namespace=native_namespace, + dtypes=dtypes, # type: ignore[arg-type] + ) ) @@ -1996,8 +2010,15 @@ def from_dict( │ 2 ┆ 4 │ └─────┴─────┘ """ - return _stableify( # type: ignore[no-any-return] - nw.from_dict(data, schema=schema, native_namespace=native_namespace) + from narwhals.stable.v1 import dtypes + + return _stableify( + _from_dict_impl( + data, + schema, + native_namespace=native_namespace, + dtypes=dtypes, # type: ignore[arg-type] + ) ) diff --git a/narwhals/stable/v1/dtypes.py b/narwhals/stable/v1/dtypes.py new file mode 100644 index 0000000000..942881ba47 --- /dev/null +++ b/narwhals/stable/v1/dtypes.py @@ -0,0 +1,47 @@ +from narwhals.dtypes import Array +from narwhals.dtypes import Boolean +from narwhals.dtypes import Categorical +from narwhals.dtypes import Date +from narwhals.dtypes import Datetime +from narwhals.dtypes import Duration +from narwhals.dtypes import Enum +from narwhals.dtypes import Float32 +from narwhals.dtypes import Float64 +from narwhals.dtypes import Int8 +from narwhals.dtypes import Int16 +from narwhals.dtypes import Int32 +from narwhals.dtypes import Int64 +from narwhals.dtypes import List +from narwhals.dtypes import Object +from narwhals.dtypes import String +from narwhals.dtypes import Struct +from narwhals.dtypes import UInt8 +from narwhals.dtypes import UInt16 +from narwhals.dtypes import UInt32 +from narwhals.dtypes import UInt64 +from narwhals.dtypes import Unknown + +__all__ = [ + "Array", + "Boolean", + "Categorical", + "Date", + "Datetime", + "Duration", + "Enum", + "Float32", + "Float64", + "Int8", + "Int16", + "Int32", + "Int64", + "List", + "Object", + "String", + "Struct", + "UInt8", + "UInt16", + "UInt32", + "UInt64", + "Unknown", +] diff --git a/narwhals/translate.py b/narwhals/translate.py index 0e7706fb7e..4c23f6d911 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -34,6 +34,7 @@ from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame from narwhals.series import Series + from narwhals.typing import DTypes from narwhals.typing import IntoDataFrameT from narwhals.typing import IntoFrameT @@ -296,7 +297,7 @@ def from_native( ) -> Any: ... -def from_native( # noqa: PLR0915 +def from_native( native_object: Any, *, strict: bool = True, @@ -330,6 +331,29 @@ def from_native( # noqa: PLR0915 Returns: narwhals.DataFrame or narwhals.LazyFrame or narwhals.Series """ + from narwhals import dtypes + + return _from_native_impl( + native_object, + strict=strict, + eager_only=eager_only, + eager_or_interchange_only=eager_or_interchange_only, + series_only=series_only, + allow_series=allow_series, + dtypes=dtypes, # type: ignore[arg-type] + ) + + +def _from_native_impl( # noqa: PLR0915 + native_object: Any, + *, + strict: bool = True, + eager_only: bool | None = None, + eager_or_interchange_only: bool | None = None, + series_only: bool | None = None, + allow_series: bool | None = None, + dtypes: DTypes, +) -> Any: from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.series import ArrowSeries from narwhals._dask.dataframe import DaskLazyFrame @@ -398,7 +422,11 @@ def from_native( # noqa: PLR0915 raise TypeError(msg) pl = get_polars() return DataFrame( - PolarsDataFrame(native_object, backend_version=parse_version(pl.__version__)), + PolarsDataFrame( + native_object, + backend_version=parse_version(pl.__version__), + dtypes=dtypes, + ), level="full", ) elif is_polars_lazyframe(native_object): @@ -410,7 +438,11 @@ def from_native( # noqa: PLR0915 raise TypeError(msg) pl = get_polars() return LazyFrame( - PolarsLazyFrame(native_object, backend_version=parse_version(pl.__version__)), + PolarsLazyFrame( + native_object, + backend_version=parse_version(pl.__version__), + dtypes=dtypes, + ), level="full", ) elif is_polars_series(native_object): @@ -419,7 +451,11 @@ def from_native( # noqa: PLR0915 msg = "Please set `allow_series=True`" raise TypeError(msg) return Series( - PolarsSeries(native_object, backend_version=parse_version(pl.__version__)), + PolarsSeries( + native_object, + backend_version=parse_version(pl.__version__), + dtypes=dtypes, + ), level="full", ) @@ -434,6 +470,7 @@ def from_native( # noqa: PLR0915 native_object, backend_version=parse_version(pd.__version__), implementation=Implementation.PANDAS, + dtypes=dtypes, ), level="full", ) @@ -447,6 +484,7 @@ def from_native( # noqa: PLR0915 native_object, implementation=Implementation.PANDAS, backend_version=parse_version(pd.__version__), + dtypes=dtypes, ), level="full", ) @@ -462,6 +500,7 @@ def from_native( # noqa: PLR0915 native_object, implementation=Implementation.MODIN, backend_version=parse_version(mpd.__version__), + dtypes=dtypes, ), level="full", ) @@ -475,6 +514,7 @@ def from_native( # noqa: PLR0915 native_object, implementation=Implementation.MODIN, backend_version=parse_version(mpd.__version__), + dtypes=dtypes, ), level="full", ) @@ -490,6 +530,7 @@ def from_native( # noqa: PLR0915 native_object, implementation=Implementation.CUDF, backend_version=parse_version(cudf.__version__), + dtypes=dtypes, ), level="full", ) @@ -503,6 +544,7 @@ def from_native( # noqa: PLR0915 native_object, implementation=Implementation.CUDF, backend_version=parse_version(cudf.__version__), + dtypes=dtypes, ), level="full", ) @@ -514,7 +556,11 @@ def from_native( # noqa: PLR0915 msg = "Cannot only use `series_only` with arrow table" raise TypeError(msg) return DataFrame( - ArrowDataFrame(native_object, backend_version=parse_version(pa.__version__)), + ArrowDataFrame( + native_object, + backend_version=parse_version(pa.__version__), + dtypes=dtypes, + ), level="full", ) elif is_pyarrow_chunked_array(native_object): @@ -524,7 +570,10 @@ def from_native( # noqa: PLR0915 raise TypeError(msg) return Series( ArrowSeries( - native_object, backend_version=parse_version(pa.__version__), name="" + native_object, + backend_version=parse_version(pa.__version__), + name="", + dtypes=dtypes, ), level="full", ) @@ -542,7 +591,9 @@ def from_native( # noqa: PLR0915 raise ImportError(msg) return LazyFrame( DaskLazyFrame( - native_object, backend_version=parse_version(get_dask().__version__) + native_object, + backend_version=parse_version(get_dask().__version__), + dtypes=dtypes, ), level="full", ) @@ -556,7 +607,7 @@ def from_native( # noqa: PLR0915 ) raise TypeError(msg) return DataFrame( - DuckDBInterchangeFrame(native_object), + DuckDBInterchangeFrame(native_object, dtypes=dtypes), level="interchange", ) @@ -569,7 +620,7 @@ def from_native( # noqa: PLR0915 ) raise TypeError(msg) return DataFrame( - IbisInterchangeFrame(native_object), + IbisInterchangeFrame(native_object, dtypes=dtypes), level="interchange", ) @@ -582,7 +633,7 @@ def from_native( # noqa: PLR0915 ) raise TypeError(msg) return DataFrame( - InterchangeFrame(native_object), + InterchangeFrame(native_object, dtypes=dtypes), level="interchange", ) diff --git a/narwhals/typing.py b/narwhals/typing.py index ecc89a4b25..62a7ca58c7 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -14,6 +14,7 @@ else: from typing_extensions import TypeAlias + from narwhals import dtypes from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame from narwhals.expr import Expr @@ -52,6 +53,32 @@ def __dataframe__(self, *args: Any, **kwargs: Any) -> Any: ... FrameT = TypeVar("FrameT", "DataFrame[Any]", "LazyFrame[Any]") DataFrameT = TypeVar("DataFrameT", bound="DataFrame[Any]") + +class DTypes: + Int64: type[dtypes.Int64] + Int32: type[dtypes.Int32] + Int16: type[dtypes.Int16] + Int8: type[dtypes.Int8] + UInt64: type[dtypes.UInt64] + UInt32: type[dtypes.UInt32] + UInt16: type[dtypes.UInt16] + UInt8: type[dtypes.UInt8] + Float64: type[dtypes.Float64] + Float32: type[dtypes.Float32] + String: type[dtypes.String] + Boolean: type[dtypes.Boolean] + Object: type[dtypes.Object] + Categorical: type[dtypes.Categorical] + Enum: type[dtypes.Enum] + Datetime: type[dtypes.Datetime] + Duration: type[dtypes.Duration] + Date: type[dtypes.Date] + Struct: type[dtypes.Struct] + List: type[dtypes.List] + Array: type[dtypes.Array] + Unknown: type[dtypes.Unknown] + + __all__ = [ "IntoExpr", "IntoDataFrame", diff --git a/narwhals/utils.py b/narwhals/utils.py index 0d9503240b..62ae7730b6 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -11,7 +11,6 @@ from typing import TypeVar from typing import cast -from narwhals import dtypes from narwhals._exceptions import ColumnNotFoundError from narwhals.dependencies import get_cudf from narwhals.dependencies import get_dask_dataframe @@ -393,6 +392,8 @@ def is_ordered_categorical(series: Series) -> bool: """ from narwhals._interchange.series import InterchangeSeries + dtypes = series._compliant_series._dtypes + if ( isinstance(series._compliant_series, InterchangeSeries) and series.dtype == dtypes.Categorical diff --git a/tests/from_dict_test.py b/tests/from_dict_test.py index a1332908a3..4583b03e5b 100644 --- a/tests/from_dict_test.py +++ b/tests/from_dict_test.py @@ -1,6 +1,7 @@ import pytest -import narwhals.stable.v1 as nw +import narwhals as nw +import narwhals.stable.v1 as nw_v1 from tests.utils import Constructor from tests.utils import compare_dicts @@ -21,10 +22,10 @@ def test_from_dict_schema( ) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) - schema = {"c": nw.Int16(), "d": nw.Float32()} - df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})) - native_namespace = nw.get_native_namespace(df) - result = nw.from_dict( + schema = {"c": nw_v1.Int16(), "d": nw_v1.Float32()} + df = nw_v1.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})) + native_namespace = nw_v1.get_native_namespace(df) + result = nw_v1.from_dict( {"c": [1, 2], "d": [5, 6]}, native_namespace=native_namespace, schema=schema, # type: ignore[arg-type] @@ -55,6 +56,17 @@ def test_from_dict_one_native_one_narwhals( compare_dicts(result, expected) +def test_from_dict_v1(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "dask" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})) + native_namespace = nw.get_native_namespace(df) + result = nw.from_dict({"c": [1, 2], "d": [5, 6]}, native_namespace=native_namespace) + expected = {"c": [1, 2], "d": [5, 6]} + compare_dicts(result, expected) + assert isinstance(result, nw.DataFrame) + + def test_from_dict_empty() -> None: with pytest.raises(ValueError, match="empty"): nw.from_dict({}) diff --git a/tests/new_series_test.py b/tests/new_series_test.py index 8ddcabd408..fad4a75363 100644 --- a/tests/new_series_test.py +++ b/tests/new_series_test.py @@ -3,7 +3,8 @@ import pandas as pd import pytest -import narwhals.stable.v1 as nw +import narwhals as nw +import narwhals.stable.v1 as nw_v1 from tests.utils import compare_dicts @@ -24,6 +25,25 @@ def test_new_series(constructor_eager: Any) -> None: compare_dicts(result.to_frame(), expected) +def test_new_series_v1(constructor_eager: Any) -> None: + s = nw_v1.from_native(constructor_eager({"a": [1, 2, 3]}), eager_only=True)["a"] + result = nw_v1.new_series( + "b", [4, 1, 2], native_namespace=nw_v1.get_native_namespace(s) + ) + expected = {"b": [4, 1, 2]} + # all supported libraries auto-infer this to be int64, we can always special-case + # something different if necessary + assert result.dtype == nw_v1.Int64 + compare_dicts(result.to_frame(), expected) + + result = nw_v1.new_series( + "b", [4, 1, 2], nw_v1.Int32, native_namespace=nw_v1.get_native_namespace(s) + ) + expected = {"b": [4, 1, 2]} + assert result.dtype == nw_v1.Int32 + compare_dicts(result.to_frame(), expected) + + def test_new_series_dask() -> None: pytest.importorskip("dask") pytest.importorskip("dask_expr", exc_type=ImportError) diff --git a/tests/stable_api_test.py b/tests/stable_api_test.py index d579b8185d..a12b20cc65 100644 --- a/tests/stable_api_test.py +++ b/tests/stable_api_test.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Any import polars as pl @@ -135,3 +136,10 @@ def test_series_docstrings() -> None: ) == getattr(df, item).__doc__ ) + + +def test_dtypes(constructor: Constructor) -> None: + df = nw.from_native(constructor({"a": [1], "b": [datetime(2020, 1, 1)]})) + dtype = df.collect_schema()["b"] + assert dtype in {nw.Datetime} + assert isinstance(dtype, nw.Datetime)