Skip to content

Commit 171fa02

Browse files
korowaalamb
authored andcommitted
feat: support input reordering for NestedLoopJoinExec (apache#9676)
* support input reordering for NestedLoopJoinExec * renamed variables and struct fields * fixed nl join filter expression in tests * Update datafusion/physical-plan/src/joins/nested_loop_join.rs Co-authored-by: Andrew Lamb <[email protected]> * typo fixed --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent c5aec9f commit 171fa02

File tree

8 files changed

+547
-374
lines changed

8 files changed

+547
-374
lines changed

datafusion/core/src/physical_optimizer/join_selection.rs

Lines changed: 206 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ use crate::error::Result;
3030
use crate::physical_optimizer::PhysicalOptimizerRule;
3131
use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
3232
use crate::physical_plan::joins::{
33-
CrossJoinExec, HashJoinExec, PartitionMode, StreamJoinPartitionMode,
34-
SymmetricHashJoinExec,
33+
CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode,
34+
StreamJoinPartitionMode, SymmetricHashJoinExec,
3535
};
3636
use crate::physical_plan::projection::ProjectionExec;
3737
use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
@@ -199,6 +199,38 @@ fn swap_hash_join(
199199
}
200200
}
201201

202+
/// Swaps inputs of `NestedLoopJoinExec` and wraps it into `ProjectionExec` is required
203+
fn swap_nl_join(join: &NestedLoopJoinExec) -> Result<Arc<dyn ExecutionPlan>> {
204+
let new_filter = swap_join_filter(join.filter());
205+
let new_join_type = &swap_join_type(*join.join_type());
206+
207+
let new_join = NestedLoopJoinExec::try_new(
208+
Arc::clone(join.right()),
209+
Arc::clone(join.left()),
210+
new_filter,
211+
new_join_type,
212+
)?;
213+
214+
// For Semi/Anti joins, swap result will produce same output schema,
215+
// no need to wrap them into additional projection
216+
let plan: Arc<dyn ExecutionPlan> = if matches!(
217+
join.join_type(),
218+
JoinType::LeftSemi
219+
| JoinType::RightSemi
220+
| JoinType::LeftAnti
221+
| JoinType::RightAnti
222+
) {
223+
Arc::new(new_join)
224+
} else {
225+
let projection =
226+
swap_reverting_projection(&join.left().schema(), &join.right().schema());
227+
228+
Arc::new(ProjectionExec::try_new(projection, Arc::new(new_join))?)
229+
};
230+
231+
Ok(plan)
232+
}
233+
202234
/// When the order of the join is changed by the optimizer, the columns in
203235
/// the output should not be impacted. This function creates the expressions
204236
/// that will allow to swap back the values from the original left as the first
@@ -438,6 +470,14 @@ fn statistical_join_selection_subrule(
438470
} else {
439471
None
440472
}
473+
} else if let Some(nl_join) = plan.as_any().downcast_ref::<NestedLoopJoinExec>() {
474+
let left = nl_join.left();
475+
let right = nl_join.right();
476+
if should_swap_join_order(&**left, &**right)? {
477+
swap_nl_join(nl_join).map(Some)?
478+
} else {
479+
None
480+
}
441481
} else {
442482
None
443483
};
@@ -674,9 +714,12 @@ mod tests_statistical {
674714

675715
use arrow::datatypes::{DataType, Field, Schema};
676716
use datafusion_common::{stats::Precision, JoinType, ScalarValue};
677-
use datafusion_physical_expr::expressions::Column;
717+
use datafusion_expr::Operator;
718+
use datafusion_physical_expr::expressions::{BinaryExpr, Column};
678719
use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};
679720

721+
use rstest::rstest;
722+
680723
/// Return statistcs for empty table
681724
fn empty_statistics() -> Statistics {
682725
Statistics {
@@ -762,6 +805,35 @@ mod tests_statistical {
762805
}]
763806
}
764807

