Skip to content

Commit

Permalink
feat: add dtypes to stable api (narwhals-dev#1087)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored and akmalsoliev committed Oct 15, 2024
1 parent 6a9eca6 commit 2139f93
Show file tree
Hide file tree
Showing 40 changed files with 851 additions and 374 deletions.
5 changes: 5 additions & 0 deletions docs/how_it_works.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ from narwhals.utils import parse_version
pn = PandasLikeNamespace(
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=nw.dtypes,
)
print(nw.col("a")._call(pn))
```
Expand All @@ -101,13 +102,15 @@ import pandas as pd
pn = PandasLikeNamespace(
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=nw.dtypes,
)

df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
df = PandasLikeDataFrame(
df_pd,
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=nw.dtypes,
)
expression = pn.col("a") + 1
result = expression._call(df)
Expand Down Expand Up @@ -196,6 +199,7 @@ import pandas as pd
pn = PandasLikeNamespace(
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=nw.dtypes,
)

df_pd = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
Expand All @@ -210,6 +214,7 @@ backend, and it does so by passing a Narwhals-compliant namespace to `nw.Expr._c
pn = PandasLikeNamespace(
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=nw.dtypes,
)
expr = (nw.col("a") + 1)._call(pn)
print(expr)
Expand Down
43 changes: 35 additions & 8 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,27 @@
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import IntoArrowExpr
from narwhals.dtypes import DType
from narwhals.typing import DTypes


class ArrowDataFrame:
# --- not in the spec ---
def __init__(
self, native_dataframe: pa.Table, *, backend_version: tuple[int, ...]
self,
native_dataframe: pa.Table,
*,
backend_version: tuple[int, ...],
dtypes: DTypes,
) -> None:
self._native_frame = native_dataframe
self._implementation = Implementation.PYARROW
self._backend_version = backend_version
self._dtypes = dtypes

def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace

return ArrowNamespace(backend_version=self._backend_version)
return ArrowNamespace(backend_version=self._backend_version, dtypes=self._dtypes)

def __native_namespace__(self: Self) -> ModuleType:
if self._implementation is Implementation.PYARROW:
Expand All @@ -63,7 +69,9 @@ def __narwhals_lazyframe__(self) -> Self:
return self

def _from_native_frame(self, df: Any) -> Self:
return self.__class__(df, backend_version=self._backend_version)
return self.__class__(
df, backend_version=self._backend_version, dtypes=self._dtypes
)

@property
def shape(self) -> tuple[int, int]:
Expand Down Expand Up @@ -111,6 +119,7 @@ def get_column(self, name: str) -> ArrowSeries:
self._native_frame[name],
name=name,
backend_version=self._backend_version,
dtypes=self._dtypes,
)

