Skip to content

Commit 4a6c4d1

Browse files
refactor: dataframe join params (#912)
* refactor: dataframe join params * chore: add description for on params * fix type * chore: change join param * chore: update join params in tpch * oops * chore: final change * Add support for join_keys as a positional argument --------- Co-authored-by: Tim Saucer <[email protected]>
1 parent cbe28cb commit 4a6c4d1

25 files changed

+240
-85
lines changed

docs/source/user-guide/common-operations/joins.rst

+5-5
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ will be included in the resulting DataFrame.
5656

5757
.. ipython:: python
5858
59-
left.join(right, join_keys=(["customer_id"], ["id"]), how="inner")
59+
left.join(right, left_on="customer_id", right_on="id", how="inner")
6060
6161
The parameter ``join_keys`` specifies the columns from the left DataFrame and right DataFrame that contains the values
6262
that should match.
@@ -70,7 +70,7 @@ values for the corresponding columns.
7070

7171
.. ipython:: python
7272
73-
left.join(right, join_keys=(["customer_id"], ["id"]), how="left")
73+
left.join(right, left_on="customer_id", right_on="id", how="left")
7474
7575
Full Join
7676
---------
@@ -80,7 +80,7 @@ is no match. Unmatched rows will have null values.
8080

8181
.. ipython:: python
8282
83-
left.join(right, join_keys=(["customer_id"], ["id"]), how="full")
83+
left.join(right, left_on="customer_id", right_on="id", how="full")
8484
8585
Left Semi Join
8686
--------------
@@ -90,7 +90,7 @@ omitting duplicates with multiple matches in the right table.
9090

9191
.. ipython:: python
9292
93-
left.join(right, join_keys=(["customer_id"], ["id"]), how="semi")
93+
left.join(right, left_on="customer_id", right_on="id", how="semi")
9494
9595
Left Anti Join
9696
--------------
@@ -101,4 +101,4 @@ the right table.
101101

102102
.. ipython:: python
103103
104-
left.join(right, join_keys=(["customer_id"], ["id"]), how="anti")
104+
left.join(right, left_on="customer_id", right_on="id", how="anti")

examples/tpch/_tests.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pytest
1919
from importlib import import_module
2020
import pyarrow as pa
21-
from datafusion import col, lit, functions as F
21+
from datafusion import DataFrame, col, lit, functions as F
2222
from util import get_answer_file
2323

2424

@@ -94,7 +94,7 @@ def check_q17(df):
9494
)
9595
def test_tpch_query_vs_answer_file(query_code: str, answer_file: str):
9696
module = import_module(query_code)
97-
df = module.df
97+
df: DataFrame = module.df
9898

9999
# Treat q17 as a special case. The answer file does not match the spec.
100100
# Running at scale factor 1, we have manually verified this result does
@@ -121,5 +121,5 @@ def test_tpch_query_vs_answer_file(query_code: str, answer_file: str):
121121

122122
cols = list(read_schema.names)
123123

124-
assert df.join(df_expected, (cols, cols), "anti").count() == 0
124+
assert df.join(df_expected, on=cols, how="anti").count() == 0
125125
assert df.count() == df_expected.count()

examples/tpch/q02_minimum_cost_supplier.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,20 @@
8080
# Now that we have the region, find suppliers in that region. Suppliers are tied to their nation
8181
# and nations are tied to the region.
8282

83-
df_nation = df_nation.join(df_region, (["n_regionkey"], ["r_regionkey"]), how="inner")
83+
df_nation = df_nation.join(
84+
df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner"
85+
)
8486
df_supplier = df_supplier.join(
85-
df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner"
87+
df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner"
8688
)
8789

8890
# Now that we know who the potential suppliers are for the part, we can limit out part
8991
# supplies table down. We can further join down to the specific parts we've identified
9092
# as matching the request
9193

92-
df = df_partsupp.join(df_supplier, (["ps_suppkey"], ["s_suppkey"]), how="inner")
94+
df = df_partsupp.join(
95+
df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner"
96+
)
9397

