From bb8b3f10690f6c21034cc02a3030676b110fa577 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Thu, 11 Apr 2024 13:41:44 +0200 Subject: [PATCH] q1 --- queries/pandas/q1.py | 31 +++++++++++++------------------ queries/polars/q1.py | 8 ++++---- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/queries/pandas/q1.py b/queries/pandas/q1.py index 6314542..f21a754 100644 --- a/queries/pandas/q1.py +++ b/queries/pandas/q1.py @@ -8,33 +8,28 @@ def q() -> None: - VAR1 = date(1998, 9, 2) - - lineitem = utils.get_line_item_ds + line_item_ds = utils.get_line_item_ds # first call one time to cache in case we don't include the IO times - lineitem() + line_item_ds() def query() -> pd.DataFrame: - nonlocal lineitem - lineitem = lineitem() + nonlocal line_item_ds + line_item_ds = line_item_ds() + + var1 = date(1998, 9, 2) - sel = lineitem.l_shipdate <= VAR1 - lineitem_filtered = lineitem[sel] + filt = line_item_ds[line_item_ds["l_shipdate"] <= var1] # This is lenient towards pandas as normally an optimizer should decide # that this could be computed before the groupby aggregation. # Other implementations don't enjoy this benefit. - lineitem_filtered["disc_price"] = lineitem_filtered.l_extendedprice * ( - 1 - lineitem_filtered.l_discount - ) - lineitem_filtered["charge"] = ( - lineitem_filtered.l_extendedprice - * (1 - lineitem_filtered.l_discount) - * (1 + lineitem_filtered.l_tax) + filt["disc_price"] = filt.l_extendedprice * (1.0 - filt.l_discount) + filt["charge"] = ( + filt.l_extendedprice * (1.0 - filt.l_discount) * (1.0 + filt.l_tax) ) - gb = lineitem_filtered.groupby(["l_returnflag", "l_linestatus"], as_index=False) - total = gb.agg( + gb = filt.groupby(["l_returnflag", "l_linestatus"], as_index=False) + agg = gb.agg( sum_qty=pd.NamedAgg(column="l_quantity", aggfunc="sum"), sum_base_price=pd.NamedAgg(column="l_extendedprice", aggfunc="sum"), sum_disc_price=pd.NamedAgg(column="disc_price", aggfunc="sum"), @@ -45,7 +40,7 @@ def query() -> pd.DataFrame: count_order=pd.NamedAgg(column="l_orderkey", aggfunc="size"), ) - result_df = total.sort_values(["l_returnflag", "l_linestatus"]) + result_df = agg.sort_values(["l_returnflag", "l_linestatus"]) return result_df # type: ignore[no-any-return] diff --git a/queries/polars/q1.py b/queries/polars/q1.py index b36205d..62755cf 100644 --- a/queries/polars/q1.py +++ b/queries/polars/q1.py @@ -8,17 +8,17 @@ def q() -> None: - var_1 = date(1998, 9, 2) + line_item_ds = utils.get_line_item_ds() - q = utils.get_line_item_ds() + var1 = date(1998, 9, 2) q_final = ( - q.filter(pl.col("l_shipdate") <= var_1) + line_item_ds.filter(pl.col("l_shipdate") <= var1) .group_by("l_returnflag", "l_linestatus") .agg( pl.sum("l_quantity").alias("sum_qty"), pl.sum("l_extendedprice").alias("sum_base_price"), - (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))) + (pl.col("l_extendedprice") * (1.0 - pl.col("l_discount"))) .sum() .alias("sum_disc_price"), (