Skip to content

Commit

Permalink
try-cudf-fastpath
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Mar 22, 2024
1 parent 9ab3d4d commit 653d903
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 69 deletions.
124 changes: 66 additions & 58 deletions narwhals/pandas_like/group_by.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from __future__ import annotations

import collections
import os
import warnings
from copy import copy
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterable

from narwhals.pandas_like.utils import dataframe_from_dict
from narwhals.pandas_like.utils import evaluate_simple_aggregation
from narwhals.pandas_like.utils import horizontal_concat
from narwhals.pandas_like.utils import is_simple_aggregation
from narwhals.pandas_like.utils import item
from narwhals.pandas_like.utils import parse_into_exprs
Expand Down Expand Up @@ -43,7 +39,7 @@ def agg(
grouped = df.groupby(
list(self._keys),
sort=False,
as_index=False,
as_index=True,
)
implementation: str = self._df._implementation
output_names: list[str] = copy(self._keys)
Expand All @@ -57,23 +53,13 @@ def agg(
raise ValueError(msg)
output_names.extend(expr._output_names)

if implementation in ("pandas", "modin") and not os.environ.get(
"NARWHALS_FORCE_GENERIC"
):
return agg_pandas(
grouped,
exprs,
self._keys,
output_names,
self._from_dataframe,
)
return agg_generic(
return agg_pandas(
grouped,
exprs,
self._keys,
output_names,
implementation,
self._from_dataframe,
implementation,
)

def _from_dataframe(self, df: PandasDataFrame) -> PandasDataFrame:
Expand All @@ -85,12 +71,13 @@ def _from_dataframe(self, df: PandasDataFrame) -> PandasDataFrame:
)


def agg_pandas(
def agg_pandas( # noqa: PLR0913,PLR0915
grouped: Any,
exprs: list[PandasExpr],
keys: list[str],
output_names: list[str],
from_dataframe: Callable[[Any], PandasDataFrame],
implementation: Any,
) -> PandasDataFrame:
"""
This should be the fastpath, but cuDF is too far behind to use it.
Expand All @@ -100,6 +87,8 @@ def agg_pandas(
"""
import pandas as pd

from narwhals.pandas_like.namespace import PandasNamespace

simple_aggs = []
complex_aggs = []
for expr in exprs:
Expand All @@ -113,17 +102,31 @@ def agg_pandas(
# e.g. agg(pl.len())
assert expr._output_names is not None
for output_name in expr._output_names:
simple_aggregations[output_name] = pd.NamedAgg(
column=keys[0], aggfunc=expr._function_name.replace("len", "size")
simple_aggregations[output_name] = (
keys[0],
expr._function_name.replace("len", "size"),
)
continue

assert expr._root_names is not None
assert expr._output_names is not None
for root_name, output_name in zip(expr._root_names, expr._output_names):
name = remove_prefix(expr._function_name, "col->")
simple_aggregations[output_name] = pd.NamedAgg(column=root_name, aggfunc=name)
result_simple = grouped.agg(**simple_aggregations) if simple_aggregations else None
simple_aggregations[output_name] = (root_name, name)

if simple_aggregations:
aggs = collections.defaultdict(list)
name_mapping = {}
for output_name, named_agg in simple_aggregations.items():
aggs[named_agg[0]].append(named_agg[1])
name_mapping[f"{named_agg[0]}_{named_agg[1]}"] = output_name
result_simple = grouped.agg(aggs)
result_simple.columns = [f"{a}_{b}" for a, b in result_simple.columns]
result_simple = result_simple.rename(columns=name_mapping).reset_index()
else:
result_simple = None

plx = PandasNamespace(implementation=implementation)

def func(df: Any) -> Any:
out_group = []
Expand All @@ -133,7 +136,7 @@ def func(df: Any) -> Any:
for result_keys in results_keys:
out_group.append(item(result_keys._series))
out_names.append(result_keys.name)
return pd.Series(out_group, index=out_names)
return plx.make_native_series(name="", data=out_group, index=out_names)

if complex_aggs:
warnings.warn(
Expand All @@ -143,53 +146,58 @@ def func(df: Any) -> Any:
UserWarning,
stacklevel=2,
)
if parse_version(pd.__version__) < parse_version("2.2.0"):
result_complex = grouped.apply(func)
if implementation == "pandas":
import pandas as pd

if parse_version(pd.__version__) < parse_version("2.2.0"):
result_complex = grouped.apply(func)
else:
result_complex = grouped.apply(func, include_groups=False)
else:
result_complex = grouped.apply(func, include_groups=False)
result_complex = grouped.apply(func)

if result_simple is not None and not complex_aggs:
result = result_simple
elif result_simple is not None and complex_aggs:
result = pd.concat(
[result_simple, result_complex.drop(columns=keys)],
[result_simple, result_complex.reset_index(drop=True)],
axis=1,
copy=False,
)
elif complex_aggs:
result = result_complex
result = result_complex.reset_index()
else:
raise AssertionError("At least one aggregation should have been passed")
return from_dataframe(result.loc[:, output_names])


def agg_generic( # noqa: PLR0913
grouped: Any,
exprs: list[PandasExpr],
group_by_keys: list[str],
output_names: list[str],
implementation: str,
from_dataframe: Callable[[Any], PandasDataFrame],
) -> PandasDataFrame:
dfs: list[Any] = []
to_remove: list[int] = []
for i, expr in enumerate(exprs):
if is_simple_aggregation(expr):
dfs.append(evaluate_simple_aggregation(expr, grouped, group_by_keys))
to_remove.append(i)
exprs = [expr for i, expr in enumerate(exprs) if i not in to_remove]

out: dict[str, list[Any]] = collections.defaultdict(list)
for keys, df_keys in grouped:
for key, name in zip(keys, group_by_keys):
out[name].append(key)
for expr in exprs:
results_keys = expr._call(from_dataframe(df_keys))
for result_keys in results_keys:
out[result_keys.name].append(result_keys.item())

results_keys = dataframe_from_dict(out, implementation=implementation)
results_df = horizontal_concat(
[results_keys, *dfs], implementation=implementation
).loc[:, output_names]
return from_dataframe(results_df)
# def agg_generic(
# grouped: Any,
# exprs: list[PandasExpr],
# group_by_keys: list[str],
# output_names: list[str],
# implementation: str,
# from_dataframe: Callable[[Any], PandasDataFrame],
# ) -> PandasDataFrame:
# dfs: list[Any] = []
# to_remove: list[int] = []
# for i, expr in enumerate(exprs):
# if is_simple_aggregation(expr):
# dfs.append(evaluate_simple_aggregation(expr, grouped, group_by_keys))
# to_remove.append(i)
# exprs = [expr for i, expr in enumerate(exprs) if i not in to_remove]

# out: dict[str, list[Any]] = collections.defaultdict(list)
# for keys, df_keys in grouped:
# for key, name in zip(keys, group_by_keys):
# out[name].append(key)
# for expr in exprs:
# results_keys = expr._call(from_dataframe(df_keys))
# for result_keys in results_keys:
# out[result_keys.name].append(result_keys.item())

# results_keys = dataframe_from_dict(out, implementation=implementation)
# results_df = horizontal_concat(
# [results_keys, *dfs], implementation=implementation
# ).loc[:, output_names]
# return from_dataframe(results_df)
16 changes: 10 additions & 6 deletions narwhals/pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,19 @@ class PandasNamespace:
Boolean = dtypes.Boolean
String = dtypes.String

def Series(self, name: str, data: list[Any]) -> PandasSeries: # noqa: N802
from narwhals.pandas_like.series import PandasSeries

def make_native_series(self, name: str, data: list[Any], index: Any) -> Any:
if self._implementation == "pandas":
import pandas as pd

return PandasSeries(
pd.Series(name=name, data=data), implementation=self._implementation
)
return pd.Series(name=name, data=data, index=index)
if self._implementation == "modin":
import modin.pandas as mpd

return mpd.Series(name=name, data=data, index=index)
if self._implementation == "cudf":
import cudf

return cudf.Series(name=name, data=data, index=index)
raise NotImplementedError

# --- not in spec ---
Expand Down
68 changes: 68 additions & 0 deletions tests/tpch_q1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,74 @@ def test_q1(library: str) -> None:
compare_dicts(result, expected)


@pytest.mark.parametrize(
"library",
["pandas", "polars"],
)
@pytest.mark.filterwarnings(
"ignore:.*Passing a BlockManager.*:DeprecationWarning",
"ignore:.*Complex.*:UserWarning",
)
def test_q1_w_generic_funcs(library: str) -> None:
if library == "pandas":
df_raw = pd.read_parquet("tests/data/lineitem.parquet")
df_raw["l_shipdate"] = pd.to_datetime(df_raw["l_shipdate"])
elif library == "polars":
df_raw = pl.scan_parquet("tests/data/lineitem.parquet")
var_1 = datetime(1998, 9, 2)
df = nw.LazyFrame(df_raw)
query_result = (
df.filter(nw.col("l_shipdate") <= var_1)
.with_columns(
charge=(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
),
)
.group_by(["l_returnflag", "l_linestatus"])
.agg(
[
nw.sum("l_quantity").alias("sum_qty"),
nw.sum("l_extendedprice").alias("sum_base_price"),
(nw.col("l_extendedprice") * (1 - nw.col("l_discount")))
.sum()
.alias("sum_disc_price"),
nw.col("charge").sum().alias("sum_charge"),
nw.mean("l_quantity").alias("avg_qty"),
nw.mean("l_extendedprice").alias("avg_price"),
nw.mean("l_discount").alias("avg_disc"),
nw.len().alias("count_order"),
],
)
.sort(["l_returnflag", "l_linestatus"])
)
result = query_result.collect().to_dict(as_series=False)
expected = {
"l_returnflag": ["A", "N", "N", "R"],
"l_linestatus": ["F", "F", "O", "F"],
"sum_qty": [2109.0, 29.0, 3682.0, 1876.0],
"sum_base_price": [3114026.44, 39824.83, 5517101.99, 2947892.16],
"sum_disc_price": [2954950.8082, 39028.3334, 5205468.4852, 2816542.4816999994],
"sum_charge": [
3092840.4194289995,
39808.900068,
5406966.873419,
2935797.8313019997,
],
"avg_qty": [27.75, 29.0, 25.047619047619047, 26.422535211267604],
"avg_price": [
40974.032105263155,
39824.83,
37531.30605442177,
41519.607887323946,
],
"avg_disc": [0.05039473684210526, 0.02, 0.05537414965986395, 0.04507042253521127],
"count_order": [76, 1, 147, 71],
}
compare_dicts(result, expected)


@mock.patch.dict(os.environ, {"NARWHALS_FORCE_GENERIC": "1"})
@pytest.mark.filterwarnings("ignore:.*Passing a BlockManager.*:DeprecationWarning")
def test_q1_w_pandas_agg_generic_path() -> None:
Expand Down
10 changes: 5 additions & 5 deletions tpch/q3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def q3(
var_1 = var_2 = datetime(1995, 3, 15)
var_3 = "BUILDING"

customer_ds = nw.LazyFrame(customer_ds_raw)
line_item_ds = nw.LazyFrame(line_item_ds_raw)
orders_ds = nw.LazyFrame(orders_ds_raw)
customer_ds = nw.from_native(customer_ds_raw)
line_item_ds = nw.from_native(line_item_ds_raw)
orders_ds = nw.from_native(orders_ds_raw)

q_final = (
customer_ds.filter(nw.col("c_mktsegment") == var_3)
Expand All @@ -48,7 +48,7 @@ def q3(
.head(10)
)

return nw.to_native(q_final.collect())
return nw.to_native(q_final)


customer_ds = polars.scan_parquet("../tpch-data/s1/customer.parquet")
Expand All @@ -66,5 +66,5 @@ def q3(
customer_ds,
lineitem_ds,
orders_ds,
)
).collect()
)

0 comments on commit 653d903

Please sign in to comment.