Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speed up pandas groupby #18

Merged
merged 1 commit into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions narwhals/pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections
import os
import warnings
from copy import copy
from typing import TYPE_CHECKING
from typing import Any
Expand Down Expand Up @@ -101,12 +102,21 @@ def agg_pandas(
simple_aggs = []
complex_aggs = []
for expr in exprs:
if is_simple_aggregation(expr):
if is_simple_aggregation(expr, implementation="pandas"):
simple_aggs.append(expr)
else:
complex_aggs.append(expr)
simple_aggregations = {}
for expr in simple_aggs:
if expr._depth == 0:
# e.g. agg(pl.len())
assert expr._output_names is not None
for output_name in expr._output_names:
simple_aggregations[output_name] = pd.NamedAgg(
column=keys[0], aggfunc=expr._function_name.replace("len", "size")
)
continue

assert expr._root_names is not None
assert expr._output_names is not None
for root_name, output_name in zip(expr._root_names, expr._output_names):
Expand All @@ -124,17 +134,31 @@ def func(df: Any) -> Any:
out_names.append(result_keys.name)
return pd.Series(out_group, index=out_names)

if parse(pd.__version__) < parse("2.2.0"):
result_complex = grouped.apply(func)
else:
result_complex = grouped.apply(func, include_groups=False)
if complex_aggs:
warnings.warn(
"Found complex group-by expression, which can't be expressed efficiently with the "
"pandas API. If you can, please rewrite your query such that group-by aggregations "
"are simple (e.g. mean, std, min, max, ...).",
UserWarning,
stacklevel=2,
)
if parse(pd.__version__) < parse("2.2.0"):
result_complex = grouped.apply(func)
else:
result_complex = grouped.apply(func, include_groups=False)

if result_simple is not None:
if result_simple is not None and not complex_aggs:
result = result_simple
elif result_simple is not None and complex_aggs:
result = pd.concat(
[result_simple, result_complex.drop(columns=keys)], axis=1, copy=False
[result_simple, result_complex.drop(columns=keys)],
axis=1,
copy=False,
)
else:
elif complex_aggs:
result = result_complex
else:
raise AssertionError("At least one aggregation should have been passed")
return from_dataframe(result.loc[:, output_names])


Expand All @@ -149,7 +173,7 @@ def agg_generic( # noqa: PLR0913
dfs: list[Any] = []
to_remove: list[int] = []
for i, expr in enumerate(exprs):
if is_simple_aggregation(expr):
if is_simple_aggregation(expr, implementation):
dfs.append(evaluate_simple_aggregation(expr, grouped))
to_remove.append(i)
exprs = [expr for i, expr in enumerate(exprs) if i not in to_remove]
Expand Down
12 changes: 10 additions & 2 deletions narwhals/pandas_like/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from copy import copy
from typing import TYPE_CHECKING
from typing import Any
Expand Down Expand Up @@ -195,13 +196,20 @@ def item(s: Any) -> Any:
return s.iloc[0]


def is_simple_aggregation(expr: PandasExpr) -> bool:
def is_simple_aggregation(expr: PandasExpr, implementation: str) -> bool:
return (
expr._function_name is not None
and expr._depth is not None
and expr._depth < 2
# todo: avoid this one?
and expr._root_names is not None
and (
expr._root_names is not None
or (
expr._depth == 0
and implementation == "pandas"
and not os.environ.get("NARWHALS_FORCE_GENERIC")
)
)
)


Expand Down
50 changes: 25 additions & 25 deletions tests/tpch_q1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,24 @@ def test_q1(df_raw: Any) -> None:
df = nw.LazyFrame(df_raw)
query_result = (
df.filter(nw.col("l_shipdate") <= var_1)
.with_columns(
disc_price=nw.col("l_extendedprice") * (1 - nw.col("l_discount")),
charge=(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
),
)
.group_by(["l_returnflag", "l_linestatus"])
.agg(
[
nw.col("l_quantity").sum().alias("sum_qty"),
nw.col("l_extendedprice").sum().alias("sum_base_price"),
(nw.col("l_extendedprice") * (1 - nw.col("l_discount")))
.sum()
.alias("sum_disc_price"),
(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
)
.sum()
.alias("sum_charge"),
nw.col("l_quantity").mean().alias("avg_qty"),
nw.col("l_extendedprice").mean().alias("avg_price"),
nw.col("l_discount").mean().alias("avg_disc"),
nw.sum("l_quantity").alias("sum_qty"),
nw.sum("l_extendedprice").alias("sum_base_price"),
nw.sum("disc_price").alias("sum_disc_price"),
nw.col("charge").sum().alias("sum_charge"),
nw.mean("l_quantity").alias("avg_qty"),
nw.mean("l_extendedprice").alias("avg_price"),
nw.mean("l_discount").alias("avg_disc"),
nw.len().alias("count_order"),
],
)
Expand Down Expand Up @@ -85,21 +85,21 @@ def test_q1_w_pandas_agg_generic_path(df_raw: Any) -> None:
df = nw.LazyFrame(df_raw)
query_result = (
df.filter(nw.col("l_shipdate") <= var_1)
.with_columns(
disc_price=nw.col("l_extendedprice") * (1 - nw.col("l_discount")),
charge=(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
),
)
.group_by(["l_returnflag", "l_linestatus"])
.agg(
[
nw.sum("l_quantity").alias("sum_qty"),
nw.sum("l_extendedprice").alias("sum_base_price"),
(nw.col("l_extendedprice") * (1 - nw.col("l_discount")))
.sum()
.alias("sum_disc_price"),
(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
)
.sum()
.alias("sum_charge"),
nw.sum("disc_price").alias("sum_disc_price"),
nw.col("charge").sum().alias("sum_charge"),
nw.mean("l_quantity").alias("avg_qty"),
nw.mean("l_extendedprice").alias("avg_price"),
nw.mean("l_discount").alias("avg_disc"),
Expand Down
20 changes: 10 additions & 10 deletions tpch/q1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ def q1(df_raw: Any) -> Any:
df = nw.LazyFrame(df_raw)
result = (
df.filter(nw.col("l_shipdate") <= var_1)
.with_columns(
disc_price=nw.col("l_extendedprice") * (1 - nw.col("l_discount")),
charge=(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
),
)
.group_by(["l_returnflag", "l_linestatus"])
.agg(
[
nw.sum("l_quantity").alias("sum_qty"),
nw.sum("l_extendedprice").alias("sum_base_price"),
(nw.col("l_extendedprice") * (1 - nw.col("l_discount")))
.sum()
.alias("sum_disc_price"),
(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
)
.sum()
.alias("sum_charge"),
nw.sum("disc_price").alias("sum_disc_price"),
nw.col("charge").sum().alias("sum_charge"),
nw.mean("l_quantity").alias("avg_qty"),
nw.mean("l_extendedprice").alias("avg_price"),
nw.mean("l_discount").alias("avg_disc"),
Expand Down
Loading