From da25360d6f9289f9d0394d0140d38734ebea6e2c Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Tue, 10 Sep 2024 23:51:11 +0200 Subject: [PATCH] fix: group by no aggregation --- narwhals/_dask/group_by.py | 6 ++++++ narwhals/_pandas_like/group_by.py | 8 +++++--- tests/test_group_by.py | 7 +++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 463d6fc58..d5fbaaf94 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -72,6 +72,7 @@ def agg( output_names.extend(expr._output_names) return agg_dask( + self._df, self._grouped, exprs, self._keys, @@ -88,6 +89,7 @@ def _from_native_frame(self, df: DaskLazyFrame) -> DaskLazyFrame: def agg_dask( + df: DaskLazyFrame, grouped: Any, exprs: list[DaskExpr], keys: list[str], @@ -99,6 +101,10 @@ def agg_dask( - https://github.com/rapidsai/cudf/issues/15118 - https://github.com/rapidsai/cudf/issues/15084 """ + if not exprs: + # No aggregation provided + return df.select(*keys).unique(subset=keys) + all_simple_aggs = True for expr in exprs: if not is_simple_aggregation(expr): diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 97a477dc4..892291d57 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -80,6 +80,7 @@ def agg( dataframe_is_empty=self._df._native_frame.empty, implementation=implementation, backend_version=self._df._backend_version, + native_namespace=self._df.__native_namespace__(), ) def _from_native_frame(self, df: PandasLikeDataFrame) -> PandasLikeDataFrame: @@ -114,6 +115,7 @@ def agg_pandas( # noqa: PLR0915 implementation: Any, backend_version: tuple[int, ...], dataframe_is_empty: bool, + native_namespace: Any, ) -> PandasLikeDataFrame: """ This should be the fastpath, but cuDF is too far behind to use it. @@ -204,9 +206,9 @@ def agg_pandas( # noqa: PLR0915 result_aggs = result_nunique_aggs elif simple_aggs and not nunique_aggs: result_aggs = result_simple_aggs - else: # pragma: no cover - msg = "Congrats, you entered unreachable code. Please report a bug to https://github.com/narwhals-dev/narwhals/issues." - raise RuntimeError(msg) + else: + # No aggregation provided + result_aggs = native_namespace.DataFrame(grouped.groups.keys(), columns=keys) return from_dataframe(result_aggs.loc[:, output_names]) if dataframe_is_empty: diff --git a/tests/test_group_by.py b/tests/test_group_by.py index 4bd3427a5..6f12d06b1 100644 --- a/tests/test_group_by.py +++ b/tests/test_group_by.py @@ -246,3 +246,10 @@ def test_key_with_nulls(constructor: Any, request: Any) -> None: ) expected = {"b": [4.0, 5, float("nan")], "len": [1, 1, 1], "a": [1, 2, 3]} compare_dicts(result, expected) + + +def test_no_agg(constructor: Any) -> None: + result = nw.from_native(constructor(data)).group_by(["a", "b"]).agg().sort("a", "b") + + expected = {"a": [1, 3], "b": [4, 6]} + compare_dicts(result, expected)