Skip to content

Commit

Permalink
q1
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Apr 11, 2024
1 parent c553a99 commit 0bb7759
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 22 deletions.
31 changes: 13 additions & 18 deletions queries/pandas/q1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,28 @@


def q() -> None:
VAR1 = date(1998, 9, 2)

lineitem = utils.get_line_item_ds
line_item_ds = utils.get_line_item_ds
# first call one time to cache in case we don't include the IO times
lineitem()
line_item_ds()

def query() -> pd.DataFrame:
nonlocal lineitem
lineitem = lineitem()
nonlocal line_item_ds
line_item_ds = line_item_ds()

var1 = date(1998, 9, 2)

sel = lineitem.l_shipdate <= VAR1
lineitem_filtered = lineitem[sel]
filt = line_item_ds[line_item_ds["l_shipdate"] <= var1]

# This is lenient towards pandas as normally an optimizer should decide
# that this could be computed before the groupby aggregation.
# Other implementations don't enjoy this benefit.
lineitem_filtered["disc_price"] = lineitem_filtered.l_extendedprice * (
1 - lineitem_filtered.l_discount
)
lineitem_filtered["charge"] = (
lineitem_filtered.l_extendedprice
* (1 - lineitem_filtered.l_discount)
* (1 + lineitem_filtered.l_tax)
filt["disc_price"] = filt.l_extendedprice * (1.0 - filt.l_discount)
filt["charge"] = (
filt.l_extendedprice * (1.0 - filt.l_discount) * (1.0 + filt.l_tax)
)
gb = lineitem_filtered.groupby(["l_returnflag", "l_linestatus"], as_index=False)

total = gb.agg(
gb = filt.groupby(["l_returnflag", "l_linestatus"], as_index=False)
agg = gb.agg(
sum_qty=pd.NamedAgg(column="l_quantity", aggfunc="sum"),
sum_base_price=pd.NamedAgg(column="l_extendedprice", aggfunc="sum"),
sum_disc_price=pd.NamedAgg(column="disc_price", aggfunc="sum"),
Expand All @@ -45,7 +40,7 @@ def query() -> pd.DataFrame:
count_order=pd.NamedAgg(column="l_orderkey", aggfunc="size"),
)

result_df = total.sort_values(["l_returnflag", "l_linestatus"])
result_df = agg.sort_values(["l_returnflag", "l_linestatus"])

return result_df # type: ignore[no-any-return]

Expand Down
8 changes: 4 additions & 4 deletions queries/polars/q1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@


def q() -> None:
var_1 = date(1998, 9, 2)
line_item_ds = utils.get_line_item_ds()

q = utils.get_line_item_ds()
var1 = date(1998, 9, 2)

q_final = (
q.filter(pl.col("l_shipdate") <= var_1)
line_item_ds.filter(pl.col("l_shipdate") <= var1)
.group_by("l_returnflag", "l_linestatus")
.agg(
pl.sum("l_quantity").alias("sum_qty"),
pl.sum("l_extendedprice").alias("sum_base_price"),
(pl.col("l_extendedprice") * (1 - pl.col("l_discount")))
(pl.col("l_extendedprice") * (1.0 - pl.col("l_discount")))
.sum()
.alias("sum_disc_price"),
(
Expand Down

0 comments on commit 0bb7759

Please sign in to comment.