Skip to content

Commit

Permalink
feat: Dask multiple partitions (#940)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Sep 11, 2024
1 parent d6f3cd7 commit d49af40
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 21 deletions.
13 changes: 11 additions & 2 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def filter(
and isinstance(predicates[0], list)
and all(isinstance(x, bool) for x in predicates[0])
):
msg = "Filtering with boolean mask is not supported for `DaskLazyFrame`"
msg = (
"`LazyFrame.filter` is not supported for Dask backend with boolean masks."
)
raise NotImplementedError(msg)

from narwhals._dask.namespace import DaskNamespace
Expand Down Expand Up @@ -329,7 +331,14 @@ def group_by(self, *by: str) -> DaskLazyGroupBy:
return DaskLazyGroupBy(self, list(by))

def tail(self: Self, n: int) -> Self:
return self._from_native_frame(self._native_frame.tail(n=n, compute=False))
native_frame = self._native_frame
n_partitions = native_frame.npartitions

if n_partitions == 1:
return self._from_native_frame(self._native_frame.tail(n=n, compute=False))
else:
msg = "`LazyFrame.tail` is not supported for Dask backend with multiple partitions."
raise NotImplementedError(msg)

def gather_every(self: Self, n: int, offset: int) -> Self:
row_index_token = generate_unique_token(n_bytes=8, columns=self.columns)
Expand Down
14 changes: 13 additions & 1 deletion narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,15 @@ def quantile(
interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"],
) -> Self:
if interpolation == "linear":

def func(_input: dask_expr.Series, _quantile: float) -> dask_expr.Series:
if _input.npartitions > 1:
msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions."
raise NotImplementedError(msg)
return _input.quantile(q=_quantile, method="dask")

return self._from_call(
lambda _input, quantile: _input.quantile(q=quantile, method="dask"),
func,
"quantile",
quantile,
returns_scalar=True,
Expand Down Expand Up @@ -626,6 +633,11 @@ def func(df: DaskLazyFrame) -> list[Any]:
"`nw.col('a', 'b')`\n"
)
raise ValueError(msg)

if df._native_frame.npartitions > 1:
msg = "`Expr.over` is not supported for Dask backend with multiple partitions."
raise NotImplementedError(msg)

tmp = df.group_by(*keys).agg(self)
tmp_native = (
df.select(*keys)
Expand Down
11 changes: 8 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,14 @@ def polars_lazy_constructor(obj: Any) -> pl.LazyFrame:
return pl.LazyFrame(obj)


def dask_lazy_constructor(obj: Any) -> IntoFrame: # pragma: no cover
def dask_lazy_p1_constructor(obj: Any) -> IntoFrame: # pragma: no cover
dd = get_dask_dataframe()
return dd.from_pandas(pd.DataFrame(obj)) # type: ignore[no-any-return]
return dd.from_dict(obj, npartitions=1) # type: ignore[no-any-return]


def dask_lazy_p2_constructor(obj: Any) -> IntoFrame: # pragma: no cover
dd = get_dask_dataframe()
return dd.from_dict(obj, npartitions=2) # type: ignore[no-any-return]


def pyarrow_table_constructor(obj: Any) -> IntoDataFrame:
Expand All @@ -98,7 +103,7 @@ def pyarrow_table_constructor(obj: Any) -> IntoDataFrame:
if get_cudf() is not None:
eager_constructors.append(cudf_constructor) # pragma: no cover
if get_dask_dataframe() is not None: # pragma: no cover
lazy_constructors.append(dask_lazy_constructor) # type: ignore # noqa: PGH003
lazy_constructors.extend([dask_lazy_p1_constructor, dask_lazy_p2_constructor]) # type: ignore # noqa: PGH003


@pytest.fixture(params=eager_constructors)
Expand Down
31 changes: 27 additions & 4 deletions tests/expr_and_series/over_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import nullcontext as does_not_raise
from typing import Any

import pytest
Expand All @@ -14,26 +15,48 @@

def test_over_single(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.with_columns(c_max=nw.col("c").max().over("a"))
expected = {
"a": ["a", "a", "b", "b", "b"],
"b": [1, 2, 3, 5, 3],
"c": [5, 4, 3, 2, 1],
"c_max": [5, 5, 3, 3, 3],
}
compare_dicts(result, expected)

context = (
pytest.raises(
NotImplementedError,
match="`Expr.over` is not supported for Dask backend with multiple partitions.",
)
if "dask_lazy_p2" in str(constructor)
else does_not_raise()
)

with context:
result = df.with_columns(c_max=nw.col("c").max().over("a"))
compare_dicts(result, expected)


def test_over_multiple(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.with_columns(c_min=nw.col("c").min().over("a", "b"))
expected = {
"a": ["a", "a", "b", "b", "b"],
"b": [1, 2, 3, 5, 3],
"c": [5, 4, 3, 2, 1],
"c_min": [5, 4, 1, 2, 1],
}
compare_dicts(result, expected)

context = (
pytest.raises(
NotImplementedError,
match="`Expr.over` is not supported for Dask backend with multiple partitions.",
)
if "dask_lazy_p2" in str(constructor)
else does_not_raise()
)

with context:
result = df.with_columns(c_min=nw.col("c").min().over("a", "b"))
compare_dicts(result, expected)


def test_over_invalid(request: Any, constructor: Any) -> None:
Expand Down
17 changes: 15 additions & 2 deletions tests/expr_and_series/quantile_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from contextlib import nullcontext as does_not_raise
from typing import Any
from typing import Literal

Expand Down Expand Up @@ -28,12 +29,24 @@ def test_quantile_expr(
) -> None:
if "dask" in str(constructor) and interpolation != "linear":
request.applymarker(pytest.mark.xfail)

q = 0.3
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_raw = constructor(data)
df = nw.from_native(df_raw)
result = df.select(nw.all().quantile(quantile=q, interpolation=interpolation))
compare_dicts(result, expected)

context = (
pytest.raises(
NotImplementedError,
match="`Expr.quantile` is not supported for Dask backend with multiple partitions.",
)
if "dask_lazy_p2" in str(constructor)
else does_not_raise()
)

with context:
result = df.select(nw.all().quantile(quantile=q, interpolation=interpolation))
compare_dicts(result, expected)


@pytest.mark.parametrize(
Expand Down
3 changes: 2 additions & 1 deletion tests/frame/filter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def test_filter_with_boolean_list(constructor: Any) -> None:

context = (
pytest.raises(
NotImplementedError, match="Filtering with boolean mask is not supported"
NotImplementedError,
match="`LazyFrame.filter` is not supported for Dask backend with boolean masks.",
)
if "dask" in str(constructor)
else does_not_raise()
Expand Down
29 changes: 21 additions & 8 deletions tests/frame/tail_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from contextlib import nullcontext as does_not_raise
from typing import Any

import pytest

import narwhals.stable.v1 as nw
from tests.utils import compare_dicts

Expand All @@ -13,14 +16,24 @@ def test_tail(constructor: Any) -> None:
df_raw = constructor(data)
df = nw.from_native(df_raw).lazy()

result = df.tail(2)
compare_dicts(result, expected)
context = (
pytest.raises(
NotImplementedError,
match="`LazyFrame.tail` is not supported for Dask backend with multiple partitions.",
)
if "dask_lazy_p2" in str(constructor)
else does_not_raise()
)

with context:
result = df.tail(2)
compare_dicts(result, expected)

result = df.collect().tail(2) # type: ignore[assignment]
compare_dicts(result, expected)
result = df.collect().tail(2) # type: ignore[assignment]
compare_dicts(result, expected)

result = df.collect().tail(-1) # type: ignore[assignment]
compare_dicts(result, expected)
result = df.collect().tail(-1) # type: ignore[assignment]
compare_dicts(result, expected)

result = df.collect().select(nw.col("a").tail(2)) # type: ignore[assignment]
compare_dicts(result, {"a": expected["a"]})
result = df.collect().select(nw.col("a").tail(2)) # type: ignore[assignment]
compare_dicts(result, {"a": expected["a"]})

0 comments on commit d49af40

Please sign in to comment.