808+
/// Create join filter for NLJoinExec with expression `big_col > small_col`
809+
/// where both columns are 0-indexed and come from left and right inputs respectively
810+
fn nl_join_filter() -> Option<JoinFilter> {
811+
let column_indices = vec![
812+
ColumnIndex {
813+
index: 0,
814+
side: JoinSide::Left,
815+
},
816+
ColumnIndex {
817+
index: 0,
818+
side: JoinSide::Right,
819+
},
820+
];
821+
let intermediate_schema = Schema::new(vec![
822+
Field::new("big_col", DataType::Int32, false),
823+
Field::new("small_col", DataType::Int32, false),
824+
]);
825+
let expression = Arc::new(BinaryExpr::new(
826+
Arc::new(Column::new_with_schema("big_col", &intermediate_schema).unwrap()),
827+
Operator::Gt,
828+
Arc::new(Column::new_with_schema("small_col", &intermediate_schema).unwrap()),
829+
)) as _;
830+
Some(JoinFilter::new(
831+
expression,
832+
column_indices,
833+
intermediate_schema,
834+
))
835+
}
836+
765837
/// Returns three plans with statistics of (min, max, distinct_count)
766838
/// * big 100K rows @ (0, 50k, 50k)
767839
/// * medium 10K rows @ (1k, 5k, 1k)
@@ -1114,6 +1186,137 @@ mod tests_statistical {
11141186
crosscheck_plans(join).unwrap();
11151187
}
11161188

