Skip to content

Commit 3e56ed2

Browse files
authored
Enforce JOIN plan to require condition (#15334)
* add check for missing join condition * modify the tests * handle the cross join case * remove invald sql case * fix sqllogictests * fix for cross join case * revert the change for limit.slt * add cross join test
1 parent a0a063d commit 3e56ed2

File tree

8 files changed

+47
-185
lines changed

8 files changed

+47
-185
lines changed

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,14 +1074,19 @@ impl LogicalPlanBuilder {
10741074
let left_keys = left_keys.into_iter().collect::<Result<Vec<Column>>>()?;
10751075
let right_keys = right_keys.into_iter().collect::<Result<Vec<Column>>>()?;
10761076

1077-
let on = left_keys
1077+
let on: Vec<_> = left_keys
10781078
.into_iter()
10791079
.zip(right_keys)
10801080
.map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
10811081
.collect();
10821082
let join_schema =
10831083
build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
10841084

1085+
// Inner type without join condition is cross join
1086+
if join_type != JoinType::Inner && on.is_empty() && filter.is_none() {
1087+
return plan_err!("join condition should not be empty");
1088+
}
1089+
10851090
Ok(Self::new(LogicalPlan::Join(Join {
10861091
left: self.plan,
10871092
right: Arc::new(right),

datafusion/optimizer/src/push_down_limit.rs

Lines changed: 5 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -242,16 +242,13 @@ fn transformed_limit(
242242
fn push_down_join(mut join: Join, limit: usize) -> Transformed<Join> {
243243
use JoinType::*;
244244

245-
fn is_no_join_condition(join: &Join) -> bool {
246-
join.on.is_empty() && join.filter.is_none()
245+
// Cross join is the special case of inner join where there is no join condition. see [LogicalPlanBuilder::cross_join]
246+
fn is_cross_join(join: &Join) -> bool {
247+
join.join_type == Inner && join.on.is_empty() && join.filter.is_none()
247248
}
248249

249-
let (left_limit, right_limit) = if is_no_join_condition(&join) {
250-
match join.join_type {
251-
Left | Right | Full | Inner => (Some(limit), Some(limit)),
252-
LeftAnti | LeftSemi | LeftMark => (Some(limit), None),
253-
RightAnti | RightSemi => (None, Some(limit)),
254-
}
250+
let (left_limit, right_limit) = if is_cross_join(&join) {
251+
(Some(limit), Some(limit))
255252
} else {
256253
match join.join_type {
257254
Left => (Some(limit), None),
@@ -861,167 +858,6 @@ mod test {
861858
assert_optimized_plan_equal(outer_query, expected)
862859
}
863860

864-
#[test]
865-
fn limit_should_push_down_join_without_condition() -> Result<()> {
866-
let table_scan_1 = test_table_scan()?;
867-
let table_scan_2 = test_table_scan_with_name("test2")?;
868-
let left_keys: Vec<&str> = Vec::new();
869-
let right_keys: Vec<&str> = Vec::new();
870-
let plan = LogicalPlanBuilder::from(table_scan_1.clone())
871-
.join(
872-
LogicalPlanBuilder::from(table_scan_2.clone()).build()?,
873-
JoinType::Left,
874-
(left_keys.clone(), right_keys.clone()),
875-
None,
876-
)?
877-
.limit(0, Some(1000))?
878-
.build()?;
879-
880-
let expected = "Limit: skip=0, fetch=1000\
881-
\n Left Join: \
882-
\n Limit: skip=0, fetch=1000\
883-
\n TableScan: test, fetch=1000\
884-
\n Limit: skip=0, fetch=1000\
885-
\n TableScan: test2, fetch=1000";
886-
887-
assert_optimized_plan_equal(plan, expected)?;
888-
889-
let plan = LogicalPlanBuilder::from(table_scan_1.clone())
890-
.join(
891-
LogicalPlanBuilder::from(table_scan_2.clone()).build()?,
892-
JoinType::Right,
893-
(left_keys.clone(), right_keys.clone()),
894-
None,
895-
)?
896-
.limit(0, Some(1000))?
897-
.build()?;
898-
899-
let expected = "Limit: skip=0, fetch=1000\
900-
\n Right Join: \
901-
\n Limit: skip=0, fetch=1000\
902-
\n TableScan: test, fetch=1000\
903-
\n Limit: skip=0, fetch=1000\
904-
\n TableScan: test2, fetch=1000";
905-
906-
assert_optimized_plan_equal(plan, expected)?;
907-
908-
let plan = LogicalPlanBuilder::from(table_scan_1.clone())
909-
.join(
910-
LogicalPlanBuilder::from(table_scan_2.clone()).build()?,
911-
JoinType::Full,
912-
(left_keys.clone(), right_keys.clone()),
913-
None,
914-
)?
915-
.limit(0, Some(1000))?
916-
.build()?;
917-
918-
let expected = "Limit: skip=0, fetch=1000\
919-
\n Full Join: \
920-
\n Limit: skip=0, fetch=1000\
921-
\n TableScan: test, fetch=1000\
922-
\n Limit: skip=0, fetch=1000\
923-
\n TableScan: test2, fetch=1000";
924-
925-
assert_optimized_plan_equal(plan, expected)?;
926-
927-
let plan = LogicalPlanBuilder::from(table_scan_1.clone())
928-
.join(
929-
LogicalPlanBuilder::from(table_scan_2.clone()).build()?,
930-
JoinType::LeftSemi,
931-
(left_keys.clone(), right_keys.clone()),
932-
None,
933-
)?
934-
.limit(0, Some(1000))?
935-
.build()?;
936-
937-
let expected = "Limit: skip=0, fetch=1000\
938-
\n LeftSemi Join: \
939-
\n Limit: skip=0, fetch=1000\
940-
\n TableScan: test, fetch=1000\
941-
\n TableScan: test2";
942-
943-
assert_optimized_plan_equal(plan, expected)?;
944-
945-
let plan = LogicalPlanBuilder::from(table_scan_1.clone())
946-
.join(
947-
LogicalPlanBuilder::from(table_scan_2.clone()).build()?,
948-
JoinType::LeftAnti,
949-
(left_keys.clone(), right_keys.clone()),
950-
None,
951-
)?
952-
.limit(0, Some(1000))?
953-
.build()?;
954-
955-
let expected = "Limit: skip=0, fetch=1000\
956-
\n LeftAnti Join: \
957-
\n Limit: skip=0, fetch=1000\
958-
\n TableScan: test, fetch=1000\
959-
\n TableScan: test2";
960-
961-
assert_optimized_plan_equal(plan, expected)?;
962-
963-
let plan = LogicalPlanBuilder::from(table_scan_1.clone())
964-
.join(
965-
LogicalPlanBuilder::from(table_scan_2.clone()).build()?,
966-
JoinType::RightSemi,
967-
(left_keys.clone(), right_keys.clone()),
968-
None,
969-
)?
970-
.limit(0, Some(1000))?
971-
.build()?;
972-
973-
let expected = "Limit: skip=0, fetch=1000\
974-
\n RightSemi Join: \
975-
\n TableScan: test\
976-
\n Limit: skip=0, fetch=1000\
977-
\n TableScan: test2, fetch=1000";
978-
979-
assert_optimized_plan_equal(plan, expected)?;
980-
981-
let plan = LogicalPlanBuilder::from(table_scan_1)
982-
.join(
983-
LogicalPlanBuilder::from(table_scan_2).build()?,
984-
JoinType::RightAnti,
985-
(left_keys, right_keys),
986-
None,
987-
)?
988-
.limit(0, Some(1000))?
989-
.build()?;
990-
991-
let expected = "Limit: skip=0, fetch=1000\
992-
\n RightAnti Join: \
993-
\n TableScan: test\
994-
\n Limit: skip=0, fetch=1000\
995-
\n TableScan: test2, fetch=1000";
996-
997-
assert_optimized_plan_equal(plan, expected)
998-
}
999-
1000-
#[test]
1001-
fn limit_should_push_down_left_outer_join() -> Result<()> {
1002-
let table_scan_1 = test_table_scan()?;
1003-
let table_scan_2 = test_table_scan_with_name("test2")?;
1004-
1005-
let plan = LogicalPlanBuilder::from(table_scan_1)
1006-
.join(
1007-
LogicalPlanBuilder::from(table_scan_2).build()?,
1008-
JoinType::Left,
1009-
(vec!["a"], vec!["a"]),
1010-
None,
1011-
)?
1012-
.limit(0, Some(1000))?
1013-
.build()?;
1014-
1015-
// Limit pushdown Not supported in Join
1016-
let expected = "Limit: skip=0, fetch=1000\
1017-
\n Left Join: test.a = test2.a\
1018-
\n Limit: skip=0, fetch=1000\
1019-
\n TableScan: test, fetch=1000\
1020-
\n TableScan: test2";
1021-
1022-
assert_optimized_plan_equal(plan, expected)
1023-
}
1024-
1025861
#[test]
1026862
fn limit_should_push_down_left_outer_join_with_offset() -> Result<()> {
1027863
let table_scan_1 = test_table_scan()?;

datafusion/optimizer/src/scalar_subquery_to_join.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,11 @@ fn build_join(
331331
_ => {
332332
// if not correlated, group down to 1 row and left join on that (preserving row count)
333333
LogicalPlanBuilder::from(filter_input.clone())
334-
.join_on(sub_query_alias, JoinType::Left, None)?
334+
.join_on(
335+
sub_query_alias,
336+
JoinType::Left,
337+
vec![Expr::Literal(ScalarValue::Boolean(Some(true)))],
338+
)?
335339
.build()?
336340
}
337341
}
@@ -557,7 +561,7 @@ mod tests {
557561
// it will optimize, but fail for the same reason the unoptimized query would
558562
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
559563
\n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
560-
\n Left Join: [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
564+
\n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
561565
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
562566
\n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
563567
\n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
@@ -589,7 +593,7 @@ mod tests {
589593

590594
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
591595
\n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
592-
\n Left Join: [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
596+
\n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
593597
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
594598
\n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
595599
\n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
@@ -965,7 +969,7 @@ mod tests {
965969

966970
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
967971
\n Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
968-
\n Left Join: [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
972+
\n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
969973
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
970974
\n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
971975
\n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
@@ -996,7 +1000,7 @@ mod tests {
9961000

9971001
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
9981002
\n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
999-
\n Left Join: [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
1003+
\n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\
10001004
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
10011005
\n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\
10021006
\n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\
@@ -1097,8 +1101,8 @@ mod tests {
10971101

10981102
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
10991103
\n Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\
1100-
\n Left Join: [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\
1101-
\n Left Join: [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N]\
1104+
\n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\
1105+
\n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N]\
11021106
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
11031107
\n SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N]\
11041108
\n Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N]\

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,11 +1458,6 @@ fn test_unnest_to_sql() {
14581458

14591459
#[test]
14601460
fn test_join_with_no_conditions() {
1461-
sql_round_trip(
1462-
GenericDialect {},
1463-
"SELECT j1.j1_id, j1.j1_string FROM j1 JOIN j2",
1464-
"SELECT j1.j1_id, j1.j1_string FROM j1 CROSS JOIN j2",
1465-
);
14661461
sql_round_trip(
14671462
GenericDialect {},
14681463
"SELECT j1.j1_id, j1.j1_string FROM j1 CROSS JOIN j2",

datafusion/sqllogictest/test_files/join.slt.part

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,28 @@ FROM t1
625625
----
626626
11 11 11
627627

628+
# join condition is required
629+
# TODO: query error join condition should not be empty
630+
# related to: https://github.com/apache/datafusion/issues/13486
631+
statement ok
632+
SELECT * FROM t1 JOIN t2
633+
634+
# join condition is required
635+
query error join condition should not be empty
636+
SELECT * FROM t1 LEFT JOIN t2
637+
638+
# join condition is required
639+
query error join condition should not be empty
640+
SELECT * FROM t1 RIGHT JOIN t2
641+
642+
# join condition is required
643+
query error join condition should not be empty
644+
SELECT * FROM t1 FULL JOIN t2
645+
646+
# cross join no need for join condition
647+
statement ok
648+
SELECT * FROM t1 CROSS JOIN t2
649+
628650
# multiple inner joins with mixed ON clause and filter
629651
query III rowsort
630652
SELECT t1.t1_id, t2.t2_id, t3.t3_id

datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ logical_plan
5555
06)----------Inner Join: supplier.s_nationkey = nation.n_nationkey
5656
07)------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey
5757
08)--------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
58-
09)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost]
58+
09)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], partial_filters=[Boolean(true)]
5959
10)----------------TableScan: supplier projection=[s_suppkey, s_nationkey]
6060
11)------------Projection: nation.n_nationkey
6161
12)--------------Filter: nation.n_name = Utf8("GERMANY")

datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ logical_plan
5555
03)----Inner Join: revenue0.total_revenue = __scalar_sq_1.max(revenue0.total_revenue)
5656
04)------Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address, supplier.s_phone, revenue0.total_revenue
5757
05)--------Inner Join: supplier.s_suppkey = revenue0.supplier_no
58-
06)----------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_phone]
58+
06)----------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_phone], partial_filters=[Boolean(true)]
5959
07)----------SubqueryAlias: revenue0
6060
08)------------Projection: lineitem.l_suppkey AS supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue
6161
09)--------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]]

datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ logical_plan
6565
07)------------Projection: customer.c_phone, customer.c_acctbal
6666
08)--------------LeftAnti Join: customer.c_custkey = __correlated_sq_1.o_custkey
6767
09)----------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])
68-
10)------------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])]
68+
10)------------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]), Boolean(true)]
6969
11)----------------SubqueryAlias: __correlated_sq_1
7070
12)------------------TableScan: orders projection=[o_custkey]
7171
13)------------SubqueryAlias: __scalar_sq_2

0 commit comments

Comments
 (0)