From ab36082da2af849121fce1688b47132ec5ce3fac Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sat, 2 Nov 2024 10:49:49 +0100 Subject: [PATCH] chore: update join params in tpch --- examples/tpch/_tests.py | 6 +++--- examples/tpch/q02_minimum_cost_supplier.py | 12 ++++++++---- examples/tpch/q03_shipping_priority.py | 6 +++--- examples/tpch/q04_order_priority_checking.py | 4 +++- examples/tpch/q05_local_supplier_volume.py | 13 ++++++++----- examples/tpch/q07_volume_shipping.py | 12 +++++++----- examples/tpch/q08_market_share.py | 14 +++++++------- examples/tpch/q09_product_type_profit_measure.py | 13 ++++++++----- examples/tpch/q10_returned_item_reporting.py | 6 +++--- .../tpch/q11_important_stock_identification.py | 6 ++++-- examples/tpch/q12_ship_mode_order_priority.py | 2 +- examples/tpch/q13_customer_distribution.py | 4 +++- examples/tpch/q14_promotion_effect.py | 4 +++- examples/tpch/q15_top_supplier.py | 2 +- examples/tpch/q16_part_supplier_relationship.py | 2 +- examples/tpch/q17_small_quantity_order.py | 2 +- examples/tpch/q18_large_volume_customer.py | 4 ++-- examples/tpch/q19_discounted_revenue.py | 2 +- examples/tpch/q20_potential_part_promotion.py | 11 +++++++---- examples/tpch/q21_suppliers_kept_orders_waiting.py | 4 ++-- examples/tpch/q22_global_sales_opportunity.py | 2 +- 21 files changed, 77 insertions(+), 54 deletions(-) diff --git a/examples/tpch/_tests.py b/examples/tpch/_tests.py index 903b5354..13144ae9 100644 --- a/examples/tpch/_tests.py +++ b/examples/tpch/_tests.py @@ -18,7 +18,7 @@ import pytest from importlib import import_module import pyarrow as pa -from datafusion import col, lit, functions as F +from datafusion import DataFrame, col, lit, functions as F from util import get_answer_file @@ -94,7 +94,7 @@ def check_q17(df): ) def test_tpch_query_vs_answer_file(query_code: str, answer_file: str): module = import_module(query_code) - df = module.df + df: DataFrame = module.df # Treat q17 as a special case. The answer file does not match the spec. # Running at scale factor 1, we have manually verified this result does @@ -121,5 +121,5 @@ def test_tpch_query_vs_answer_file(query_code: str, answer_file: str): cols = list(read_schema.names) - assert df.join(df_expected, (cols, cols), "anti").count() == 0 + assert df.join(df_expected, on=cols, how="anti").count() == 0 assert df.count() == df_expected.count() diff --git a/examples/tpch/q02_minimum_cost_supplier.py b/examples/tpch/q02_minimum_cost_supplier.py index 2440fdad..c4ccf8ad 100644 --- a/examples/tpch/q02_minimum_cost_supplier.py +++ b/examples/tpch/q02_minimum_cost_supplier.py @@ -80,16 +80,20 @@ # Now that we have the region, find suppliers in that region. Suppliers are tied to their nation # and nations are tied to the region. -df_nation = df_nation.join(df_region, (["n_regionkey"], ["r_regionkey"]), how="inner") +df_nation = df_nation.join( + df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner" +) df_supplier = df_supplier.join( - df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner" + df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" ) # Now that we know who the potential suppliers are for the part, we can limit out part # supplies table down. We can further join down to the specific parts we've identified # as matching the request -df = df_partsupp.join(df_supplier, (["ps_suppkey"], ["s_suppkey"]), how="inner") +df = df_partsupp.join( + df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner" +) # Locate the minimum cost across all suppliers. There are multiple ways you could do this, # but one way is to create a window function across all suppliers, find the minimum, and @@ -111,7 +115,7 @@ df = df.filter(col("min_cost") == col("ps_supplycost")) -df = df.join(df_part, (["ps_partkey"], ["p_partkey"]), how="inner") +df = df.join(df_part, left_on=["ps_partkey"], right_on=["p_partkey"], how="inner") # From the problem statement, these are the values we wish to output diff --git a/examples/tpch/q03_shipping_priority.py b/examples/tpch/q03_shipping_priority.py index c4e8f461..5ebab13c 100644 --- a/examples/tpch/q03_shipping_priority.py +++ b/examples/tpch/q03_shipping_priority.py @@ -55,9 +55,9 @@ # Join all 3 dataframes -df = df_customer.join(df_orders, (["c_custkey"], ["o_custkey"]), how="inner").join( - df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner" -) +df = df_customer.join( + df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" +).join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") # Compute the revenue diff --git a/examples/tpch/q04_order_priority_checking.py b/examples/tpch/q04_order_priority_checking.py index f10b74d9..8bf02cb8 100644 --- a/examples/tpch/q04_order_priority_checking.py +++ b/examples/tpch/q04_order_priority_checking.py @@ -66,7 +66,9 @@ ) # Perform the join to find only orders for which there are lineitems outside of expected range -df = df_orders.join(df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner") +df = df_orders.join( + df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner" +) # Based on priority, find the number of entries df = df.aggregate( diff --git a/examples/tpch/q05_local_supplier_volume.py b/examples/tpch/q05_local_supplier_volume.py index 2a83d2d1..413a4acb 100644 --- a/examples/tpch/q05_local_supplier_volume.py +++ b/examples/tpch/q05_local_supplier_volume.py @@ -76,15 +76,18 @@ # Join all the dataframes df = ( - df_customer.join(df_orders, (["c_custkey"], ["o_custkey"]), how="inner") - .join(df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner") + df_customer.join( + df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" + ) + .join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") .join( df_supplier, - (["l_suppkey", "c_nationkey"], ["s_suppkey", "s_nationkey"]), + left_on=["l_suppkey", "c_nationkey"], + right_on=["s_suppkey", "s_nationkey"], how="inner", ) - .join(df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner") - .join(df_region, (["n_regionkey"], ["r_regionkey"]), how="inner") + .join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") + .join(df_region, left_on=["n_regionkey"], right_on=["r_regionkey"], how="inner") ) # Compute the final result diff --git a/examples/tpch/q07_volume_shipping.py b/examples/tpch/q07_volume_shipping.py index a1d7d81a..18c290d9 100644 --- a/examples/tpch/q07_volume_shipping.py +++ b/examples/tpch/q07_volume_shipping.py @@ -90,20 +90,22 @@ # Limit suppliers to either nation df_supplier = df_supplier.join( - df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner" + df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner" ).select(col("s_suppkey"), col("n_name").alias("supp_nation")) # Limit customers to either nation df_customer = df_customer.join( - df_nation, (["c_nationkey"], ["n_nationkey"]), how="inner" + df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner" ).select(col("c_custkey"), col("n_name").alias("cust_nation")) # Join up all the data frames from line items, and make sure the supplier and customer are in # different nations. df = ( - df_lineitem.join(df_orders, (["l_orderkey"], ["o_orderkey"]), how="inner") - .join(df_customer, (["o_custkey"], ["c_custkey"]), how="inner") - .join(df_supplier, (["l_suppkey"], ["s_suppkey"]), how="inner") + df_lineitem.join( + df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner" + ) + .join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") + .join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") .filter(col("cust_nation") != col("supp_nation")) ) diff --git a/examples/tpch/q08_market_share.py b/examples/tpch/q08_market_share.py index 95fc0a87..7138ab65 100644 --- a/examples/tpch/q08_market_share.py +++ b/examples/tpch/q08_market_share.py @@ -89,27 +89,27 @@ # After this join we have all of the possible sales nations df_regional_customers = df_regional_customers.join( - df_nation, (["r_regionkey"], ["n_regionkey"]), how="inner" + df_nation, left_on=["r_regionkey"], right_on=["n_regionkey"], how="inner" ) # Now find the possible customers df_regional_customers = df_regional_customers.join( - df_customer, (["n_nationkey"], ["c_nationkey"]), how="inner" + df_customer, left_on=["n_nationkey"], right_on=["c_nationkey"], how="inner" ) # Next find orders for these customers df_regional_customers = df_regional_customers.join( - df_orders, (["c_custkey"], ["o_custkey"]), how="inner" + df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="inner" ) # Find all line items from these orders df_regional_customers = df_regional_customers.join( - df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner" + df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner" ) # Limit to the part of interest df_regional_customers = df_regional_customers.join( - df_part, (["l_partkey"], ["p_partkey"]), how="inner" + df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner" ) # Compute the volume for each line item @@ -126,7 +126,7 @@ # Determine the suppliers by the limited nation key we have in our single row df above df_national_suppliers = df_national_suppliers.join( - df_supplier, (["n_nationkey"], ["s_nationkey"]), how="inner" + df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner" ) # When we join to the customer dataframe, we don't want to confuse other columns, so only @@ -141,7 +141,7 @@ # column only from suppliers in the nation we are evaluating. df = df_regional_customers.join( - df_national_suppliers, (["l_suppkey"], ["s_suppkey"]), how="left" + df_national_suppliers, left_on=["l_suppkey"], right_on=["s_suppkey"], how="left" ) # Use a case statement to compute the volume sold by suppliers in the nation of interest diff --git a/examples/tpch/q09_product_type_profit_measure.py b/examples/tpch/q09_product_type_profit_measure.py index 0295d302..aa47d76c 100644 --- a/examples/tpch/q09_product_type_profit_measure.py +++ b/examples/tpch/q09_product_type_profit_measure.py @@ -65,13 +65,16 @@ df = df_part.filter(F.strpos(col("p_name"), part_color) > lit(0)) # We have a series of joins that get us to limit down to the line items we need -df = df.join(df_lineitem, (["p_partkey"], ["l_partkey"]), how="inner") -df = df.join(df_supplier, (["l_suppkey"], ["s_suppkey"]), how="inner") -df = df.join(df_orders, (["l_orderkey"], ["o_orderkey"]), how="inner") +df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner") +df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") +df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") df = df.join( - df_partsupp, (["l_suppkey", "l_partkey"], ["ps_suppkey", "ps_partkey"]), how="inner" + df_partsupp, + left_on=["l_suppkey", "l_partkey"], + right_on=["ps_suppkey", "ps_partkey"], + how="inner", ) -df = df.join(df_nation, (["s_nationkey"], ["n_nationkey"]), how="inner") +df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") # Compute the intermediate values and limit down to the expressions we need df = df.select( diff --git a/examples/tpch/q10_returned_item_reporting.py b/examples/tpch/q10_returned_item_reporting.py index 25f81b2f..94b398c1 100644 --- a/examples/tpch/q10_returned_item_reporting.py +++ b/examples/tpch/q10_returned_item_reporting.py @@ -74,7 +74,7 @@ col("o_orderdate") < date_start_of_quarter + interval_one_quarter ) -df = df.join(df_lineitem, (["o_orderkey"], ["l_orderkey"]), how="inner") +df = df.join(df_lineitem, left_on=["o_orderkey"], right_on=["l_orderkey"], how="inner") # Compute the revenue df = df.aggregate( @@ -83,8 +83,8 @@ ) # Now join in the customer data -df = df.join(df_customer, (["o_custkey"], ["c_custkey"]), how="inner") -df = df.join(df_nation, (["c_nationkey"], ["n_nationkey"]), how="inner") +df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") +df = df.join(df_nation, left_on=["c_nationkey"], right_on=["n_nationkey"], how="inner") # These are the columns the problem statement requires df = df.select( diff --git a/examples/tpch/q11_important_stock_identification.py b/examples/tpch/q11_important_stock_identification.py index 86ff2296..707265e1 100644 --- a/examples/tpch/q11_important_stock_identification.py +++ b/examples/tpch/q11_important_stock_identification.py @@ -52,9 +52,11 @@ # Find part supplies of within this target nation -df = df_nation.join(df_supplier, (["n_nationkey"], ["s_nationkey"]), how="inner") +df = df_nation.join( + df_supplier, left_on=["n_nationkey"], right_on=["s_nationkey"], how="inner" +) -df = df.join(df_partsupp, (["s_suppkey"], ["ps_suppkey"]), how="inner") +df = df.join(df_partsupp, left_on=["s_suppkey"], right_on=["ps_suppkey"], how="inner") # Compute the value of individual parts diff --git a/examples/tpch/q12_ship_mode_order_priority.py b/examples/tpch/q12_ship_mode_order_priority.py index c3fc0d2e..def2a6c3 100644 --- a/examples/tpch/q12_ship_mode_order_priority.py +++ b/examples/tpch/q12_ship_mode_order_priority.py @@ -75,7 +75,7 @@ # We need order priority, so join order df to line item -df = df.join(df_orders, (["l_orderkey"], ["o_orderkey"]), how="inner") +df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") # Restrict to line items we care about based on the problem statement. df = df.filter(col("l_commitdate") < col("l_receiptdate")) diff --git a/examples/tpch/q13_customer_distribution.py b/examples/tpch/q13_customer_distribution.py index f8b6c139..67365a96 100644 --- a/examples/tpch/q13_customer_distribution.py +++ b/examples/tpch/q13_customer_distribution.py @@ -49,7 +49,9 @@ ) # Since we may have customers with no orders we must do a left join -df = df_customer.join(df_orders, (["c_custkey"], ["o_custkey"]), how="left") +df = df_customer.join( + df_orders, left_on=["c_custkey"], right_on=["o_custkey"], how="left" +) # Find the number of orders for each customer df = df.aggregate([col("c_custkey")], [F.count(col("o_custkey")).alias("c_count")]) diff --git a/examples/tpch/q14_promotion_effect.py b/examples/tpch/q14_promotion_effect.py index 8224136a..cd26ee2b 100644 --- a/examples/tpch/q14_promotion_effect.py +++ b/examples/tpch/q14_promotion_effect.py @@ -57,7 +57,9 @@ ) # Left join so we can sum up the promo parts different from other parts -df = df_lineitem.join(df_part, (["l_partkey"], ["p_partkey"]), "left") +df = df_lineitem.join( + df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="left" +) # Make a factor of 1.0 if it is a promotion, 0.0 otherwise df = df.with_column("promo_factor", F.coalesce(col("promo_factor"), lit(0.0))) diff --git a/examples/tpch/q15_top_supplier.py b/examples/tpch/q15_top_supplier.py index 44d5dd99..0bc316f7 100644 --- a/examples/tpch/q15_top_supplier.py +++ b/examples/tpch/q15_top_supplier.py @@ -76,7 +76,7 @@ # Now that we know the supplier(s) with maximum revenue, get the rest of their information # from the supplier table -df = df.join(df_supplier, (["l_suppkey"], ["s_suppkey"]), "inner") +df = df.join(df_supplier, left_on=["l_suppkey"], right_on=["s_suppkey"], how="inner") # Return only the columns requested df = df.select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") diff --git a/examples/tpch/q16_part_supplier_relationship.py b/examples/tpch/q16_part_supplier_relationship.py index cbdd9989..dabebaed 100644 --- a/examples/tpch/q16_part_supplier_relationship.py +++ b/examples/tpch/q16_part_supplier_relationship.py @@ -56,7 +56,7 @@ # Remove unwanted suppliers df_partsupp = df_partsupp.join( - df_unwanted_suppliers, (["ps_suppkey"], ["s_suppkey"]), "anti" + df_unwanted_suppliers, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="anti" ) # Select the parts we are interested in diff --git a/examples/tpch/q17_small_quantity_order.py b/examples/tpch/q17_small_quantity_order.py index ff494279..d7b43d49 100644 --- a/examples/tpch/q17_small_quantity_order.py +++ b/examples/tpch/q17_small_quantity_order.py @@ -51,7 +51,7 @@ ) # Combine data -df = df.join(df_lineitem, (["p_partkey"], ["l_partkey"]), "inner") +df = df.join(df_lineitem, left_on=["p_partkey"], right_on=["l_partkey"], how="inner") # Find the average quantity window_frame = WindowFrame("rows", None, None) diff --git a/examples/tpch/q18_large_volume_customer.py b/examples/tpch/q18_large_volume_customer.py index 49761549..165fce03 100644 --- a/examples/tpch/q18_large_volume_customer.py +++ b/examples/tpch/q18_large_volume_customer.py @@ -54,8 +54,8 @@ # We've identified the orders of interest, now join the additional data # we are required to report on -df = df.join(df_orders, (["l_orderkey"], ["o_orderkey"]), "inner") -df = df.join(df_customer, (["o_custkey"], ["c_custkey"]), "inner") +df = df.join(df_orders, left_on=["l_orderkey"], right_on=["o_orderkey"], how="inner") +df = df.join(df_customer, left_on=["o_custkey"], right_on=["c_custkey"], how="inner") df = df.select( "c_name", "c_custkey", "o_orderkey", "o_orderdate", "o_totalprice", "total_quantity" diff --git a/examples/tpch/q19_discounted_revenue.py b/examples/tpch/q19_discounted_revenue.py index c2fe2570..4aed0cba 100644 --- a/examples/tpch/q19_discounted_revenue.py +++ b/examples/tpch/q19_discounted_revenue.py @@ -72,7 +72,7 @@ (col("l_shipmode") == lit("AIR")) | (col("l_shipmode") == lit("AIR REG")) ) -df = df.join(df_part, (["l_partkey"], ["p_partkey"]), "inner") +df = df.join(df_part, left_on=["l_partkey"], right_on=["p_partkey"], how="inner") # Create the user defined function (UDF) definition that does the work diff --git a/examples/tpch/q20_potential_part_promotion.py b/examples/tpch/q20_potential_part_promotion.py index 3a0edb1e..d720cdce 100644 --- a/examples/tpch/q20_potential_part_promotion.py +++ b/examples/tpch/q20_potential_part_promotion.py @@ -70,7 +70,7 @@ ) # This will filter down the line items to the parts of interest -df = df.join(df_part, (["l_partkey"], ["p_partkey"]), "inner") +df = df.join(df_part, left_on="l_partkey", right_on="p_partkey", how="inner") # Compute the total sold and limit ourselves to individual supplier/part combinations df = df.aggregate( @@ -78,15 +78,18 @@ ) df = df.join( - df_partsupp, (["l_partkey", "l_suppkey"], ["ps_partkey", "ps_suppkey"]), "inner" + df_partsupp, + left_on=["l_partkey", "l_suppkey"], + right_on=["ps_partkey", "ps_suppkey"], + how="inner", ) # Find cases of excess quantity df.filter(col("ps_availqty") > lit(0.5) * col("total_sold")) # We could do these joins earlier, but now limit to the nation of interest suppliers -df = df.join(df_supplier, (["ps_suppkey"], ["s_suppkey"]), "inner") -df = df.join(df_nation, (["s_nationkey"], ["n_nationkey"]), "inner") +df = df.join(df_supplier, left_on=["ps_suppkey"], right_on=["s_suppkey"], how="inner") +df = df.join(df_nation, left_on=["s_nationkey"], right_on=["n_nationkey"], how="inner") # Restrict to the requested data per the problem statement df = df.select("s_name", "s_address").distinct() diff --git a/examples/tpch/q21_suppliers_kept_orders_waiting.py b/examples/tpch/q21_suppliers_kept_orders_waiting.py index d3d57ace..991b88eb 100644 --- a/examples/tpch/q21_suppliers_kept_orders_waiting.py +++ b/examples/tpch/q21_suppliers_kept_orders_waiting.py @@ -52,13 +52,13 @@ df_suppliers_of_interest = df_nation.filter(col("n_name") == lit(NATION_OF_INTEREST)) df_suppliers_of_interest = df_suppliers_of_interest.join( - df_supplier, (["n_nationkey"], ["s_nationkey"]), "inner" + df_supplier, left_on="n_nationkey", right_on="s_nationkey", how="inner" ) # Find the failed orders and all their line items df = df_orders.filter(col("o_orderstatus") == lit("F")) -df = df_lineitem.join(df, (["l_orderkey"], ["o_orderkey"]), "inner") +df = df_lineitem.join(df, left_on="l_orderkey", right_on="o_orderkey", how="inner") # Identify the line items for which the order is failed due to. df = df.with_column( diff --git a/examples/tpch/q22_global_sales_opportunity.py b/examples/tpch/q22_global_sales_opportunity.py index e6660e60..72dce528 100644 --- a/examples/tpch/q22_global_sales_opportunity.py +++ b/examples/tpch/q22_global_sales_opportunity.py @@ -62,7 +62,7 @@ df = df.filter(col("c_acctbal") > col("avg_balance")) # Limit results to customers with no orders -df = df.join(df_orders, (["c_custkey"], ["o_custkey"]), "anti") +df = df.join(df_orders, left_on="c_custkey", right_on="o_custkey", how="anti") # Count up the customers and the balances df = df.aggregate(