9498
# Locate the minimum cost across all suppliers. There are multiple ways you could do this,
9599
# but one way is to create a window function across all suppliers, find the minimum, and
@@ -111,7 +115,7 @@
111115

112116
df = df.filter(col("min_cost") == col("ps_supplycost"))
113117

114-
df = df.join(df_part, (["ps_partkey"], ["p_partkey"]), how="inner")
118+
df = df.join(df_part, left_on=["ps_partkey"], right_on=["p_partkey"], how="inner")
115119

116120
# From the problem statement, these are the values we wish to output
117121

examples/tpch/q03_shipping_priority.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@
5555

5656
# Join all 3 dataframes
5757

58-
df = df_customer.join(df_orders, (["c_custkey"], ["o_custkey"]), how="inner").join(
59-
df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner"
60-
)
58+
df = df_customer.join(
59+
df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner"
60+
).join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner")
6161

6262
# Compute the revenue
6363

examples/tpch/q04_order_priority_checking.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@
6666
)
6767

6868
# Perform the join to find only orders for which there are lineitems outside of expected range
69-
df = df_orders.join(df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner")
69+
df = df_orders.join(
70+
df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner"
71+
)
7072

7173
# Based on priority, find the number of entries
7274
df = df.aggregate(

examples/tpch/q05_local_supplier_volume.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,18 @@
7676
# Join all the dataframes
7777

7878
df = (
79-
df_customer.join(df_orders, (["c_custkey"], ["o_custkey"]), how="inner")
80-
.join(df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner")
79+
df_customer.join(
80+
df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner"
81+
)
82+
.join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner")
8183
.join(
8284
df_supplier,
83-
(["l_suppkey", "c_nationkey"], ["s_suppkey", "s_nationkey"]),
85+
left_on=["l_suppkey", "c_nationkey"],
86+
right_on=["s_suppkey", "s_nationkey"],
8487
how="inner",
8588
)
86-
.join(df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner")
87-
.join(df_region, (["n_regionkey"], ["r_regionkey"]), how="inner")
89+
.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner")
90+
.join(df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner")
8891
)
8992

9093
# Compute the final result

examples/tpch/q07_volume_shipping.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -90,20 +90,22 @@
9090

9191
# Limit suppliers to either nation
9292
df_supplier = df_supplier.join(
93-
df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner"
93+
df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner"
9494
).select(col("s_suppkey"), col("n_name").alias("supp_nation"))
9595

9696
# Limit customers to either nation
9797
df_customer = df_customer.join(
98-
df_nation, (["c_nationkey"], ["n_nationkey"]), how="inner"
98+
df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner"
9999
).select(col("c_custkey"), col("n_name").alias("cust_nation"))
100100

101101
# Join up all the data frames from line items, and make sure the supplier and customer are in
102102
# different nations.
103103
df = (
104-
df_lineitem.join(df_orders, (["l_orderkey"], ["o_orderkey"]), how="inner")
105-
.join(df_customer, (["o_custkey"], ["c_custkey"]), how="inner")
106-
.join(df_supplier, (["l_suppkey"], ["s_suppkey"]), how="inner")
104+
df_lineitem.join(
105+
df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner"
106+
)
107+
.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner")
108+
.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner")
107109
.filter(col("cust_nation") != col("supp_nation"))
108110
)
109111

examples/tpch/q08_market_share.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -89,27 +89,27 @@
8989

9090
# After this join we have all of the possible sales nations
9191
df_regional_customers = df_regional_customers.join(
92-
df_nation, (["r_regionkey"], ["n_regionkey"]), how="inner"
92+
df_nation, left_on=["r_regionkey"], right_on=["n_regionkey"], how="inner"
9393
)
9494

9595
# Now find the possible customers
9696
df_regional_customers = df_regional_customers.join(
97-
df_customer, (["n_nationkey"], ["c_nationkey"]), how="inner"
97+
df_customer, left_on=["n_nationkey"], right_on=["c_nationkey"], how="inner"
9898
)
9999

100100
# Next find orders for these customers
101101
df_regional_customers = df_regional_customers.join(
102-
df_orders, (["c_custkey"], ["o_custkey"]), how="inner"
102+
df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner"
103103
)
104104

105105
# Find all line items from these orders
106106
df_regional_customers = df_regional_customers.join(
107-
df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner"
107+
df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner"
108108
)
109109

110110
# Limit to the part of interest
111111
df_regional_customers = df_regional_customers.join(
112-
df_part, (["l_partkey"], ["p_partkey"]), how="inner"
112+
df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner"
113113
)
114114

115115
# Compute the volume for each line item
@@ -126,7 +126,7 @@
126126

127127
# Determine the suppliers by the limited nation key we have in our single row df above
128128
df_national_suppliers = df_national_suppliers.join(
129-
df_supplier, (["n_nationkey"], ["s_nationkey"]), how="inner"
129+
df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner"
130130
)
131131

132132
# When we join to the customer dataframe, we don't want to confuse other columns, so only
@@ -141,7 +141,7 @@
141141
# column only from suppliers in the nation we are evaluating.
142142

143143
df = df_regional_customers.join(
144-
df_national_suppliers, (["l_suppkey"], ["s_suppkey"]), how="left"
144+
df_national_suppliers, left_on=["l_suppkey"], right_on=["s_suppkey"], how="left"
145145
)
146146

147147
# Use a case statement to compute the volume sold by suppliers in the nation of interest

examples/tpch/q09_product_type_profit_measure.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,16 @@
6565
df = df_part.filter(F.strpos(col("p_name"), part_color) > lit(0))
6666

6767
# We have a series of joins that get us to limit down to the line items we need
68-
df = df.join(df_lineitem, (["p_partkey"], ["l_partkey"]), how="inner")
69-
df = df.join(df_supplier, (["l_suppkey"], ["s_suppkey"]), how="inner")
70-
df = df.join(df_orders, (["l_orderkey"], ["o_orderkey"]), how="inner")
68+
df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner")
69+
df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner")
70+
df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner")
7171
df = df.join(
72-
df_partsupp, (["l_suppkey", "l_partkey"], ["ps_suppkey", "ps_partkey"]), how="inner"
72+
df_partsupp,
73+
left_on=["l_suppkey", "l_partkey"],
74+
right_on=["ps_suppkey", "ps_partkey"],
75+
how="inner",
7376
)
74-
df = df.join(df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner")
77+
df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner")
7578

7679
# Compute the intermediate values and limit down to the expressions we need
7780
df = df.select(

examples/tpch/q10_returned_item_reporting.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
col("o_orderdate") < date_start_of_quarter + interval_one_quarter
7575
)
7676

77-
df = df.join(df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner")
77+
df = df.join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner")
7878

7979
# Compute the revenue
8080
df = df.aggregate(
@@ -83,8 +83,8 @@
8383
)
8484

8585
# Now join in the customer data
86-
df = df.join(df_customer, (["o_custkey"], ["c_custkey"]), how="inner")
87-
df = df.join(df_nation, (["c_nationkey"], ["n_nationkey"]), how="inner")
86+
df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner")
87+
df = df.join(df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner")
8888

8989
# These are the columns the problem statement requires
9090
df = df.select(

examples/tpch/q11_important_stock_identification.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@
5252

5353
# Find part supplies of within this target nation
5454

55-
df = df_nation.join(df_supplier, (["n_nationkey"], ["s_nationkey"]), how="inner")
55+
df = df_nation.join(
56+
df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner"
57+
)
5658

57-
df = df.join(df_partsupp, (["s_suppkey"], ["ps_suppkey"]), how="inner")
59+
df = df.join(df_partsupp, left_on=["s_suppkey"], right_on=["ps_suppkey"], how="inner")
5860

5961

6062
# Compute the value of individual parts

examples/tpch/q12_ship_mode_order_priority.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575

7676

7777
# We need order priority, so join order df to line item
78-
df = df.join(df_orders, (["l_orderkey"], ["o_orderkey"]), how="inner")
78+
df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner")
7979

8080
# Restrict to line items we care about based on the problem statement.
8181
df = df.filter(col("l_commitdate") < col("l_receiptdate"))

examples/tpch/q13_customer_distribution.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@
4949
)
5050

5151
# Since we may have customers with no orders we must do a left join
52-
df = df_customer.join(df_orders, (["c_custkey"], ["o_custkey"]), how="left")
52+
df = df_customer.join(
53+
df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="left"
54+
)
5355

5456
# Find the number of orders for each customer
5557
df = df.aggregate([col("c_custkey")], [F.count(col("o_custkey")).alias("c_count")])

examples/tpch/q14_promotion_effect.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@
5757
)
5858

5959
# Left join so we can sum up the promo parts different from other parts
60-
df = df_lineitem.join(df_part, (["l_partkey"], ["p_partkey"]), "left")
60+
df = df_lineitem.join(
61+
df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="left"
62+
)
6163

6264
# Make a factor of 1.0 if it is a promotion, 0.0 otherwise
6365
df = df.with_column("promo_factor", F.coalesce(col("promo_factor"), lit(0.0)))

examples/tpch/q15_top_supplier.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676

7777
# Now that we know the supplier(s) with maximum revenue, get the rest of their information
7878
# from the supplier table
79-
df = df.join(df_supplier, (["l_suppkey"], ["s_suppkey"]), "inner")
79+
df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner")
8080

8181
# Return only the columns requested
8282
df = df.select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue")

examples/tpch/q16_part_supplier_relationship.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656

5757
# Remove unwanted suppliers
5858
df_partsupp = df_partsupp.join(
59-
df_unwanted_suppliers, (["ps_suppkey"], ["s_suppkey"]), "anti"
59+
df_unwanted_suppliers, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="anti"
6060
)
6161

6262
# Select the parts we are interested in
@@ -73,7 +73,9 @@
7373
p_sizes = F.make_array(*[lit(s).cast(pa.int32()) for s in SIZES_OF_INTEREST])
7474
df_part = df_part.filter(~F.array_position(p_sizes, col("p_size")).is_null())
7575

76-
df = df_part.join(df_partsupp, (["p_partkey"], ["ps_partkey"]), "inner")
76+
df = df_part.join(
77+
df_partsupp, left_on=["p_partkey"], right_on=["ps_partkey"], how="inner"
78+
)
7779

7880
df = df.select("p_brand", "p_type", "p_size", "ps_suppkey").distinct()
7981

examples/tpch/q17_small_quantity_order.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
)
5252

5353
# Combine data
54-
df = df.join(df_lineitem, (["p_partkey"], ["l_partkey"]), "inner")
54+
df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner")
5555

5656
# Find the average quantity
5757
window_frame = WindowFrame("rows", None, None)

examples/tpch/q18_large_volume_customer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@
5454

5555
# We've identified the orders of interest, now join the additional data
5656
# we are required to report on
57-
df = df.join(df_orders, (["l_orderkey"], ["o_orderkey"]), "inner")
58-
df = df.join(df_customer, (["o_custkey"], ["c_custkey"]), "inner")
57+
df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner")
58+
df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner")
5959

6060
df = df.select(
6161
"c_name", "c_custkey", "o_orderkey", "o_orderdate", "o_totalprice", "total_quantity"

examples/tpch/q19_discounted_revenue.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
(col("l_shipmode") == lit("AIR")) | (col("l_shipmode") == lit("AIR REG"))
7373
)
7474

75-
df = df.join(df_part, (["l_partkey"], ["p_partkey"]), "inner")
75+
df = df.join(df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner")
7676

7777

7878
# Create the user defined function (UDF) definition that does the work

0 commit comments

Comments
 (0)