1189+
#[rstest(
1190+
join_type,
1191+
case::inner(JoinType::Inner),
1192+
case::left(JoinType::Left),
1193+
case::right(JoinType::Right),
1194+
case::full(JoinType::Full)
1195+
)]
1196+
#[tokio::test]
1197+
async fn test_nl_join_with_swap(join_type: JoinType) {
1198+
let (big, small) = create_big_and_small();
1199+
1200+
let join = Arc::new(
1201+
NestedLoopJoinExec::try_new(
1202+
Arc::clone(&big),
1203+
Arc::clone(&small),
1204+
nl_join_filter(),
1205+
&join_type,
1206+
)
1207+
.unwrap(),
1208+
);
1209+
1210+
let optimized_join = JoinSelection::new()
1211+
.optimize(join.clone(), &ConfigOptions::new())
1212+
.unwrap();
1213+
1214+
let swapping_projection = optimized_join
1215+
.as_any()
1216+
.downcast_ref::<ProjectionExec>()
1217+
.expect("A proj is required to swap columns back to their original order");
1218+
1219+
assert_eq!(swapping_projection.expr().len(), 2);
1220+
let (col, name) = &swapping_projection.expr()[0];
1221+
assert_eq!(name, "big_col");
1222+
assert_col_expr(col, "big_col", 1);
1223+
let (col, name) = &swapping_projection.expr()[1];
1224+
assert_eq!(name, "small_col");
1225+
assert_col_expr(col, "small_col", 0);
1226+
1227+
let swapped_join = swapping_projection
1228+
.input()
1229+
.as_any()
1230+
.downcast_ref::<NestedLoopJoinExec>()
1231+
.expect("The type of the plan should not be changed");
1232+
1233+
// Assert join side of big_col swapped in filter expression
1234+
let swapped_filter = swapped_join.filter().unwrap();
1235+
let swapped_big_col_idx = swapped_filter.schema().index_of("big_col").unwrap();
1236+
let swapped_big_col_side = swapped_filter
1237+
.column_indices()
1238+
.get(swapped_big_col_idx)
1239+
.unwrap()
1240+
.side;
1241+
assert_eq!(
1242+
swapped_big_col_side,
1243+
JoinSide::Right,
1244+
"Filter column side should be swapped"
1245+
);
1246+
1247+
assert_eq!(
1248+
swapped_join.left().statistics().unwrap().total_byte_size,
1249+
Precision::Inexact(8192)
1250+
);
1251+
assert_eq!(
1252+
swapped_join.right().statistics().unwrap().total_byte_size,
1253+
Precision::Inexact(2097152)
1254+
);
1255+
crosscheck_plans(join.clone()).unwrap();
1256+
}
1257+
1258+
#[rstest(
1259+
join_type,
1260+
case::left_semi(JoinType::LeftSemi),
1261+
case::left_anti(JoinType::LeftAnti),
1262+
case::right_semi(JoinType::RightSemi),
1263+
case::right_anti(JoinType::RightAnti)
1264+
)]
1265+
#[tokio::test]
1266+
async fn test_nl_join_with_swap_no_proj(join_type: JoinType) {
1267+
let (big, small) = create_big_and_small();
1268+
1269+
let join = Arc::new(
1270+
NestedLoopJoinExec::try_new(
1271+
Arc::clone(&big),
1272+
Arc::clone(&small),
1273+
nl_join_filter(),
1274+
&join_type,
1275+
)
1276+
.unwrap(),
1277+
);
1278+
1279+
let optimized_join = JoinSelection::new()
1280+
.optimize(join.clone(), &ConfigOptions::new())
1281+
.unwrap();
1282+
1283+
let swapped_join = optimized_join
1284+
.as_any()
1285+
.downcast_ref::<NestedLoopJoinExec>()
1286+
.expect("The type of the plan should not be changed");
1287+
1288+
// Assert before/after schemas are equal
1289+
assert_eq!(
1290+
join.schema(),
1291+
swapped_join.schema(),
1292+
"Join schema should not be modified while optimization"
1293+
);
1294+
1295+
// Assert join side of big_col swapped in filter expression
1296+
let swapped_filter = swapped_join.filter().unwrap();
1297+
let swapped_big_col_idx = swapped_filter.schema().index_of("big_col").unwrap();
1298+
let swapped_big_col_side = swapped_filter
1299+
.column_indices()
1300+
.get(swapped_big_col_idx)
1301+
.unwrap()
1302+
.side;
1303+
assert_eq!(
1304+
swapped_big_col_side,
1305+
JoinSide::Right,
1306+
"Filter column side should be swapped"
1307+
);
1308+
1309+
assert_eq!(
1310+
swapped_join.left().statistics().unwrap().total_byte_size,
1311+
Precision::Inexact(8192)
1312+
);
1313+
assert_eq!(
1314+
swapped_join.right().statistics().unwrap().total_byte_size,
1315+
Precision::Inexact(2097152)
1316+
);
1317+
crosscheck_plans(join.clone()).unwrap();
1318+
}
1319+
11171320
#[tokio::test]
11181321
async fn test_swap_reverting_projection() {
11191322
let left_schema = Schema::new(vec![

datafusion/core/tests/fuzz_cases/join_fuzz.rs

Lines changed: 84 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,19 @@ use arrow::array::{ArrayRef, Int32Array};
2121
use arrow::compute::SortOptions;
2222
use arrow::record_batch::RecordBatch;
2323
use arrow::util::pretty::pretty_format_batches;
24+
use arrow_schema::Schema;
2425
use rand::Rng;
2526

27+
use datafusion::common::JoinSide;
28+
use datafusion::logical_expr::{JoinType, Operator};
29+
use datafusion::physical_expr::expressions::BinaryExpr;
2630
use datafusion::physical_plan::collect;
2731
use datafusion::physical_plan::expressions::Column;
28-
use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec};
32+
use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
33+
use datafusion::physical_plan::joins::{
34+
HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec,
35+
};
2936
use datafusion::physical_plan::memory::MemoryExec;
30-
use datafusion_expr::JoinType;
3137

3238
use datafusion::prelude::{SessionConfig, SessionContext};
3339
use test_utils::stagger_batch_with_seed;
@@ -73,7 +79,7 @@ async fn test_full_join_1k() {
7379
}
7480

7581
#[tokio::test]
76-
async fn test_semi_join_1k() {
82+
async fn test_semi_join_10k() {
7783
run_join_test(
7884
make_staggered_batches(10000),
7985
make_staggered_batches(10000),
@@ -83,7 +89,7 @@ async fn test_semi_join_1k() {
8389
}
8490

8591
#[tokio::test]
86-
async fn test_anti_join_1k() {
92+
async fn test_anti_join_10k() {
8793
run_join_test(
8894
make_staggered_batches(10000),
8995
make_staggered_batches(10000),
@@ -118,6 +124,46 @@ async fn run_join_test(
118124
),
119125
];
120126

127+
// Nested loop join uses filter for joining records
128+
let column_indices = vec![
129+
ColumnIndex {
130+
index: 0,
131+
side: JoinSide::Left,
132+
},
133+
ColumnIndex {
134+
index: 1,
135+
side: JoinSide::Left,
136+
},
137+
ColumnIndex {
138+
index: 0,
139+
side: JoinSide::Right,
140+
},
141+
ColumnIndex {
142+
index: 1,
143+
side: JoinSide::Right,
144+
},
145+
];
146+
let intermediate_schema = Schema::new(vec![
147+
schema1.field_with_name("a").unwrap().to_owned(),
148+
schema1.field_with_name("b").unwrap().to_owned(),
149+
schema2.field_with_name("a").unwrap().to_owned(),
150+
schema2.field_with_name("b").unwrap().to_owned(),
151+
]);
152+
153+
let equal_a = Arc::new(BinaryExpr::new(
154+
Arc::new(Column::new("a", 0)),
155+
Operator::Eq,
156+
Arc::new(Column::new("a", 2)),
157+
)) as _;
158+
let equal_b = Arc::new(BinaryExpr::new(
159+
Arc::new(Column::new("b", 1)),
160+
Operator::Eq,
161+
Arc::new(Column::new("b", 3)),
162+
)) as _;
163+
let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)) as _;
164+
165+
let on_filter = JoinFilter::new(expression, column_indices, intermediate_schema);
166+
121167
// sort-merge join
122168
let left = Arc::new(
123169
MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(),
@@ -161,22 +207,55 @@ async fn run_join_test(
161207
);
162208
let hj_collected = collect(hj, task_ctx.clone()).await.unwrap();
163209

210+
// nested loop join
211+
let left = Arc::new(
212+
MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(),
213+
);
214+
let right = Arc::new(
215+
MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(),
216+
);
217+
let nlj = Arc::new(
218+
NestedLoopJoinExec::try_new(left, right, Some(on_filter), &join_type)
219+
.unwrap(),
220+
);
221+
let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap();
222+
164223
// compare
165224
let smj_formatted = pretty_format_batches(&smj_collected).unwrap().to_string();
166225
let hj_formatted = pretty_format_batches(&hj_collected).unwrap().to_string();
226+
let nlj_formatted = pretty_format_batches(&nlj_collected).unwrap().to_string();
167227

168228
let mut smj_formatted_sorted: Vec<&str> = smj_formatted.trim().lines().collect();
169229
smj_formatted_sorted.sort_unstable();
170230

171231
let mut hj_formatted_sorted: Vec<&str> = hj_formatted.trim().lines().collect();
172232
hj_formatted_sorted.sort_unstable();
173233

234+
let mut nlj_formatted_sorted: Vec<&str> = nlj_formatted.trim().lines().collect();
235+
nlj_formatted_sorted.sort_unstable();
236+
174237
for (i, (smj_line, hj_line)) in smj_formatted_sorted
175238
.iter()
176239
.zip(&hj_formatted_sorted)
177240
.enumerate()
178241
{
179-
assert_eq!((i, smj_line), (i, hj_line));
242+
assert_eq!(
243+
(i, smj_line),
244+
(i, hj_line),
245+
"SortMergeJoinExec and HashJoinExec produced different results"
246+
);
247+
}
248+
249+
for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
250+
.iter()
251+
.zip(&hj_formatted_sorted)
252+
.enumerate()
253+
{
254+
assert_eq!(
255+
(i, nlj_line),
256+
(i, hj_line),
257+
"NestedLoopJoinExec and HashJoinExec produced different results"
258+
);
180259
}
181260
}
182261
}

0 commit comments

Comments
 (0)