From 250f9493b3f40fd2a388f68ed96db6efdfa5c775 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 6 Sep 2024 18:29:25 +0100 Subject: [PATCH] resolve dask case --- narwhals/_dask/group_by.py | 12 +++++++++--- tests/test_group_by.py | 7 +------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 7aeff6220..463d6fc58 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: import dask.dataframe as dd + import pandas as pd from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.expr import DaskExpr @@ -20,11 +21,16 @@ def n_unique() -> dd.Aggregation: import dask.dataframe as dd # ignore-banned-import + def chunk(s: pd.core.groupby.generic.SeriesGroupBy) -> int: + return s.nunique(dropna=False) # type: ignore[no-any-return] + + def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> int: + return s0.sum() # type: ignore[no-any-return] + return dd.Aggregation( name="nunique", - chunk=lambda s: s.apply(lambda x: list(set(x))), - agg=lambda s0: s0.obj.groupby(level=list(range(s0.obj.index.nlevels))).sum(), - finalize=lambda s1: s1.apply(lambda final: len(set(final))), + chunk=chunk, + agg=agg, ) diff --git a/tests/test_group_by.py b/tests/test_group_by.py index da5f7e52c..5c69983f4 100644 --- a/tests/test_group_by.py +++ b/tests/test_group_by.py @@ -113,12 +113,7 @@ def test_group_by_n_unique(constructor: Any) -> None: compare_dicts(result, expected) -def test_group_by_n_unique_w_missing( - constructor: Any, request: pytest.FixtureRequest -) -> None: - if "dask" in str(constructor): - # temporary: let's fix this before merging - request.applymarker(pytest.mark.xfail) +def test_group_by_n_unique_w_missing(constructor: Any) -> None: data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]} result = ( nw.from_native(constructor(data))