Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

try-cudf-fastpath #25

Merged
merged 2 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 36 additions & 60 deletions narwhals/pandas_like/group_by.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from __future__ import annotations

import collections
import os
import warnings
from copy import copy
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterable

from narwhals.pandas_like.utils import dataframe_from_dict
from narwhals.pandas_like.utils import evaluate_simple_aggregation
from narwhals.pandas_like.utils import horizontal_concat
from narwhals.pandas_like.utils import is_simple_aggregation
from narwhals.pandas_like.utils import item
from narwhals.pandas_like.utils import parse_into_exprs
Expand Down Expand Up @@ -43,7 +39,7 @@ def agg(
grouped = df.groupby(
list(self._keys),
sort=False,
as_index=False,
as_index=True,
)
implementation: str = self._df._implementation
output_names: list[str] = copy(self._keys)
Expand All @@ -57,23 +53,13 @@ def agg(
raise ValueError(msg)
output_names.extend(expr._output_names)

if implementation in ("pandas", "modin") and not os.environ.get(
"NARWHALS_FORCE_GENERIC"
):
return agg_pandas(
grouped,
exprs,
self._keys,
output_names,
self._from_dataframe,
)
return agg_generic(
return agg_pandas(
grouped,
exprs,
self._keys,
output_names,
implementation,
self._from_dataframe,
implementation,
)

def _from_dataframe(self, df: PandasDataFrame) -> PandasDataFrame:
Expand All @@ -85,12 +71,13 @@ def _from_dataframe(self, df: PandasDataFrame) -> PandasDataFrame:
)


def agg_pandas(
def agg_pandas( # noqa: PLR0913,PLR0915
grouped: Any,
exprs: list[PandasExpr],
keys: list[str],
output_names: list[str],
from_dataframe: Callable[[Any], PandasDataFrame],
implementation: Any,
) -> PandasDataFrame:
"""
This should be the fastpath, but cuDF is too far behind to use it.
Expand All @@ -100,6 +87,8 @@ def agg_pandas(
"""
import pandas as pd

from narwhals.pandas_like.namespace import PandasNamespace

simple_aggs = []
complex_aggs = []
for expr in exprs:
Expand All @@ -113,17 +102,31 @@ def agg_pandas(
# e.g. agg(pl.len())
assert expr._output_names is not None
for output_name in expr._output_names:
simple_aggregations[output_name] = pd.NamedAgg(
column=keys[0], aggfunc=expr._function_name.replace("len", "size")
simple_aggregations[output_name] = (
keys[0],
expr._function_name.replace("len", "size"),
)
continue

assert expr._root_names is not None
assert expr._output_names is not None
for root_name, output_name in zip(expr._root_names, expr._output_names):
name = remove_prefix(expr._function_name, "col->")
simple_aggregations[output_name] = pd.NamedAgg(column=root_name, aggfunc=name)
result_simple = grouped.agg(**simple_aggregations) if simple_aggregations else None
simple_aggregations[output_name] = (root_name, name)

if simple_aggregations:
aggs = collections.defaultdict(list)
name_mapping = {}
for output_name, named_agg in simple_aggregations.items():
aggs[named_agg[0]].append(named_agg[1])
name_mapping[f"{named_agg[0]}_{named_agg[1]}"] = output_name
result_simple = grouped.agg(aggs)
result_simple.columns = [f"{a}_{b}" for a, b in result_simple.columns]
result_simple = result_simple.rename(columns=name_mapping).reset_index()
else:
result_simple = None

plx = PandasNamespace(implementation=implementation)

def func(df: Any) -> Any:
out_group = []
Expand All @@ -133,7 +136,7 @@ def func(df: Any) -> Any:
for result_keys in results_keys:
out_group.append(item(result_keys._series))
out_names.append(result_keys.name)
return pd.Series(out_group, index=out_names)
return plx.make_native_series(name="", data=out_group, index=out_names)

if complex_aggs:
warnings.warn(
Expand All @@ -143,53 +146,26 @@ def func(df: Any) -> Any:
UserWarning,
stacklevel=2,
)
if parse_version(pd.__version__) < parse_version("2.2.0"):
result_complex = grouped.apply(func)
if implementation == "pandas":
import pandas as pd

if parse_version(pd.__version__) < parse_version("2.2.0"):
result_complex = grouped.apply(func)
else:
result_complex = grouped.apply(func, include_groups=False)
else:
result_complex = grouped.apply(func, include_groups=False)
result_complex = grouped.apply(func)

if result_simple is not None and not complex_aggs:
result = result_simple
elif result_simple is not None and complex_aggs:
result = pd.concat(
[result_simple, result_complex.drop(columns=keys)],
[result_simple, result_complex.reset_index(drop=True)],
axis=1,
copy=False,
)
elif complex_aggs:
result = result_complex
result = result_complex.reset_index()
else:
raise AssertionError("At least one aggregation should have been passed")
return from_dataframe(result.loc[:, output_names])


def agg_generic( # noqa: PLR0913
grouped: Any,
exprs: list[PandasExpr],
group_by_keys: list[str],
output_names: list[str],
implementation: str,
from_dataframe: Callable[[Any], PandasDataFrame],
) -> PandasDataFrame:
dfs: list[Any] = []
to_remove: list[int] = []
for i, expr in enumerate(exprs):
if is_simple_aggregation(expr):
dfs.append(evaluate_simple_aggregation(expr, grouped, group_by_keys))
to_remove.append(i)
exprs = [expr for i, expr in enumerate(exprs) if i not in to_remove]

out: dict[str, list[Any]] = collections.defaultdict(list)
for keys, df_keys in grouped:
for key, name in zip(keys, group_by_keys):
out[name].append(key)
for expr in exprs:
results_keys = expr._call(from_dataframe(df_keys))
for result_keys in results_keys:
out[result_keys.name].append(result_keys.item())

results_keys = dataframe_from_dict(out, implementation=implementation)
results_df = horizontal_concat(
[results_keys, *dfs], implementation=implementation
).loc[:, output_names]
return from_dataframe(results_df)
16 changes: 10 additions & 6 deletions narwhals/pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,19 @@ class PandasNamespace:
Boolean = dtypes.Boolean
String = dtypes.String

def Series(self, name: str, data: list[Any]) -> PandasSeries: # noqa: N802
from narwhals.pandas_like.series import PandasSeries

def make_native_series(self, name: str, data: list[Any], index: Any) -> Any:
if self._implementation == "pandas":
import pandas as pd

return PandasSeries(
pd.Series(name=name, data=data), implementation=self._implementation
)
return pd.Series(name=name, data=data, index=index)
if self._implementation == "modin":
import modin.pandas as mpd

return mpd.Series(name=name, data=data, index=index)
if self._implementation == "cudf":
import cudf

return cudf.Series(name=name, data=data, index=index)
raise NotImplementedError

# --- not in spec ---
Expand Down
68 changes: 68 additions & 0 deletions tests/tpch_q1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,74 @@ def test_q1(library: str) -> None:
compare_dicts(result, expected)


@pytest.mark.parametrize(
"library",
["pandas", "polars"],
)
@pytest.mark.filterwarnings(
"ignore:.*Passing a BlockManager.*:DeprecationWarning",
"ignore:.*Complex.*:UserWarning",
)
def test_q1_w_generic_funcs(library: str) -> None:
if library == "pandas":
df_raw = pd.read_parquet("tests/data/lineitem.parquet")
df_raw["l_shipdate"] = pd.to_datetime(df_raw["l_shipdate"])
elif library == "polars":
df_raw = pl.scan_parquet("tests/data/lineitem.parquet")
var_1 = datetime(1998, 9, 2)
df = nw.LazyFrame(df_raw)
query_result = (
df.filter(nw.col("l_shipdate") <= var_1)
.with_columns(
charge=(
nw.col("l_extendedprice")
* (1.0 - nw.col("l_discount"))
* (1.0 + nw.col("l_tax"))
),
)
.group_by(["l_returnflag", "l_linestatus"])
.agg(
[
nw.sum("l_quantity").alias("sum_qty"),
nw.sum("l_extendedprice").alias("sum_base_price"),
(nw.col("l_extendedprice") * (1 - nw.col("l_discount")))
.sum()
.alias("sum_disc_price"),
nw.col("charge").sum().alias("sum_charge"),
nw.mean("l_quantity").alias("avg_qty"),
nw.mean("l_extendedprice").alias("avg_price"),
nw.mean("l_discount").alias("avg_disc"),
nw.len().alias("count_order"),
],
)
.sort(["l_returnflag", "l_linestatus"])
)
result = query_result.collect().to_dict(as_series=False)
expected = {
"l_returnflag": ["A", "N", "N", "R"],
"l_linestatus": ["F", "F", "O", "F"],
"sum_qty": [2109.0, 29.0, 3682.0, 1876.0],
"sum_base_price": [3114026.44, 39824.83, 5517101.99, 2947892.16],
"sum_disc_price": [2954950.8082, 39028.3334, 5205468.4852, 2816542.4816999994],
"sum_charge": [
3092840.4194289995,
39808.900068,
5406966.873419,
2935797.8313019997,
],
"avg_qty": [27.75, 29.0, 25.047619047619047, 26.422535211267604],
"avg_price": [
40974.032105263155,
39824.83,
37531.30605442177,
41519.607887323946,
],
"avg_disc": [0.05039473684210526, 0.02, 0.05537414965986395, 0.04507042253521127],
"count_order": [76, 1, 147, 71],
}
compare_dicts(result, expected)


@mock.patch.dict(os.environ, {"NARWHALS_FORCE_GENERIC": "1"})
@pytest.mark.filterwarnings("ignore:.*Passing a BlockManager.*:DeprecationWarning")
def test_q1_w_pandas_agg_generic_path() -> None:
Expand Down
10 changes: 5 additions & 5 deletions tpch/q3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def q3(
var_1 = var_2 = datetime(1995, 3, 15)
var_3 = "BUILDING"

customer_ds = nw.LazyFrame(customer_ds_raw)
line_item_ds = nw.LazyFrame(line_item_ds_raw)
orders_ds = nw.LazyFrame(orders_ds_raw)
customer_ds = nw.from_native(customer_ds_raw)
line_item_ds = nw.from_native(line_item_ds_raw)
orders_ds = nw.from_native(orders_ds_raw)

q_final = (
customer_ds.filter(nw.col("c_mktsegment") == var_3)
Expand All @@ -48,7 +48,7 @@ def q3(
.head(10)
)

return nw.to_native(q_final.collect())
return nw.to_native(q_final)


customer_ds = polars.scan_parquet("../tpch-data/s1/customer.parquet")
Expand All @@ -66,5 +66,5 @@ def q3(
customer_ds,
lineitem_ds,
orders_ds,
)
).collect()
)
Loading