Skip to content

Commit f514e12

Browse files
Preserve the order of right table in NestedLoopJoinExec (#12504)
* Maintain right child's order in NestedLoopJoinExec * Format * Refactor monotonicity check * Update sqllogictest according to new behavior * Check output ordering properties * Parameterize only batch sizes for left and right tables * Document maintains_input_order
1 parent 9781aef commit f514e12

File tree

3 files changed

+254
-20
lines changed

3 files changed

+254
-20
lines changed

datafusion/physical-plan/src/joins/nested_loop_join.rs

Lines changed: 247 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ impl NestedLoopJoinExec {
221221
right.equivalence_properties().clone(),
222222
&join_type,
223223
schema,
224-
&[false, false],
224+
&Self::maintains_input_order(join_type),
225225
None,
226226
// No on columns in nested loop join
227227
&[],
@@ -238,6 +238,31 @@ impl NestedLoopJoinExec {
238238

239239
PlanProperties::new(eq_properties, output_partitioning, mode)
240240
}
241+
242+
/// Returns a vector indicating whether the left and right inputs maintain their order.
243+
/// The first element corresponds to the left input, and the second to the right.
244+
///
245+
/// The left (build-side) input's order may change, but the right (probe-side) input's
246+
/// order is maintained for INNER, RIGHT, RIGHT ANTI, and RIGHT SEMI joins.
247+
///
248+
/// Maintaining the right input's order helps optimize the nodes down the pipeline
249+
/// (See [`ExecutionPlan::maintains_input_order`]).
250+
///
251+
/// This is a separate method because it is also called when computing properties, before
252+
/// a [`NestedLoopJoinExec`] is created. It also takes [`JoinType`] as an argument, as
253+
/// opposed to `Self`, for the same reason.
254+
fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
255+
vec![
256+
false,
257+
matches!(
258+
join_type,
259+
JoinType::Inner
260+
| JoinType::Right
261+
| JoinType::RightAnti
262+
| JoinType::RightSemi
263+
),
264+
]
265+
}
241266
}
242267

243268
impl DisplayAs for NestedLoopJoinExec {
@@ -278,6 +303,10 @@ impl ExecutionPlan for NestedLoopJoinExec {
278303
]
279304
}
280305

306+
fn maintains_input_order(&self) -> Vec<bool> {
307+
Self::maintains_input_order(self.join_type)
308+
}
309+
281310
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
282311
vec![&self.left, &self.right]
283312
}
@@ -430,17 +459,17 @@ struct NestedLoopJoinStream {
430459
}
431460

