Skip to content

Commit

Permalink
speed up pandas groupby
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Mar 18, 2024
1 parent 3065018 commit 7844555
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 46 deletions.
42 changes: 33 additions & 9 deletions narwhals/pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections
import os
import warnings
from copy import copy
from typing import TYPE_CHECKING
from typing import Any
Expand Down Expand Up @@ -101,12 +102,21 @@ def agg_pandas(
simple_aggs = []
complex_aggs = []
for expr in exprs:
if is_simple_aggregation(expr):
if is_simple_aggregation(expr, implementation="pandas"):
simple_aggs.append(expr)
else:
complex_aggs.append(expr)
simple_aggregations = {}
for expr in simple_aggs:
if expr._depth == 0:
# e.g. agg(pl.len())
assert expr._output_names is not None
for output_name in expr._output_names:
simple_aggregations[output_name] = pd.NamedAgg(
column=keys[0], aggfunc=expr._function_name.replace("len", "size")
)
continue

assert expr._root_names is not None
assert expr._output_names is not None
for root_name, output_name in zip(expr._root_names, expr._output_names):
Expand All @@ -124,17 +134,31 @@ def func(df: Any) -> Any:
out_names.append(result_keys.name)
return pd.Series(out_group, index=out_names)

if parse(pd.__version__) < parse("2.2.0"):
result_complex = grouped.apply(func)
else:
result_complex = grouped.apply(func, include_groups=False)
if complex_aggs:
warnings.warn(
"Found complex group-by expression, which can't be expressed efficiently with the "
"pandas API. If you can, please rewrite your query such that group-by aggregations "
"are simple (e.g. mean, std, min, max, ...).",
UserWarning,
stacklevel=2,
)
if parse(pd.__version__) < parse("2.2.0"):
result_complex = grouped.apply(func)
else:
result_complex = grouped.apply(func, include_groups=False)

if result_simple is not None:
if result_simple is not None and not complex_aggs:
result = result_simple
elif result_simple is not None and complex_aggs:
result = pd.concat(
[result_simple, result_complex.drop(columns=keys)], axis=1, copy=False
[result_simple, result_complex.drop(columns=keys)],
axis=1,
copy=False,
)
else:
elif complex_aggs:
result = result_complex
else:
raise AssertionError("At least one aggregation should have been passed")
return from_dataframe(result.loc[:, output_names])


Expand All @@ -149,7 +173,7 @@ def agg_generic( # noqa: PLR0913
dfs: list[Any] = []
to_remove: list[int] = []
for i, expr in enumerate(exprs):
if is_simple_aggregation(expr):
if is_simple_aggregation(expr, implementation):
dfs.append(evaluate_simple_aggregation(expr, grouped))
to_remove.append(i)
exprs = [expr for i, expr in enumerate(exprs) if i not in to_remove]
Expand Down
12 changes: 10 additions & 2 deletions narwhals/pandas_like/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from copy import copy
from typing import TYPE_CHECKING
from typing import Any
Expand Down Expand Up @@ -195,13 +196,20 @@ def item(s: Any) -> Any:
return s.iloc[0]


def is_simple_aggregation(expr: PandasExpr) -> bool:
def is_simple_aggregation(expr: PandasExpr, implementation: str) -> bool:
return (
expr._function_name is not None
and expr._depth is not None
and expr._depth < 2
# todo: avoid this one?
and expr._root_names is not None
and (
expr._root_names is not None
or (
expr._depth == 0
and implementation == "pandas"
and not os.environ.get("NARWHALS_FORCE_GENERIC")
)
)
)


Expand Down
50 changes: 25 additions & 25 deletions tests/tpch_q1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,24 @@ def test_q1(df_raw: Any) -> None:
df = nw.LazyFrame(df_raw)
query_result = (
df.filter(nw.col("l_shipdate") <= var_1)
.with_columns(
disc_price=nw.col("l_extendedprice") * (1 - nw.col("l_discount")),
charge=(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
),
)
.group_by(["l_returnflag", "l_linestatus"])
.agg(
[
nw.col("l_quantity").sum().alias("sum_qty"),
nw.col("l_extendedprice").sum().alias("sum_base_price"),
(nw.col("l_extendedprice") * (1 - nw.col("l_discount")))
.sum()
.alias("sum_disc_price"),
(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
)
.sum()
.alias("sum_charge"),
nw.col("l_quantity").mean().alias("avg_qty"),
nw.col("l_extendedprice").mean().alias("avg_price"),
nw.col("l_discount").mean().alias("avg_disc"),
nw.sum("l_quantity").alias("sum_qty"),
nw.sum("l_extendedprice").alias("sum_base_price"),
nw.sum("disc_price").alias("sum_disc_price"),
nw.col("charge").sum().alias("sum_charge"),
nw.mean("l_quantity").alias("avg_qty"),
nw.mean("l_extendedprice").alias("avg_price"),
nw.mean("l_discount").alias("avg_disc"),
nw.len().alias("count_order"),
],
)
Expand Down Expand Up @@ -85,21 +85,21 @@ def test_q1_w_pandas_agg_generic_path(df_raw: Any) -> None:
df = nw.LazyFrame(df_raw)
query_result = (
df.filter(nw.col("l_shipdate") <= var_1)
.with_columns(
disc_price=nw.col("l_extendedprice") * (1 - nw.col("l_discount")),
charge=(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
),
)
.group_by(["l_returnflag", "l_linestatus"])
.agg(
[
nw.sum("l_quantity").alias("sum_qty"),
nw.sum("l_extendedprice").alias("sum_base_price"),
(nw.col("l_extendedprice") * (1 - nw.col("l_discount")))
.sum()
.alias("sum_disc_price"),
(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
)
.sum()
.alias("sum_charge"),
nw.sum("disc_price").alias("sum_disc_price"),
nw.col("charge").sum().alias("sum_charge"),
nw.mean("l_quantity").alias("avg_qty"),
nw.mean("l_extendedprice").alias("avg_price"),
nw.mean("l_discount").alias("avg_disc"),
Expand Down
20 changes: 10 additions & 10 deletions tpch/q1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ def q1(df_raw: Any) -> Any:
df = nw.LazyFrame(df_raw)
result = (
df.filter(nw.col("l_shipdate") <= var_1)
.with_columns(
disc_price=nw.col("l_extendedprice") * (1 - nw.col("l_discount")),
charge=(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
),
)
.group_by(["l_returnflag", "l_linestatus"])
.agg(
[
nw.sum("l_quantity").alias("sum_qty"),
nw.sum("l_extendedprice").alias("sum_base_price"),
(nw.col("l_extendedprice") * (1 - nw.col("l_discount")))
.sum()
.alias("sum_disc_price"),
(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
)
.sum()
.alias("sum_charge"),
nw.sum("disc_price").alias("sum_disc_price"),
nw.col("charge").sum().alias("sum_charge"),
nw.mean("l_quantity").alias("avg_qty"),
nw.mean("l_extendedprice").alias("avg_price"),
nw.mean("l_discount").alias("avg_disc"),
Expand Down

0 comments on commit 7844555

Please sign in to comment.