Skip to content

Commit

Permalink
fixup pyarrow path
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Sep 6, 2024
1 parent d37d2a3 commit 35a6226
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
12 changes: 10 additions & 2 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,24 @@ def agg_arrow(
function_name = remove_prefix(expr._function_name, "col->")
function_name = POLARS_TO_ARROW_AGGREGATIONS.get(function_name, function_name)
for root_name, output_name in zip(expr._root_names, expr._output_names):
if function_name != "len":
if function_name not in ("len", "count_distinct"):
simple_aggregations[output_name] = (
(root_name, function_name),
f"{root_name}_{function_name}",
)
else:
elif function_name == "len":
simple_aggregations[output_name] = (
(root_name, "count", pc.CountOptions(mode="all")),
f"{root_name}_count",
)
elif function_name == "count_distinct":
simple_aggregations[output_name] = (
(root_name, "count_distinct", pc.CountOptions(mode="all")),
f"{root_name}_count_distinct",
)
else: # pragma: no cover
msg = "unreachable code"
raise RuntimeError(msg)

aggs: list[Any] = []
name_mapping = {}
Expand Down
4 changes: 2 additions & 2 deletions tests/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_group_by_n_unique_w_missing(constructor: Any) -> None:
.agg(
nw.col("b").n_unique(),
c_n_unique=nw.col("c").n_unique(),
c_n_unique_other=nw.col("c").n_unique(),
c_n_min=nw.col("b").min(),
d_n_unique=nw.col("d").n_unique(),
)
.sort("a")
Expand All @@ -130,7 +130,7 @@ def test_group_by_n_unique_w_missing(constructor: Any) -> None:
"a": [1, 2],
"b": [2, 1],
"c_n_unique": [1, 1],
"c_n_unique_other": [1, 1],
"c_n_min": [4, 5],
"d_n_unique": [1, 1],
}
compare_dicts(result, expected)
Expand Down

0 comments on commit 35a6226

Please sign in to comment.