Skip to content

Commit

Permalink
have both a fastpath and a slowpath in groupby.agg
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Feb 22, 2024
1 parent 94bf242 commit 65f8457
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 2 deletions.
4 changes: 2 additions & 2 deletions narwhals/pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def func(df: Any) -> Any:
def agg_generic( # noqa: PLR0913
grouped: Any,
exprs: list[Expr],
keys: list[str],
group_by_keys: list[str],
output_names: list[str],
implementation: str,
from_dataframe: Callable[[Any], LazyFrame],
Expand All @@ -174,7 +174,7 @@ def agg_generic( # noqa: PLR0913

out: dict[str, list[Any]] = collections.defaultdict(list)
for keys, df_keys in grouped:
for key, name in zip(keys, keys):
for key, name in zip(keys, group_by_keys):
out[name].append(key)
for expr in exprs:
results_keys = expr._call(from_dataframe(df_keys))
Expand Down
63 changes: 63 additions & 0 deletions tests/tpch_q1_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import os
from datetime import datetime
from typing import Any
from unittest import mock

import polars
import pytest
Expand Down Expand Up @@ -69,3 +71,64 @@ def test_q1(df_raw: Any) -> None:
"count_order": [76, 1, 147, 71],
}
compare_dicts(result, expected)


@pytest.mark.parametrize(
"df_raw",
[
(polars.read_parquet("tests/data/lineitem.parquet").to_pandas()),
],
)
@mock.patch.dict(os.environ, {"NARWHALS_FORCE_GENERIC": "1"})
def test_q1_w_pandas_agg_generic_path(df_raw: Any) -> None:
var_1 = datetime(1998, 9, 2)
df, pl = to_polars_api(df_raw, version="0.20")
query_result = (
df.filter(pl.col("l_shipdate") <= var_1)
.group_by(["l_returnflag", "l_linestatus"])
.agg(
[
pl.sum("l_quantity").alias("sum_qty"),
pl.sum("l_extendedprice").alias("sum_base_price"),
(pl.col("l_extendedprice") * (1 - pl.col("l_discount")))
.sum()
.alias("sum_disc_price"),
(
pl.col("l_extendedprice")
* (1.0 - pl.col("l_discount"))
* (1.0 + pl.col("l_tax"))
)
.sum()
.alias("sum_charge"),
pl.mean("l_quantity").alias("avg_qty"),
pl.mean("l_extendedprice").alias("avg_price"),
pl.mean("l_discount").alias("avg_disc"),
pl.len().alias("count_order"),
],
)
.sort(["l_returnflag", "l_linestatus"])
)
result = query_result.collect().to_dict(as_series=False)
expected = {
"l_returnflag": ["A", "N", "N", "R"],
"l_linestatus": ["F", "F", "O", "F"],
"sum_qty": [2109.0, 29.0, 3682.0, 1876.0],
"sum_base_price": [3114026.44, 39824.83, 5517101.99, 2947892.16],
"sum_disc_price": [2954950.8082, 39028.3334, 5205468.4852, 2816542.4816999994],
"sum_charge": [
3092840.4194289995,
39808.900068,
5406966.873419,
2935797.8313019997,
],
"avg_qty": [27.75, 29.0, 25.047619047619047, 26.422535211267604],
"avg_price": [
40974.032105263155,
39824.83,
37531.30605442177,
41519.607887323946,
],
"avg_disc": [0.05039473684210526, 0.02, 0.05537414965986395, 0.04507042253521127],
"count_order": [76, 1, 147, 71],
}
compare_dicts(result, expected)

0 comments on commit 65f8457

Please sign in to comment.