Skip to content

Commit 5bc0051

Browse files
authored
improve filter pushdown to join (#5770)
1 parent 533bb5c commit 5bc0051

File tree

9 files changed

+399
-344
lines changed

9 files changed

+399
-344
lines changed

benchmarks/expected-plans/q17.txt

+49-55
Large diffs are not rendered by default.

benchmarks/expected-plans/q19.txt

+34-39
Large diffs are not rendered by default.

benchmarks/expected-plans/q20.txt

+80-85
Large diffs are not rendered by default.

benchmarks/expected-plans/q7.txt

+88-93
Large diffs are not rendered by default.

datafusion/core/src/physical_plan/planner.rs

+4-5
Original file line numberDiff line numberDiff line change
@@ -923,21 +923,21 @@ impl DefaultPhysicalPlanner {
923923

924924
let join_filter = match filter {
925925
Some(expr) => {
926-
// Extract columns from filter expression
926+
// Extract columns from filter expression and saved in a HashSet
927927
let cols = expr.to_columns()?;
928928

929-
// Collect left & right field indices
929+
// Collect left & right field indices, the field indices are sorted in ascending order
930930
let left_field_indices = cols.iter()
931931
.filter_map(|c| match left_df_schema.index_of_column(c) {
932932
Ok(idx) => Some(idx),
933933
_ => None,
934-
})
934+
}).sorted()
935935
.collect::<Vec<_>>();
936936
let right_field_indices = cols.iter()
937937
.filter_map(|c| match right_df_schema.index_of_column(c) {
938938
Ok(idx) => Some(idx),
939939
_ => None,
940-
})
940+
}).sorted()
941941
.collect::<Vec<_>>();
942942

943943
// Collect DFFields and Fields required for intermediate schemas
@@ -957,7 +957,6 @@ impl DefaultPhysicalPlanner {
957957
)
958958
.unzip();
959959

960-
961960
// Construct intermediate schemas used for filtering data and
962961
// convert logical expression to physical according to filter schema
963962
let filter_df_schema = DFSchema::new_with_metadata(filter_df_fields, HashMap::new())?;

datafusion/core/tests/sql/joins.rs

+9-11
Original file line numberDiff line numberDiff line change
@@ -1103,12 +1103,11 @@ async fn reduce_left_join_2() -> Result<()> {
11031103
// the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter.
11041104

11051105
let expected = vec![
1106-
"Explain [plan_type:Utf8, plan:Utf8]",
1107-
" Filter: t2.t2_int < UInt32(10) OR t1.t1_int > UInt32(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
1108-
" Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
1109-
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
1110-
" Filter: t2.t2_int < UInt32(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
1111-
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
1106+
"Explain [plan_type:Utf8, plan:Utf8]",
1107+
" Inner Join: t1.t1_id = t2.t2_id Filter: t2.t2_int < UInt32(10) OR t1.t1_int > UInt32(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
1108+
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
1109+
" Filter: t2.t2_int < UInt32(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
1110+
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
11121111
];
11131112
let formatted = plan.display_indent_schema().to_string();
11141113
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -1188,11 +1187,10 @@ async fn reduce_right_join_2() -> Result<()> {
11881187
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
11891188
let plan = dataframe.into_optimized_plan()?;
11901189
let expected = vec![
1191-
"Explain [plan_type:Utf8, plan:Utf8]",
1192-
" Filter: t1.t1_int != t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
1193-
" Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
1194-
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
1195-
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
1190+
"Explain [plan_type:Utf8, plan:Utf8]",
1191+
" Inner Join: t1.t1_id = t2.t2_id Filter: t1.t1_int != t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
1192+
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
1193+
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
11961194
];
11971195
let formatted = plan.display_indent_schema().to_string();
11981196
let actual: Vec<&str> = formatted.trim().lines().collect();

datafusion/core/tests/sql/predicates.rs

+5-7
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,11 @@ async fn multiple_or_predicates() -> Result<()> {
100100
let expected = vec![
101101
"Explain [plan_type:Utf8, plan:Utf8]",
102102
" Projection: lineitem.l_partkey [l_partkey:Int64]",
103-
" Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size:Int32]",
104-
" Projection: lineitem.l_partkey, lineitem.l_quantity, part.p_brand, part.p_size [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size:Int32]",
105-
" Inner Join: lineitem.l_partkey = part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
106-
" Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
107-
" TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
108-
" Filter: (part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
109-
" TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
103+
" Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
104+
" Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
105+
" TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
106+
" Filter: (part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
107+
" TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
110108
];
111109
let formatted = plan.display_indent_schema().to_string();
112110
let actual: Vec<&str> = formatted.trim().lines().collect();

datafusion/core/tests/sql/subqueries.rs

+12-16
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,18 @@ where c_acctbal < (
5252
let actual = format!("{}", plan.display_indent());
5353
let expected = "Sort: customer.c_custkey ASC NULLS LAST\
5454
\n Projection: customer.c_custkey\
55-
\n Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.__value\
56-
\n Projection: customer.c_custkey, customer.c_acctbal, __scalar_sq_1.__value\
57-
\n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey\
58-
\n TableScan: customer projection=[c_custkey, c_acctbal]\
59-
\n SubqueryAlias: __scalar_sq_1\
60-
\n Projection: orders.o_custkey, SUM(orders.o_totalprice) AS __value\
61-
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]]\
62-
\n Projection: orders.o_custkey, orders.o_totalprice\
63-
\n Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.__value\
64-
\n Projection: orders.o_custkey, orders.o_totalprice, __scalar_sq_2.__value\
65-
\n Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey\
66-
\n TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice]\
67-
\n SubqueryAlias: __scalar_sq_2\
68-
\n Projection: lineitem.l_orderkey, SUM(lineitem.l_extendedprice) AS price AS __value\
69-
\n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]]\
70-
\n TableScan: lineitem projection=[l_orderkey, l_extendedprice]";
55+
\n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.__value\
56+
\n TableScan: customer projection=[c_custkey, c_acctbal]\
57+
\n SubqueryAlias: __scalar_sq_1\
58+
\n Projection: orders.o_custkey, SUM(orders.o_totalprice) AS __value\
59+
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]]\
60+
\n Projection: orders.o_custkey, orders.o_totalprice\
61+
\n Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.__value\
62+
\n TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice]\
63+
\n SubqueryAlias: __scalar_sq_2\
64+
\n Projection: lineitem.l_orderkey, SUM(lineitem.l_extendedprice) AS price AS __value\
65+
\n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]]\
66+
\n TableScan: lineitem projection=[l_orderkey, l_extendedprice]";
7167
assert_eq!(actual, expected);
7268

7369
Ok(())

0 commit comments

Comments
 (0)