def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray:
Expand Down Expand Up @@ -151,6 +160,7 @@ def __getitem__(
self._native_frame[item],
name=item,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
elif (
isinstance(item, tuple)
Expand Down Expand Up @@ -191,12 +201,14 @@ def __getitem__(
self._native_frame[col_name],
name=col_name,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
selected_rows = select_rows(self._native_frame, item[0])
return ArrowSeries(
selected_rows[col_name],
name=col_name,
backend_version=self._backend_version,
dtypes=self._dtypes,
)

elif isinstance(item, slice):
Expand Down Expand Up @@ -234,7 +246,7 @@ def __getitem__(
def schema(self) -> dict[str, DType]:
schema = self._native_frame.schema
return {
name: native_to_narwhals_dtype(dtype)
name: native_to_narwhals_dtype(dtype, self._dtypes)
for name, dtype in zip(schema.names, schema.types)
}

Expand Down Expand Up @@ -410,7 +422,12 @@ def to_dict(self, *, as_series: bool) -> Any:
from narwhals._arrow.series import ArrowSeries

return {
name: ArrowSeries(col, name=name, backend_version=self._backend_version)
name: ArrowSeries(
col,
name=name,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
for name, col in names_and_values
}
else:
Expand Down Expand Up @@ -471,7 +488,9 @@ def lazy(self) -> Self:
return self

def collect(self) -> ArrowDataFrame:
return ArrowDataFrame(self._native_frame, backend_version=self._backend_version)
return ArrowDataFrame(
self._native_frame, backend_version=self._backend_version, dtypes=self._dtypes
)

def clone(self) -> Self:
msg = "clone is not yet supported on PyArrow tables"
Expand Down Expand Up @@ -541,7 +560,12 @@ def is_duplicated(self: Self) -> ArrowSeries:
).column(f"{col_token}_count"),
1,
)
return ArrowSeries(is_duplicated, name="", backend_version=self._backend_version)
return ArrowSeries(
is_duplicated,
name="",
backend_version=self._backend_version,
dtypes=self._dtypes,
)

def is_unique(self: Self) -> ArrowSeries:
import pyarrow.compute as pc # ignore-banned-import()
Expand All @@ -551,7 +575,10 @@ def is_unique(self: Self) -> ArrowSeries:
is_duplicated = self.is_duplicated()._native_series

return ArrowSeries(
pc.invert(is_duplicated), name="", backend_version=self._backend_version
pc.invert(is_duplicated),
name="",
backend_version=self._backend_version,
dtypes=self._dtypes,
)

def unique(
Expand Down
27 changes: 24 additions & 3 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import IntoArrowExpr
from narwhals.dtypes import DType
from narwhals.typing import DTypes


class ArrowExpr:
Expand All @@ -29,6 +30,7 @@ def __init__(
root_names: list[str] | None,
output_names: list[str] | None,
backend_version: tuple[int, ...],
dtypes: DTypes,
) -> None:
self._call = call
self._depth = depth
Expand All @@ -38,6 +40,7 @@ def __init__(
self._output_names = output_names
self._implementation = Implementation.PYARROW
self._backend_version = backend_version
self._dtypes = dtypes

def __repr__(self) -> str: # pragma: no cover
return (
Expand All @@ -50,7 +53,10 @@ def __repr__(self) -> str: # pragma: no cover

@classmethod
def from_column_names(
cls: type[Self], *column_names: str, backend_version: tuple[int, ...]
cls: type[Self],
*column_names: str,
backend_version: tuple[int, ...],
dtypes: DTypes,
) -> Self:
from narwhals._arrow.series import ArrowSeries

Expand All @@ -60,6 +66,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
df._native_frame[column_name],
name=column_name,
backend_version=df._backend_version,
dtypes=df._dtypes,
)
for column_name in column_names
]
Expand All @@ -71,11 +78,15 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
root_names=list(column_names),
output_names=list(column_names),
backend_version=backend_version,
dtypes=dtypes,
)

@classmethod
def from_column_indices(
cls: type[Self], *column_indices: int, backend_version: tuple[int, ...]
cls: type[Self],
*column_indices: int,
backend_version: tuple[int, ...],
dtypes: DTypes,
) -> Self:
from narwhals._arrow.series import ArrowSeries

Expand All @@ -85,6 +96,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
df._native_frame[column_index],
name=df._native_frame.column_names[column_index],
backend_version=df._backend_version,
dtypes=df._dtypes,
)
for column_index in column_indices
]
Expand All @@ -96,12 +108,13 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
root_names=None,
output_names=None,
backend_version=backend_version,
dtypes=dtypes,
)

def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace

return ArrowNamespace(backend_version=self._backend_version)
return ArrowNamespace(backend_version=self._backend_version, dtypes=self._dtypes)

def __narwhals_expr__(self) -> None: ...

Expand Down Expand Up @@ -246,6 +259,7 @@ def alias(self, name: str) -> Self:
root_names=self._root_names,
output_names=[name],
backend_version=self._backend_version,
dtypes=self._dtypes,
)

def null_count(self) -> Self:
Expand Down Expand Up @@ -352,6 +366,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
root_names=self._root_names,
output_names=self._output_names,
backend_version=self._backend_version,
dtypes=self._dtypes,
)

def mode(self: Self) -> Self:
Expand Down Expand Up @@ -573,6 +588,7 @@ def keep(self: Self) -> ArrowExpr:
root_names=root_names,
output_names=root_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
)

def map(self: Self, function: Callable[[str], str]) -> ArrowExpr:
Expand All @@ -598,6 +614,7 @@ def map(self: Self, function: Callable[[str], str]) -> ArrowExpr:
root_names=root_names,
output_names=output_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
)

def prefix(self: Self, prefix: str) -> ArrowExpr:
Expand All @@ -621,6 +638,7 @@ def prefix(self: Self, prefix: str) -> ArrowExpr:
root_names=root_names,
output_names=output_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
)

def suffix(self: Self, suffix: str) -> ArrowExpr:
Expand All @@ -645,6 +663,7 @@ def suffix(self: Self, suffix: str) -> ArrowExpr:
root_names=root_names,
output_names=output_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
)

def to_lowercase(self: Self) -> ArrowExpr:
Expand All @@ -669,6 +688,7 @@ def to_lowercase(self: Self) -> ArrowExpr:
root_names=root_names,
output_names=output_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
)

def to_uppercase(self: Self) -> ArrowExpr:
Expand All @@ -693,4 +713,5 @@ def to_uppercase(self: Self) -> ArrowExpr:
root_names=root_names,
output_names=output_names,
backend_version=self._expr._backend_version,
dtypes=self._expr._dtypes,
)
Loading

0 comments on commit 2139f93

Please sign in to comment.