432461
fn build_join_indices(
433-
left_row_index: usize,
434-
right_batch: &RecordBatch,
462+
right_row_index: usize,
435463
left_batch: &RecordBatch,
464+
right_batch: &RecordBatch,
436465
filter: Option<&JoinFilter>,
437466
) -> Result<(UInt64Array, UInt32Array)> {
438-
// left indices: [left_index, left_index, ...., left_index]
439-
// right indices: [0, 1, 2, 3, 4,....,right_row_count]
467+
// left indices: [0, 1, 2, 3, 4, ..., left_row_count]
468+
// right indices: [right_index, right_index, ..., right_index]
440469

441-
let right_row_count = right_batch.num_rows();
442-
let left_indices = UInt64Array::from(vec![left_row_index as u64; right_row_count]);
443-
let right_indices = UInt32Array::from_iter_values(0..(right_row_count as u32));
470+
let left_row_count = left_batch.num_rows();
471+
let left_indices = UInt64Array::from_iter_values(0..(left_row_count as u64));
472+
let right_indices = UInt32Array::from(vec![right_row_index as u32; left_row_count]);
444473
// in the nested loop join, the filter can contain non-equal and equal condition.
445474
if let Some(filter) = filter {
446475
apply_join_filter_to_indices(
@@ -567,9 +596,9 @@ fn join_left_and_right_batch(
567596
schema: &Schema,
568597
visited_left_side: &SharedBitmapBuilder,
569598
) -> Result<RecordBatch> {
570-
let indices = (0..left_batch.num_rows())
571-
.map(|left_row_index| {
572-
build_join_indices(left_row_index, right_batch, left_batch, filter)
599+
let indices = (0..right_batch.num_rows())
600+
.map(|right_row_index| {
601+
build_join_indices(right_row_index, left_batch, right_batch, filter)
573602
})
574603
.collect::<Result<Vec<(UInt64Array, UInt32Array)>>>()
575604
.map_err(|e| {
@@ -601,7 +630,7 @@ fn join_left_and_right_batch(
601630
right_side,
602631
0..right_batch.num_rows(),
603632
join_type,
604-
false,
633+
true,
605634
);
606635

607636
build_batch_from_indices(
@@ -649,27 +678,68 @@ mod tests {
649678
};
650679

651680
use arrow::datatypes::{DataType, Field};
681+
use arrow_array::Int32Array;
682+
use arrow_schema::SortOptions;
652683
use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue};
653684
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
654685
use datafusion_expr::Operator;
655686
use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
656687
use datafusion_physical_expr::{Partitioning, PhysicalExpr};
688+
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
689+
690+
use rstest::rstest;
657691

658692
fn build_table(
659693
a: (&str, &Vec<i32>),
660694
b: (&str, &Vec<i32>),
661695
c: (&str, &Vec<i32>),
696+
batch_size: Option<usize>,
697+
sorted_column_names: Vec<&str>,
662698
) -> Arc<dyn ExecutionPlan> {
663699
let batch = build_table_i32(a, b, c);
664700
let schema = batch.schema();
665-
Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
701+
702+
let batches = if let Some(batch_size) = batch_size {
703+
let num_batches = batch.num_rows().div_ceil(batch_size);
704+
(0..num_batches)
705+
.map(|i| {
706+
let start = i * batch_size;
707+
let remaining_rows = batch.num_rows() - start;
708+
batch.slice(start, batch_size.min(remaining_rows))
709+
})
710+
.collect::<Vec<_>>()
711+
} else {
712+
vec![batch]
713+
};
714+
715+
let mut exec =
716+
MemoryExec::try_new(&[batches], Arc::clone(&schema), None).unwrap();
717+
if !sorted_column_names.is_empty() {
718+
let mut sort_info = Vec::new();
719+
for name in sorted_column_names {
720+
let index = schema.index_of(name).unwrap();
721+
let sort_expr = PhysicalSortExpr {
722+
expr: Arc::new(Column::new(name, index)),
723+
options: SortOptions {
724+
descending: false,
725+
nulls_first: false,
726+
},
727+
};
728+
sort_info.push(sort_expr);
729+
}
730+
exec = exec.with_sort_information(vec![sort_info]);
731+
}
732+
733+
Arc::new(exec)
666734
}
667735

668736
fn build_left_table() -> Arc<dyn ExecutionPlan> {
669737
build_table(
670738
("a1", &vec![5, 9, 11]),
671739
("b1", &vec![5, 8, 8]),
672740
("c1", &vec![50, 90, 110]),
741+
None,
742+
Vec::new(),
673743
)
674744
}
675745

@@ -678,6 +748,8 @@ mod tests {
678748
("a2", &vec![12, 2, 10]),
679749
("b2", &vec![10, 2, 10]),
680750
("c2", &vec![40, 80, 100]),
751+
None,
752+
Vec::new(),
681753
)
682754
}
683755

@@ -1005,11 +1077,15 @@ mod tests {
10051077
("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
10061078
("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
10071079
("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
1080+
None,
1081+
Vec::new(),
10081082
);
10091083
let right = build_table(
10101084
("a2", &vec![10, 11]),
10111085
("b2", &vec![12, 13]),
10121086
("c2", &vec![14, 15]),
1087+
None,
1088+
Vec::new(),
10131089
);
10141090
let filter = prepare_join_filter();
10151091

@@ -1050,6 +1126,164 @@ mod tests {
10501126
Ok(())
10511127
}
10521128

1129+
fn prepare_mod_join_filter() -> JoinFilter {
1130+
let column_indices = vec![
1131+
ColumnIndex {
1132+
index: 1,
1133+
side: JoinSide::Left,
1134+
},
1135+
ColumnIndex {
1136+
index: 1,
1137+
side: JoinSide::Right,
1138+
},
1139+
];
1140+
let intermediate_schema = Schema::new(vec![
1141+
Field::new("x", DataType::Int32, true),
1142+
Field::new("x", DataType::Int32, true),
1143+
]);
1144+
1145+
// left.b1 % 3
1146+
let left_mod = Arc::new(BinaryExpr::new(
1147+
Arc::new(Column::new("x", 0)),
1148+
Operator::Modulo,
1149+
Arc::new(Literal::new(ScalarValue::Int32(Some(3)))),
1150+
)) as Arc<dyn PhysicalExpr>;
1151+
// left.b1 % 3 != 0
1152+
let left_filter = Arc::new(BinaryExpr::new(
1153+
left_mod,
1154+
Operator::NotEq,
1155+
Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
1156+
)) as Arc<dyn PhysicalExpr>;
1157+
1158+
// right.b2 % 5
1159+
let right_mod = Arc::new(BinaryExpr::new(
1160+
Arc::new(Column::new("x", 1)),
1161+
Operator::Modulo,
1162+
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))),
1163+
)) as Arc<dyn PhysicalExpr>;
1164+
// right.b2 % 5 != 0
1165+
let right_filter = Arc::new(BinaryExpr::new(
1166+
right_mod,
1167+
Operator::NotEq,
1168+
Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
1169+
)) as Arc<dyn PhysicalExpr>;
1170+
// filter = left.b1 % 3 != 0 and right.b2 % 5 != 0
1171+
let filter_expression =
1172+
Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
1173+
as Arc<dyn PhysicalExpr>;
1174+
1175+
JoinFilter::new(filter_expression, column_indices, intermediate_schema)
1176+
}
1177+
1178+
fn generate_columns(num_columns: usize, num_rows: usize) -> Vec<Vec<i32>> {
1179+
let column = (1..=num_rows).map(|x| x as i32).collect();
1180+
vec![column; num_columns]
1181+
}
1182+
1183+
#[rstest]
1184+
#[tokio::test]
1185+
async fn join_maintains_right_order(
1186+
#[values(
1187+
JoinType::Inner,
1188+
JoinType::Right,
1189+
JoinType::RightAnti,
1190+
JoinType::RightSemi
1191+
)]
1192+
join_type: JoinType,
1193+
#[values(1, 100, 1000)] left_batch_size: usize,
1194+
#[values(1, 100, 1000)] right_batch_size: usize,
1195+
) -> Result<()> {
1196+
let left_columns = generate_columns(3, 1000);
1197+
let left = build_table(
1198+
("a1", &left_columns[0]),
1199+
("b1", &left_columns[1]),
1200+
("c1", &left_columns[2]),
1201+
Some(left_batch_size),
1202+
Vec::new(),
1203+
);
1204+
1205+
let right_columns = generate_columns(3, 1000);
1206+
let right = build_table(
1207+
("a2", &right_columns[0]),
1208+
("b2", &right_columns[1]),
1209+
("c2", &right_columns[2]),
1210+
Some(right_batch_size),
1211+
vec!["a2", "b2", "c2"],
1212+
);
1213+
1214+
let filter = prepare_mod_join_filter();
1215+
1216+
let nested_loop_join = Arc::new(NestedLoopJoinExec::try_new(
1217+
left,
1218+
Arc::clone(&right),
1219+
Some(filter),
1220+
&join_type,
1221+
)?) as Arc<dyn ExecutionPlan>;
1222+
assert_eq!(nested_loop_join.maintains_input_order(), vec![false, true]);
1223+
1224+
let right_column_indices = match join_type {
1225+
JoinType::Inner | JoinType::Right => vec![3, 4, 5],
1226+
JoinType::RightAnti | JoinType::RightSemi => vec![0, 1, 2],
1227+
_ => unreachable!(),
1228+
};
1229+
1230+
let right_ordering = right.output_ordering().unwrap();
1231+
let join_ordering = nested_loop_join.output_ordering().unwrap();
1232+
for (right, join) in right_ordering.iter().zip(join_ordering.iter()) {
1233+
let right_column = right.expr.as_any().downcast_ref::<Column>().unwrap();
1234+
let join_column = join.expr.as_any().downcast_ref::<Column>().unwrap();
1235+
assert_eq!(join_column.name(), join_column.name());
1236+
assert_eq!(
1237+
right_column_indices[right_column.index()],
1238+
join_column.index()
1239+
);
1240+
assert_eq!(right.options, join.options);
1241+
}
1242+
1243+
let batches = nested_loop_join
1244+
.execute(0, Arc::new(TaskContext::default()))?
1245+
.try_collect::<Vec<_>>()
1246+
.await?;
1247+
1248+
// Make sure that the order of the right side is maintained
1249+
let mut prev_values = [i32::MIN, i32::MIN, i32::MIN];
1250+
1251+
for (batch_index, batch) in batches.iter().enumerate() {
1252+
let columns: Vec<_> = right_column_indices
1253+
.iter()
1254+
.map(|&i| {
1255+
batch
1256+
.column(i)
1257+
.as_any()
1258+
.downcast_ref::<Int32Array>()
1259+
.unwrap()
1260+
})
1261+
.collect();
1262+
1263+
for row in 0..batch.num_rows() {
1264+
let current_values = [
1265+
columns[0].value(row),
1266+
columns[1].value(row),
1267+
columns[2].value(row),
1268+
];
1269+
assert!(
1270+
current_values
1271+
.into_iter()
1272+
.zip(prev_values)
1273+
.all(|(current, prev)| current >= prev),
1274+
"batch_index: {} row: {} current: {:?}, prev: {:?}",
1275+
batch_index,
1276+
row,
1277+
current_values,
1278+
prev_values
1279+
);
1280+
prev_values = current_values;
1281+
}
1282+
}
1283+
1284+
Ok(())
1285+
}
1286+
10531287
/// Returns the column names on the schema
10541288
fn columns(schema: &Schema) -> Vec<String> {
10551289
schema.fields().iter().map(|f| f.name().clone()).collect()

datafusion/sqllogictest/test_files/join.slt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -838,10 +838,10 @@ LEFT JOIN department AS d
838838
ON (e.name = 'Alice' OR e.name = 'Bob');
839839
----
840840
1 Alice HR
841-
2 Bob HR
842841
1 Alice Engineering
843-
2 Bob Engineering
844842
1 Alice Sales
843+
2 Bob HR
844+
2 Bob Engineering
845845
2 Bob Sales
846846
3 Carol NULL
847847

@@ -853,10 +853,10 @@ RIGHT JOIN employees AS e
853853
ON (e.name = 'Alice' OR e.name = 'Bob');
854854
----
855855
1 Alice HR
856-
2 Bob HR
857856
1 Alice Engineering
858-
2 Bob Engineering
859857
1 Alice Sales
858+
2 Bob HR
859+
2 Bob Engineering
860860
2 Bob Sales
861861
3 Carol NULL
862862

@@ -868,10 +868,10 @@ FULL JOIN employees AS e
868868
ON (e.name = 'Alice' OR e.name = 'Bob');
869869
----
870870
1 Alice HR
871-
2 Bob HR
872871
1 Alice Engineering
873-
2 Bob Engineering
874872
1 Alice Sales
873+
2 Bob HR
874+
2 Bob Engineering
875875
2 Bob Sales
876876
3 Carol NULL
877877

0 commit comments

Comments
 (0)