Skip to content

Commit

Permalink
resolve dask case
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Sep 6, 2024
1 parent 3774fe0 commit 250f949
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
12 changes: 9 additions & 3 deletions narwhals/_dask/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

if TYPE_CHECKING:
import dask.dataframe as dd
import pandas as pd

from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
Expand All @@ -20,11 +21,16 @@
def n_unique() -> dd.Aggregation:
import dask.dataframe as dd # ignore-banned-import

def chunk(s: pd.core.groupby.generic.SeriesGroupBy) -> int:
return s.nunique(dropna=False) # type: ignore[no-any-return]

def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> int:
return s0.sum() # type: ignore[no-any-return]

return dd.Aggregation(
name="nunique",
chunk=lambda s: s.apply(lambda x: list(set(x))),
agg=lambda s0: s0.obj.groupby(level=list(range(s0.obj.index.nlevels))).sum(),
finalize=lambda s1: s1.apply(lambda final: len(set(final))),
chunk=chunk,
agg=agg,
)


Expand Down
7 changes: 1 addition & 6 deletions tests/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,7 @@ def test_group_by_n_unique(constructor: Any) -> None:
compare_dicts(result, expected)


def test_group_by_n_unique_w_missing(
constructor: Any, request: pytest.FixtureRequest
) -> None:
if "dask" in str(constructor):
# temporary: let's fix this before merging
request.applymarker(pytest.mark.xfail)
def test_group_by_n_unique_w_missing(constructor: Any) -> None:
data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]}
result = (
nw.from_native(constructor(data))
Expand Down

0 comments on commit 250f949

Please sign in to comment.