Skip to content

Commit

Permalink
patch: Improve q10, q11, q2, q3, q4, and q5 queries (#893)
Browse files Browse the repository at this point in the history
  • Loading branch information
luke396 authored Sep 1, 2024
1 parent a1c4289 commit 4a03572
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 55 deletions.
4 changes: 1 addition & 3 deletions tpch/queries/q10.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def query(
var1 = datetime(1993, 10, 1)
var2 = datetime(1994, 1, 1)

result = (
return (
customer_ds.join(orders_ds, left_on="c_custkey", right_on="o_custkey")
.join(lineitem_ds, left_on="o_orderkey", right_on="l_orderkey")
.join(nation_ds, left_on="c_nationkey", right_on="n_nationkey")
Expand Down Expand Up @@ -46,5 +46,3 @@ def query(
.sort(by="revenue", descending=True)
.head(20)
)

return nw.to_native(result)
19 changes: 4 additions & 15 deletions tpch/queries/q11.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
from datetime import datetime

import narwhals as nw
from narwhals.typing import FrameT


@nw.narwhalify
def query(
nation_ds_raw: FrameT,
partsupp_ds_raw: FrameT,
supplier_ds_raw: FrameT,
nation_ds: FrameT,
partsupp_ds: FrameT,
supplier_ds: FrameT,
) -> FrameT:
var1 = datetime(1993, 10, 1)
var2 = datetime(1994, 1, 1)

nation_ds = nw.from_native(nation_ds_raw)
partsupp_ds = nw.from_native(partsupp_ds_raw)
supplier_ds = nw.from_native(supplier_ds_raw)

var1 = "GERMANY"
var2 = 0.0001

Expand All @@ -30,7 +21,7 @@ def query(
* var2
)

q_final = (
return (
q1.with_columns((nw.col("ps_supplycost") * nw.col("ps_availqty")).alias("value"))
.group_by("ps_partkey")
.agg(nw.sum("value"))
Expand All @@ -39,5 +30,3 @@ def query(
.select("ps_partkey", "value")
.sort("value", descending=True)
)

return nw.to_native(q_final)
15 changes: 7 additions & 8 deletions tpch/queries/q2.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import Any

import narwhals as nw
from narwhals.typing import FrameT


@nw.narwhalify
def query(
region_ds: Any,
nation_ds: Any,
supplier_ds: Any,
part_ds: Any,
part_supp_ds: Any,
) -> Any:
region_ds: FrameT,
nation_ds: FrameT,
supplier_ds: FrameT,
part_ds: FrameT,
part_supp_ds: FrameT,
) -> FrameT:
var_1 = 15
var_2 = "BRASS"
var_3 = "EUROPE"
Expand Down
10 changes: 5 additions & 5 deletions tpch/queries/q3.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from datetime import datetime
from typing import Any

import narwhals as nw
from narwhals.typing import FrameT


@nw.narwhalify
def query(
customer_ds: Any,
line_item_ds: Any,
orders_ds: Any,
) -> Any:
customer_ds: FrameT,
line_item_ds: FrameT,
orders_ds: FrameT,
) -> FrameT:
var_1 = var_2 = datetime(1995, 3, 15)
var_3 = "BUILDING"

Expand Down
11 changes: 3 additions & 8 deletions tpch/queries/q4.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@

@nw.narwhalify
def query(
line_item_ds_raw: FrameT,
orders_ds_raw: FrameT,
line_item_ds: FrameT,
orders_ds: FrameT,
) -> FrameT:
var_1 = datetime(1993, 7, 1)
var_2 = datetime(1993, 10, 1)

line_item_ds = nw.from_native(line_item_ds_raw)
orders_ds = nw.from_native(orders_ds_raw)

result = (
return (
line_item_ds.join(orders_ds, left_on="l_orderkey", right_on="o_orderkey")
.filter(
nw.col("o_orderdate").is_between(var_1, var_2, closed="left"),
Expand All @@ -27,5 +24,3 @@ def query(
.sort(by="o_orderpriority")
.with_columns(nw.col("order_count").cast(nw.Int64))
)

return nw.to_native(result)
23 changes: 7 additions & 16 deletions tpch/queries/q5.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,18 @@

@nw.narwhalify
def query(
region_ds_raw: FrameT,
nation_ds_raw: FrameT,
customer_ds_raw: FrameT,
line_item_ds_raw: FrameT,
orders_ds_raw: FrameT,
supplier_ds_raw: FrameT,
region_ds: FrameT,
nation_ds: FrameT,
customer_ds: FrameT,
line_item_ds: FrameT,
orders_ds: FrameT,
supplier_ds: FrameT,
) -> FrameT:
var_1 = "ASIA"
var_2 = datetime(1994, 1, 1)
var_3 = datetime(1995, 1, 1)

region_ds = nw.from_native(region_ds_raw)
nation_ds = nw.from_native(nation_ds_raw)
customer_ds = nw.from_native(customer_ds_raw)
line_item_ds = nw.from_native(line_item_ds_raw)
orders_ds = nw.from_native(orders_ds_raw)
supplier_ds = nw.from_native(supplier_ds_raw)

result = (
return (
region_ds.join(nation_ds, left_on="r_regionkey", right_on="n_regionkey")
.join(customer_ds, left_on="n_nationkey", right_on="c_nationkey")
.join(orders_ds, left_on="c_custkey", right_on="o_custkey")
Expand All @@ -45,5 +38,3 @@ def query(
.agg([nw.sum("revenue")])
.sort(by="revenue", descending=True)
)

return nw.to_native(result)

0 comments on commit 4a03572

Please sign in to comment.