diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 5f0e4a4ca..1b102a297 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -120,16 +120,24 @@ def agg_arrow( function_name = remove_prefix(expr._function_name, "col->") function_name = POLARS_TO_ARROW_AGGREGATIONS.get(function_name, function_name) for root_name, output_name in zip(expr._root_names, expr._output_names): - if function_name != "len": + if function_name not in ("len", "count_distinct"): simple_aggregations[output_name] = ( (root_name, function_name), f"{root_name}_{function_name}", ) - else: + elif function_name == "len": simple_aggregations[output_name] = ( (root_name, "count", pc.CountOptions(mode="all")), f"{root_name}_count", ) + elif function_name == "count_distinct": + simple_aggregations[output_name] = ( + (root_name, "count_distinct", pc.CountOptions(mode="all")), + f"{root_name}_count_distinct", + ) + else: # pragma: no cover + msg = "unreachable code" + raise RuntimeError(msg) aggs: list[Any] = [] name_mapping = {} diff --git a/tests/test_group_by.py b/tests/test_group_by.py index 0134322c2..5c69983f4 100644 --- a/tests/test_group_by.py +++ b/tests/test_group_by.py @@ -121,7 +121,7 @@ def test_group_by_n_unique_w_missing(constructor: Any) -> None: .agg( nw.col("b").n_unique(), c_n_unique=nw.col("c").n_unique(), - c_n_unique_other=nw.col("c").n_unique(), + c_n_min=nw.col("b").min(), d_n_unique=nw.col("d").n_unique(), ) .sort("a") @@ -130,7 +130,7 @@ def test_group_by_n_unique_w_missing(constructor: Any) -> None: "a": [1, 2], "b": [2, 1], "c_n_unique": [1, 1], - "c_n_unique_other": [1, 1], + "c_n_min": [4, 5], "d_n_unique": [1, 1], } compare_dicts(result, expected)