From d49af40b26001417d79a3963e76793c6f3edf4b3 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Wed, 11 Sep 2024 15:38:01 +0200 Subject: [PATCH] feat: Dask multiple partitions (#940) --- narwhals/_dask/dataframe.py | 13 +++++++++-- narwhals/_dask/expr.py | 14 +++++++++++- tests/conftest.py | 11 ++++++--- tests/expr_and_series/over_test.py | 31 ++++++++++++++++++++++---- tests/expr_and_series/quantile_test.py | 17 ++++++++++++-- tests/frame/filter_test.py | 3 ++- tests/frame/tail_test.py | 29 +++++++++++++++++------- 7 files changed, 97 insertions(+), 21 deletions(-) diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index d4433fb39..180a897bd 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -79,7 +79,9 @@ def filter( and isinstance(predicates[0], list) and all(isinstance(x, bool) for x in predicates[0]) ): - msg = "Filtering with boolean mask is not supported for `DaskLazyFrame`" + msg = ( + "`LazyFrame.filter` is not supported for Dask backend with boolean masks." + ) raise NotImplementedError(msg) from narwhals._dask.namespace import DaskNamespace @@ -329,7 +331,14 @@ def group_by(self, *by: str) -> DaskLazyGroupBy: return DaskLazyGroupBy(self, list(by)) def tail(self: Self, n: int) -> Self: - return self._from_native_frame(self._native_frame.tail(n=n, compute=False)) + native_frame = self._native_frame + n_partitions = native_frame.npartitions + + if n_partitions == 1: + return self._from_native_frame(self._native_frame.tail(n=n, compute=False)) + else: + msg = "`LazyFrame.tail` is not supported for Dask backend with multiple partitions." + raise NotImplementedError(msg) def gather_every(self: Self, n: int, offset: int) -> Self: row_index_token = generate_unique_token(n_bytes=8, columns=self.columns) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index b20552b44..f08af590c 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -524,8 +524,15 @@ def quantile( interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], ) -> Self: if interpolation == "linear": + + def func(_input: dask_expr.Series, _quantile: float) -> dask_expr.Series: + if _input.npartitions > 1: + msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions." + raise NotImplementedError(msg) + return _input.quantile(q=_quantile, method="dask") + return self._from_call( - lambda _input, quantile: _input.quantile(q=quantile, method="dask"), + func, "quantile", quantile, returns_scalar=True, @@ -626,6 +633,11 @@ def func(df: DaskLazyFrame) -> list[Any]: "`nw.col('a', 'b')`\n" ) raise ValueError(msg) + + if df._native_frame.npartitions > 1: + msg = "`Expr.over` is not supported for Dask backend with multiple partitions." + raise NotImplementedError(msg) + tmp = df.group_by(*keys).agg(self) tmp_native = ( df.select(*keys) diff --git a/tests/conftest.py b/tests/conftest.py index cdf4e0be6..011b83265 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -72,9 +72,14 @@ def polars_lazy_constructor(obj: Any) -> pl.LazyFrame: return pl.LazyFrame(obj) -def dask_lazy_constructor(obj: Any) -> IntoFrame: # pragma: no cover +def dask_lazy_p1_constructor(obj: Any) -> IntoFrame: # pragma: no cover dd = get_dask_dataframe() - return dd.from_pandas(pd.DataFrame(obj)) # type: ignore[no-any-return] + return dd.from_dict(obj, npartitions=1) # type: ignore[no-any-return] + + +def dask_lazy_p2_constructor(obj: Any) -> IntoFrame: # pragma: no cover + dd = get_dask_dataframe() + return dd.from_dict(obj, npartitions=2) # type: ignore[no-any-return] def pyarrow_table_constructor(obj: Any) -> IntoDataFrame: @@ -98,7 +103,7 @@ def pyarrow_table_constructor(obj: Any) -> IntoDataFrame: if get_cudf() is not None: eager_constructors.append(cudf_constructor) # pragma: no cover if get_dask_dataframe() is not None: # pragma: no cover - lazy_constructors.append(dask_lazy_constructor) # type: ignore # noqa: PGH003 + lazy_constructors.extend([dask_lazy_p1_constructor, dask_lazy_p2_constructor]) # type: ignore # noqa: PGH003 @pytest.fixture(params=eager_constructors) diff --git a/tests/expr_and_series/over_test.py b/tests/expr_and_series/over_test.py index fb01a3cfd..17b07cc1e 100644 --- a/tests/expr_and_series/over_test.py +++ b/tests/expr_and_series/over_test.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext as does_not_raise from typing import Any import pytest @@ -14,26 +15,48 @@ def test_over_single(constructor: Any) -> None: df = nw.from_native(constructor(data)) - result = df.with_columns(c_max=nw.col("c").max().over("a")) expected = { "a": ["a", "a", "b", "b", "b"], "b": [1, 2, 3, 5, 3], "c": [5, 4, 3, 2, 1], "c_max": [5, 5, 3, 3, 3], } - compare_dicts(result, expected) + + context = ( + pytest.raises( + NotImplementedError, + match="`Expr.over` is not supported for Dask backend with multiple partitions.", + ) + if "dask_lazy_p2" in str(constructor) + else does_not_raise() + ) + + with context: + result = df.with_columns(c_max=nw.col("c").max().over("a")) + compare_dicts(result, expected) def test_over_multiple(constructor: Any) -> None: df = nw.from_native(constructor(data)) - result = df.with_columns(c_min=nw.col("c").min().over("a", "b")) expected = { "a": ["a", "a", "b", "b", "b"], "b": [1, 2, 3, 5, 3], "c": [5, 4, 3, 2, 1], "c_min": [5, 4, 1, 2, 1], } - compare_dicts(result, expected) + + context = ( + pytest.raises( + NotImplementedError, + match="`Expr.over` is not supported for Dask backend with multiple partitions.", + ) + if "dask_lazy_p2" in str(constructor) + else does_not_raise() + ) + + with context: + result = df.with_columns(c_min=nw.col("c").min().over("a", "b")) + compare_dicts(result, expected) def test_over_invalid(request: Any, constructor: Any) -> None: diff --git a/tests/expr_and_series/quantile_test.py b/tests/expr_and_series/quantile_test.py index d9064541f..5b8ff9334 100644 --- a/tests/expr_and_series/quantile_test.py +++ b/tests/expr_and_series/quantile_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +from contextlib import nullcontext as does_not_raise from typing import Any from typing import Literal @@ -28,12 +29,24 @@ def test_quantile_expr( ) -> None: if "dask" in str(constructor) and interpolation != "linear": request.applymarker(pytest.mark.xfail) + q = 0.3 data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_raw = constructor(data) df = nw.from_native(df_raw) - result = df.select(nw.all().quantile(quantile=q, interpolation=interpolation)) - compare_dicts(result, expected) + + context = ( + pytest.raises( + NotImplementedError, + match="`Expr.quantile` is not supported for Dask backend with multiple partitions.", + ) + if "dask_lazy_p2" in str(constructor) + else does_not_raise() + ) + + with context: + result = df.select(nw.all().quantile(quantile=q, interpolation=interpolation)) + compare_dicts(result, expected) @pytest.mark.parametrize( diff --git a/tests/frame/filter_test.py b/tests/frame/filter_test.py index 609f8ef91..e7a289feb 100644 --- a/tests/frame/filter_test.py +++ b/tests/frame/filter_test.py @@ -21,7 +21,8 @@ def test_filter_with_boolean_list(constructor: Any) -> None: context = ( pytest.raises( - NotImplementedError, match="Filtering with boolean mask is not supported" + NotImplementedError, + match="`LazyFrame.filter` is not supported for Dask backend with boolean masks.", ) if "dask" in str(constructor) else does_not_raise() diff --git a/tests/frame/tail_test.py b/tests/frame/tail_test.py index e279caba9..b64d9fa6c 100644 --- a/tests/frame/tail_test.py +++ b/tests/frame/tail_test.py @@ -1,7 +1,10 @@ from __future__ import annotations +from contextlib import nullcontext as does_not_raise from typing import Any +import pytest + import narwhals.stable.v1 as nw from tests.utils import compare_dicts @@ -13,14 +16,24 @@ def test_tail(constructor: Any) -> None: df_raw = constructor(data) df = nw.from_native(df_raw).lazy() - result = df.tail(2) - compare_dicts(result, expected) + context = ( + pytest.raises( + NotImplementedError, + match="`LazyFrame.tail` is not supported for Dask backend with multiple partitions.", + ) + if "dask_lazy_p2" in str(constructor) + else does_not_raise() + ) + + with context: + result = df.tail(2) + compare_dicts(result, expected) - result = df.collect().tail(2) # type: ignore[assignment] - compare_dicts(result, expected) + result = df.collect().tail(2) # type: ignore[assignment] + compare_dicts(result, expected) - result = df.collect().tail(-1) # type: ignore[assignment] - compare_dicts(result, expected) + result = df.collect().tail(-1) # type: ignore[assignment] + compare_dicts(result, expected) - result = df.collect().select(nw.col("a").tail(2)) # type: ignore[assignment] - compare_dicts(result, {"a": expected["a"]}) + result = df.collect().select(nw.col("a").tail(2)) # type: ignore[assignment] + compare_dicts(result, {"a": expected["a"]})