Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Dask Support Implementation #484

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from narwhals._pandas_like.utils import validate_dataframe_comparand
from narwhals._pandas_like.utils import validate_indices
from narwhals.dependencies import get_cudf
from narwhals.dependencies import get_dask
from narwhals.dependencies import get_modin
from narwhals.dependencies import get_numpy
from narwhals.dependencies import get_pandas
Expand Down Expand Up @@ -66,6 +67,8 @@ def __native_namespace__(self) -> Any:
return get_modin()
if self._implementation is Implementation.CUDF: # pragma: no cover
return get_cudf()
if self._implementation == "dask": # pragma: no cover
benrutter marked this conversation as resolved.
Show resolved Hide resolved
return get_dask()
msg = f"Expected pandas/modin/cudf, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)

Expand Down Expand Up @@ -312,9 +315,15 @@ def sort(

# --- convert ---
def collect(self) -> PandasLikeDataFrame:
if self._implementation is Implementation.DASK:
return_df = self._native_dataframe.compute()
return_implementation = Implementation.PANDAS
else:
return_df = self._native_dataframe
return_implementation = self._implementation
return PandasLikeDataFrame(
self._native_dataframe,
implementation=self._implementation,
return_df,
implementation=return_implementation,
backend_version=self._backend_version,
)

Expand Down Expand Up @@ -487,13 +496,17 @@ def to_numpy(self) -> Any:
import numpy as np

return np.hstack([self[col].to_numpy()[:, None] for col in self.columns])
if self._implementation is Implementation.DASK:
return self._native_dataframe.compute().to_numpy()
return self._native_dataframe.to_numpy()

def to_pandas(self) -> Any:
if self._implementation is Implementation.PANDAS:
return self._native_dataframe
if self._implementation is Implementation.MODIN: # pragma: no cover
return self._native_dataframe._to_pandas()
if self._implementation is Implementation.DASK: # pragma: no cover
return self._native_dataframe.compute()
return self._native_dataframe.to_pandas() # pragma: no cover

def write_parquet(self, file: Any) -> Any:
Expand Down
12 changes: 10 additions & 2 deletions narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ class PandasLikeGroupBy:
def __init__(self, df: PandasLikeDataFrame, keys: list[str]) -> None:
self._df = df
self._keys = list(keys)
keywords = {}
if df._implementation != "dask":
benrutter marked this conversation as resolved.
Show resolved Hide resolved
keywords |= {"as_index": True}
self._grouped = self._df._native_dataframe.groupby(
list(self._keys),
sort=False,
as_index=True,
**keywords,
)

def agg(
Expand All @@ -56,13 +59,18 @@ def agg(
raise ValueError(msg)
output_names.extend(expr._output_names)

dataframe_is_empty = (
self._df._native_dataframe.empty
if self._df._implementation != Implementation.DASK
else len(self._df._native_dataframe) == 0
)
Comment on lines +64 to +68
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just use self._df.is_empty()?

return agg_pandas(
self._grouped,
exprs,
self._keys,
output_names,
self._from_native_dataframe,
dataframe_is_empty=self._df._native_dataframe.empty,
dataframe_is_empty=dataframe_is_empty,
implementation=implementation,
backend_version=self._df._backend_version,
)
Expand Down
8 changes: 7 additions & 1 deletion narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from narwhals._pandas_like.expr import PandasLikeExpr
from narwhals._pandas_like.selectors import PandasSelectorNamespace
from narwhals._pandas_like.series import PandasLikeSeries
from narwhals._pandas_like.utils import Implementation
from narwhals._pandas_like.utils import create_native_series
from narwhals._pandas_like.utils import horizontal_concat
from narwhals._pandas_like.utils import vertical_concat
Expand Down Expand Up @@ -78,10 +79,15 @@ def _create_expr_from_callable(
def _create_series_from_scalar(
self, value: Any, series: PandasLikeSeries
) -> PandasLikeSeries:
index = (
series._native_series.index[0:1]
if self._implementation is not Implementation.DASK
else None
)
return PandasLikeSeries._from_iterable(
[value],
name=series._native_series.name,
index=series._native_series.index[0:1],
index=index,
implementation=self._implementation,
backend_version=self._backend_version,
)
Expand Down
22 changes: 21 additions & 1 deletion narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from narwhals._pandas_like.utils import Implementation
from narwhals._pandas_like.utils import int_dtype_mapper
from narwhals._pandas_like.utils import native_series_from_iterable
from narwhals._pandas_like.utils import not_implemented_in
from narwhals._pandas_like.utils import reverse_translate_dtype
from narwhals._pandas_like.utils import to_datetime
from narwhals._pandas_like.utils import translate_dtype
from narwhals._pandas_like.utils import validate_column_comparand
from narwhals.dependencies import get_cudf
from narwhals.dependencies import get_dask
from narwhals.dependencies import get_modin
from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_pyarrow_compute
Expand Down Expand Up @@ -107,12 +109,15 @@ def __native_namespace__(self) -> Any:
return get_modin()
if self._implementation is Implementation.CUDF: # pragma: no cover
return get_cudf()
if self._implementation == "dask": # pragma: no cover
benrutter marked this conversation as resolved.
Show resolved Hide resolved
return get_dask()
msg = f"Expected pandas/modin/cudf, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)

def __narwhals_series__(self) -> Self:
return self

@not_implemented_in(Implementation.DASK)
def __getitem__(self, idx: int | slice | Sequence[int]) -> Any:
if isinstance(idx, int):
return self._native_series.iloc[idx]
Expand Down Expand Up @@ -152,7 +157,7 @@ def _from_iterable(
)

def __len__(self) -> int:
return self.shape[0]
return len(self._native_series)

@property
def name(self) -> str:
Expand All @@ -174,6 +179,7 @@ def cast(
dtype = reverse_translate_dtype(dtype, ser.dtype, self._implementation)
return self._from_native_series(ser.astype(dtype))

@not_implemented_in("dask")
def item(self: Self, index: int | None = None) -> Any:
# cuDF doesn't have Series.item().
if index is None:
Expand Down Expand Up @@ -504,10 +510,13 @@ def to_pandas(self) -> Any:
return self._native_series.to_pandas()
elif self._implementation is Implementation.MODIN: # pragma: no cover
return self._native_series._to_pandas()
elif self._implementation is Implementation.DASK: # pragma: no cover
return self._native_series.compute()
msg = f"Unknown implementation: {self._implementation}" # pragma: no cover
raise AssertionError(msg)

# --- descriptive ---
@not_implemented_in("dask")
def is_duplicated(self: Self) -> Self:
return self._from_native_series(self._native_series.duplicated(keep=False))

Expand All @@ -520,9 +529,11 @@ def is_unique(self: Self) -> Self:
def null_count(self: Self) -> int:
return self._native_series.isna().sum() # type: ignore[no-any-return]

@not_implemented_in("dask")
def is_first_distinct(self: Self) -> Self:
return self._from_native_series(~self._native_series.duplicated(keep="first"))

@not_implemented_in("dask")
def is_last_distinct(self: Self) -> Self:
return self._from_native_series(~self._native_series.duplicated(keep="last"))

Expand Down Expand Up @@ -559,6 +570,15 @@ def quantile(
quantile: float,
interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"],
) -> Any:
if self._implementation is Implementation.DASK:
if interpolation == "linear":
return self._native_series.quantile(q=quantile)
message = (
"Dask performs approximate quantile calculations "
"and does not support specific interpolations methods. "
"Interpolation keywords other than 'linear' are not supported"
)
raise NotImplementedError(message)
return self._native_series.quantile(q=quantile, interpolation=interpolation)

def zip_with(self: Self, mask: Any, other: Any) -> PandasLikeSeries:
Expand Down
110 changes: 107 additions & 3 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
import secrets
from enum import Enum
from enum import auto
from functools import wraps
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterable
from typing import TypeVar

from narwhals.dependencies import get_cudf
from narwhals.dependencies import get_dask
from narwhals.dependencies import get_modin
from narwhals.dependencies import get_pandas
from narwhals.utils import isinstance_or_issubclass
Expand All @@ -27,6 +30,7 @@ class Implementation(Enum):
PANDAS = auto()
MODIN = auto()
CUDF = auto()
DASK = auto()


def validate_column_comparand(index: Any, other: Any) -> Any:
Expand All @@ -53,7 +57,10 @@ def validate_column_comparand(index: Any, other: Any) -> Any:
if other.len() == 1:
# broadcast
return other.item()
if other._native_series.index is not index:
if (
other._native_series.index is not index
and other._implementation != Implementation.DASK
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if other._native_series.index is not index and other._implementation is Implementation.DASK? I think we need to raise an error message in that case

return set_axis(
other._native_series,
index,
Expand All @@ -79,15 +86,65 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any:
if other.len() == 1:
# broadcast
return other._native_series.iloc[0]
if other._native_series.index is not index:
if (
other._native_series.index is not index
and other._implementation is not Implementation.DASK
):
return set_axis(
other._native_series,
index,
implementation=other._implementation,
backend_version=other._backend_version,
)
return other._native_series
msg = "Please report a bug" # pragma: no cover
return other._series
raise AssertionError("Please report a bug")


def maybe_evaluate_expr(df: PandasDataFrame, expr: Any) -> Any:
benrutter marked this conversation as resolved.
Show resolved Hide resolved
"""Evaluate `expr` if it's an expression, otherwise return it as is."""
from narwhals._pandas_like.expr import PandasExpr

if isinstance(expr, PandasExpr):
return expr._call(df)
return expr


def parse_into_expr(
implementation: str, into_expr: IntoPandasExpr | IntoArrowExpr
) -> PandasExpr:
"""Parse `into_expr` as an expression.

For example, in Polars, we can do both `df.select('a')` and `df.select(pl.col('a'))`.
We do the same in Narwhals:

- if `into_expr` is already an expression, just return it
- if it's a Series, then convert it to an expression
- if it's a numpy array, then convert it to a Series and then to an expression
- if it's a string, then convert it to an expression
- else, raise
"""
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.series import ArrowSeries
from narwhals._pandas_like.expr import PandasExpr
from narwhals._pandas_like.namespace import PandasNamespace
from narwhals._pandas_like.series import PandasSeries

if implementation == "arrow":
plx: ArrowNamespace | PandasNamespace = ArrowNamespace()
else:
plx = PandasNamespace(implementation=implementation)
if isinstance(into_expr, (PandasExpr, ArrowExpr)):
return into_expr # type: ignore[return-value]
if isinstance(into_expr, (PandasSeries, ArrowSeries)):
return plx._create_expr_from_series(into_expr) # type: ignore[arg-type, return-value]
if isinstance(into_expr, str):
return plx.col(into_expr) # type: ignore[return-value]
if (np := get_numpy()) is not None and isinstance(into_expr, np.ndarray):
series = create_native_series(into_expr, implementation=implementation)
return plx._create_expr_from_series(series) # type: ignore[arg-type, return-value]
msg = f"Expected IntoExpr, got {type(into_expr)}" # pragma: no cover
raise AssertionError(msg)


Expand Down Expand Up @@ -136,6 +193,11 @@ def horizontal_concat(
mpd = get_modin()

return mpd.concat(dfs, axis=1)
if implementation is Implementation.DASK: # pragma: no cover
dd = get_dask()
if hasattr(dfs[0], "_series"):
return dd.concat([i._series for i in dfs], axis=1)
return dd.concat(dfs, axis=1)
msg = f"Unknown implementation: {implementation}" # pragma: no cover
raise TypeError(msg) # pragma: no cover

Expand Down Expand Up @@ -171,6 +233,10 @@ def vertical_concat(
mpd = get_modin()

return mpd.concat(dfs, axis=0)
if implementation == "dask": # pragma: no cover
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

dd = get_dask()

return dd.concat(dfs, axis=0)
msg = f"Unknown implementation: {implementation}" # pragma: no cover
raise TypeError(msg) # pragma: no cover

Expand All @@ -194,6 +260,20 @@ def native_series_from_iterable(
mpd = get_modin()

return mpd.Series(data, name=name, index=index)
if implementation == "arrow":
pa = get_pyarrow()
return pa.chunked_array([data])
if implementation is Implementation.ARROW: # pragma: no cover
benrutter marked this conversation as resolved.
Show resolved Hide resolved
dd = get_dask()
pd = get_pandas()
if hasattr(data[0], "compute"):
return dd.concat([i.to_series() for i in data])
return pd.Series(
data,
name=name,
index=index,
copy=False,
).pipe(dd.from_pandas)
msg = f"Unknown implementation: {implementation}" # pragma: no cover
raise TypeError(msg) # pragma: no cover

Expand All @@ -218,6 +298,9 @@ def set_axis(
kwargs["copy"] = False
else: # pragma: no cover
pass
if implementation == "dask":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

msg = "Setting axis on columns is not currently supported for dask"
raise NotImplementedError(msg)
return obj.set_axis(index, axis=0, **kwargs) # type: ignore[no-any-return, attr-defined]


Expand Down Expand Up @@ -449,6 +532,8 @@ def to_datetime(implementation: Implementation) -> Any:
return get_modin().to_datetime
if implementation is Implementation.CUDF:
return get_cudf().to_datetime
if implementation == "dask":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

return get_dask().to_datetime
raise AssertionError


Expand Down Expand Up @@ -486,3 +571,22 @@ def generate_unique_token(n_bytes: int, columns: list[str]) -> str: # pragma: n
"join operation"
)
raise AssertionError(msg)


def not_implemented_in(*implementations: list[Implementation]) -> Callable:
"""
Produces method decorator to raise not implemented warnings for given implementations
"""

def check_implementation_wrapper(func: Callable) -> Callable:
"""Wraps function to return same function + implementation check"""

@wraps(func)
def wrapped_func(self, *args, **kwargs):
if (implementation := self._implementation) in implementations:
raise NotImplementedError(f"Not implemented in {implementation}")
return func(self, *args, **kwargs)

return wrapped_func

return check_implementation_wrapper
Loading
Loading