@@ -221,7 +221,7 @@ impl NestedLoopJoinExec {
221
221
right. equivalence_properties ( ) . clone ( ) ,
222
222
& join_type,
223
223
schema,
224
- & [ false , false ] ,
224
+ & Self :: maintains_input_order ( join_type ) ,
225
225
None ,
226
226
// No on columns in nested loop join
227
227
& [ ] ,
@@ -238,6 +238,31 @@ impl NestedLoopJoinExec {
238
238
239
239
PlanProperties :: new ( eq_properties, output_partitioning, mode)
240
240
}
241
+
242
+ /// Returns a vector indicating whether the left and right inputs maintain their order.
243
+ /// The first element corresponds to the left input, and the second to the right.
244
+ ///
245
+ /// The left (build-side) input's order may change, but the right (probe-side) input's
246
+ /// order is maintained for INNER, RIGHT, RIGHT ANTI, and RIGHT SEMI joins.
247
+ ///
248
+ /// Maintaining the right input's order helps optimize the nodes down the pipeline
249
+ /// (See [`ExecutionPlan::maintains_input_order`]).
250
+ ///
251
+ /// This is a separate method because it is also called when computing properties, before
252
+ /// a [`NestedLoopJoinExec`] is created. It also takes [`JoinType`] as an argument, as
253
+ /// opposed to `Self`, for the same reason.
254
+ fn maintains_input_order ( join_type : JoinType ) -> Vec < bool > {
255
+ vec ! [
256
+ false ,
257
+ matches!(
258
+ join_type,
259
+ JoinType :: Inner
260
+ | JoinType :: Right
261
+ | JoinType :: RightAnti
262
+ | JoinType :: RightSemi
263
+ ) ,
264
+ ]
265
+ }
241
266
}
242
267
243
268
impl DisplayAs for NestedLoopJoinExec {
@@ -278,6 +303,10 @@ impl ExecutionPlan for NestedLoopJoinExec {
278
303
]
279
304
}
280
305
306
+ fn maintains_input_order ( & self ) -> Vec < bool > {
307
+ Self :: maintains_input_order ( self . join_type )
308
+ }
309
+
281
310
fn children ( & self ) -> Vec < & Arc < dyn ExecutionPlan > > {
282
311
vec ! [ & self . left, & self . right]
283
312
}
@@ -430,17 +459,17 @@ struct NestedLoopJoinStream {
430
459
}
431
460
432
461
fn build_join_indices (
433
- left_row_index : usize ,
434
- right_batch : & RecordBatch ,
462
+ right_row_index : usize ,
435
463
left_batch : & RecordBatch ,
464
+ right_batch : & RecordBatch ,
436
465
filter : Option < & JoinFilter > ,
437
466
) -> Result < ( UInt64Array , UInt32Array ) > {
438
- // left indices: [left_index, left_index, ...., left_index ]
439
- // right indices: [0, 1, 2, 3, 4, ....,right_row_count ]
467
+ // left indices: [0, 1, 2, 3, 4, ..., left_row_count ]
468
+ // right indices: [right_index, right_index, ..., right_index ]
440
469
441
- let right_row_count = right_batch . num_rows ( ) ;
442
- let left_indices = UInt64Array :: from ( vec ! [ left_row_index as u64 ; right_row_count ] ) ;
443
- let right_indices = UInt32Array :: from_iter_values ( 0 .. ( right_row_count as u32 ) ) ;
470
+ let left_row_count = left_batch . num_rows ( ) ;
471
+ let left_indices = UInt64Array :: from_iter_values ( 0 .. ( left_row_count as u64 ) ) ;
472
+ let right_indices = UInt32Array :: from ( vec ! [ right_row_index as u32 ; left_row_count ] ) ;
444
473
// in the nested loop join, the filter can contain non-equal and equal condition.
445
474
if let Some ( filter) = filter {
446
475
apply_join_filter_to_indices (
@@ -567,9 +596,9 @@ fn join_left_and_right_batch(
567
596
schema : & Schema ,
568
597
visited_left_side : & SharedBitmapBuilder ,
569
598
) -> Result < RecordBatch > {
570
- let indices = ( 0 ..left_batch . num_rows ( ) )
571
- . map ( |left_row_index | {
572
- build_join_indices ( left_row_index , right_batch , left_batch , filter)
599
+ let indices = ( 0 ..right_batch . num_rows ( ) )
600
+ . map ( |right_row_index | {
601
+ build_join_indices ( right_row_index , left_batch , right_batch , filter)
573
602
} )
574
603
. collect :: < Result < Vec < ( UInt64Array , UInt32Array ) > > > ( )
575
604
. map_err ( |e| {
@@ -601,7 +630,7 @@ fn join_left_and_right_batch(
601
630
right_side,
602
631
0 ..right_batch. num_rows ( ) ,
603
632
join_type,
604
- false ,
633
+ true ,
605
634
) ;
606
635
607
636
build_batch_from_indices (
@@ -649,27 +678,68 @@ mod tests {
649
678
} ;
650
679
651
680
use arrow:: datatypes:: { DataType , Field } ;
681
+ use arrow_array:: Int32Array ;
682
+ use arrow_schema:: SortOptions ;
652
683
use datafusion_common:: { assert_batches_sorted_eq, assert_contains, ScalarValue } ;
653
684
use datafusion_execution:: runtime_env:: RuntimeEnvBuilder ;
654
685
use datafusion_expr:: Operator ;
655
686
use datafusion_physical_expr:: expressions:: { BinaryExpr , Literal } ;
656
687
use datafusion_physical_expr:: { Partitioning , PhysicalExpr } ;
688
+ use datafusion_physical_expr_common:: sort_expr:: PhysicalSortExpr ;
689
+
690
+ use rstest:: rstest;
657
691
658
692
fn build_table (
659
693
a : ( & str , & Vec < i32 > ) ,
660
694
b : ( & str , & Vec < i32 > ) ,
661
695
c : ( & str , & Vec < i32 > ) ,
696
+ batch_size : Option < usize > ,
697
+ sorted_column_names : Vec < & str > ,
662
698
) -> Arc < dyn ExecutionPlan > {
663
699
let batch = build_table_i32 ( a, b, c) ;
664
700
let schema = batch. schema ( ) ;
665
- Arc :: new ( MemoryExec :: try_new ( & [ vec ! [ batch] ] , schema, None ) . unwrap ( ) )
701
+
702
+ let batches = if let Some ( batch_size) = batch_size {
703
+ let num_batches = batch. num_rows ( ) . div_ceil ( batch_size) ;
704
+ ( 0 ..num_batches)
705
+ . map ( |i| {
706
+ let start = i * batch_size;
707
+ let remaining_rows = batch. num_rows ( ) - start;
708
+ batch. slice ( start, batch_size. min ( remaining_rows) )
709
+ } )
710
+ . collect :: < Vec < _ > > ( )
711
+ } else {
712
+ vec ! [ batch]
713
+ } ;
714
+
715
+ let mut exec =
716
+ MemoryExec :: try_new ( & [ batches] , Arc :: clone ( & schema) , None ) . unwrap ( ) ;
717
+ if !sorted_column_names. is_empty ( ) {
718
+ let mut sort_info = Vec :: new ( ) ;
719
+ for name in sorted_column_names {
720
+ let index = schema. index_of ( name) . unwrap ( ) ;
721
+ let sort_expr = PhysicalSortExpr {
722
+ expr : Arc :: new ( Column :: new ( name, index) ) ,
723
+ options : SortOptions {
724
+ descending : false ,
725
+ nulls_first : false ,
726
+ } ,
727
+ } ;
728
+ sort_info. push ( sort_expr) ;
729
+ }
730
+ exec = exec. with_sort_information ( vec ! [ sort_info] ) ;
731
+ }
732
+
733
+ Arc :: new ( exec)
666
734
}
667
735
668
736
fn build_left_table ( ) -> Arc < dyn ExecutionPlan > {
669
737
build_table (
670
738
( "a1" , & vec ! [ 5 , 9 , 11 ] ) ,
671
739
( "b1" , & vec ! [ 5 , 8 , 8 ] ) ,
672
740
( "c1" , & vec ! [ 50 , 90 , 110 ] ) ,
741
+ None ,
742
+ Vec :: new ( ) ,
673
743
)
674
744
}
675
745
@@ -678,6 +748,8 @@ mod tests {
678
748
( "a2" , & vec ! [ 12 , 2 , 10 ] ) ,
679
749
( "b2" , & vec ! [ 10 , 2 , 10 ] ) ,
680
750
( "c2" , & vec ! [ 40 , 80 , 100 ] ) ,
751
+ None ,
752
+ Vec :: new ( ) ,
681
753
)
682
754
}
683
755
@@ -1005,11 +1077,15 @@ mod tests {
1005
1077
( "a1" , & vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 0 ] ) ,
1006
1078
( "b1" , & vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 0 ] ) ,
1007
1079
( "c1" , & vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 0 ] ) ,
1080
+ None ,
1081
+ Vec :: new ( ) ,
1008
1082
) ;
1009
1083
let right = build_table (
1010
1084
( "a2" , & vec ! [ 10 , 11 ] ) ,
1011
1085
( "b2" , & vec ! [ 12 , 13 ] ) ,
1012
1086
( "c2" , & vec ! [ 14 , 15 ] ) ,
1087
+ None ,
1088
+ Vec :: new ( ) ,
1013
1089
) ;
1014
1090
let filter = prepare_join_filter ( ) ;
1015
1091
@@ -1050,6 +1126,164 @@ mod tests {
1050
1126
Ok ( ( ) )
1051
1127
}
1052
1128
1129
+ fn prepare_mod_join_filter ( ) -> JoinFilter {
1130
+ let column_indices = vec ! [
1131
+ ColumnIndex {
1132
+ index: 1 ,
1133
+ side: JoinSide :: Left ,
1134
+ } ,
1135
+ ColumnIndex {
1136
+ index: 1 ,
1137
+ side: JoinSide :: Right ,
1138
+ } ,
1139
+ ] ;
1140
+ let intermediate_schema = Schema :: new ( vec ! [
1141
+ Field :: new( "x" , DataType :: Int32 , true ) ,
1142
+ Field :: new( "x" , DataType :: Int32 , true ) ,
1143
+ ] ) ;
1144
+
1145
+ // left.b1 % 3
1146
+ let left_mod = Arc :: new ( BinaryExpr :: new (
1147
+ Arc :: new ( Column :: new ( "x" , 0 ) ) ,
1148
+ Operator :: Modulo ,
1149
+ Arc :: new ( Literal :: new ( ScalarValue :: Int32 ( Some ( 3 ) ) ) ) ,
1150
+ ) ) as Arc < dyn PhysicalExpr > ;
1151
+ // left.b1 % 3 != 0
1152
+ let left_filter = Arc :: new ( BinaryExpr :: new (
1153
+ left_mod,
1154
+ Operator :: NotEq ,
1155
+ Arc :: new ( Literal :: new ( ScalarValue :: Int32 ( Some ( 0 ) ) ) ) ,
1156
+ ) ) as Arc < dyn PhysicalExpr > ;
1157
+
1158
+ // right.b2 % 5
1159
+ let right_mod = Arc :: new ( BinaryExpr :: new (
1160
+ Arc :: new ( Column :: new ( "x" , 1 ) ) ,
1161
+ Operator :: Modulo ,
1162
+ Arc :: new ( Literal :: new ( ScalarValue :: Int32 ( Some ( 5 ) ) ) ) ,
1163
+ ) ) as Arc < dyn PhysicalExpr > ;
1164
+ // right.b2 % 5 != 0
1165
+ let right_filter = Arc :: new ( BinaryExpr :: new (
1166
+ right_mod,
1167
+ Operator :: NotEq ,
1168
+ Arc :: new ( Literal :: new ( ScalarValue :: Int32 ( Some ( 0 ) ) ) ) ,
1169
+ ) ) as Arc < dyn PhysicalExpr > ;
1170
+ // filter = left.b1 % 3 != 0 and right.b2 % 5 != 0
1171
+ let filter_expression =
1172
+ Arc :: new ( BinaryExpr :: new ( left_filter, Operator :: And , right_filter) )
1173
+ as Arc < dyn PhysicalExpr > ;
1174
+
1175
+ JoinFilter :: new ( filter_expression, column_indices, intermediate_schema)
1176
+ }
1177
+
1178
+ fn generate_columns ( num_columns : usize , num_rows : usize ) -> Vec < Vec < i32 > > {
1179
+ let column = ( 1 ..=num_rows) . map ( |x| x as i32 ) . collect ( ) ;
1180
+ vec ! [ column; num_columns]
1181
+ }
1182
+
1183
+ #[ rstest]
1184
+ #[ tokio:: test]
1185
+ async fn join_maintains_right_order (
1186
+ #[ values(
1187
+ JoinType :: Inner ,
1188
+ JoinType :: Right ,
1189
+ JoinType :: RightAnti ,
1190
+ JoinType :: RightSemi
1191
+ ) ]
1192
+ join_type : JoinType ,
1193
+ #[ values( 1 , 100 , 1000 ) ] left_batch_size : usize ,
1194
+ #[ values( 1 , 100 , 1000 ) ] right_batch_size : usize ,
1195
+ ) -> Result < ( ) > {
1196
+ let left_columns = generate_columns ( 3 , 1000 ) ;
1197
+ let left = build_table (
1198
+ ( "a1" , & left_columns[ 0 ] ) ,
1199
+ ( "b1" , & left_columns[ 1 ] ) ,
1200
+ ( "c1" , & left_columns[ 2 ] ) ,
1201
+ Some ( left_batch_size) ,
1202
+ Vec :: new ( ) ,
1203
+ ) ;
1204
+
1205
+ let right_columns = generate_columns ( 3 , 1000 ) ;
1206
+ let right = build_table (
1207
+ ( "a2" , & right_columns[ 0 ] ) ,
1208
+ ( "b2" , & right_columns[ 1 ] ) ,
1209
+ ( "c2" , & right_columns[ 2 ] ) ,
1210
+ Some ( right_batch_size) ,
1211
+ vec ! [ "a2" , "b2" , "c2" ] ,
1212
+ ) ;
1213
+
1214
+ let filter = prepare_mod_join_filter ( ) ;
1215
+
1216
+ let nested_loop_join = Arc :: new ( NestedLoopJoinExec :: try_new (
1217
+ left,
1218
+ Arc :: clone ( & right) ,
1219
+ Some ( filter) ,
1220
+ & join_type,
1221
+ ) ?) as Arc < dyn ExecutionPlan > ;
1222
+ assert_eq ! ( nested_loop_join. maintains_input_order( ) , vec![ false , true ] ) ;
1223
+
1224
+ let right_column_indices = match join_type {
1225
+ JoinType :: Inner | JoinType :: Right => vec ! [ 3 , 4 , 5 ] ,
1226
+ JoinType :: RightAnti | JoinType :: RightSemi => vec ! [ 0 , 1 , 2 ] ,
1227
+ _ => unreachable ! ( ) ,
1228
+ } ;
1229
+
1230
+ let right_ordering = right. output_ordering ( ) . unwrap ( ) ;
1231
+ let join_ordering = nested_loop_join. output_ordering ( ) . unwrap ( ) ;
1232
+ for ( right, join) in right_ordering. iter ( ) . zip ( join_ordering. iter ( ) ) {
1233
+ let right_column = right. expr . as_any ( ) . downcast_ref :: < Column > ( ) . unwrap ( ) ;
1234
+ let join_column = join. expr . as_any ( ) . downcast_ref :: < Column > ( ) . unwrap ( ) ;
1235
+ assert_eq ! ( join_column. name( ) , join_column. name( ) ) ;
1236
+ assert_eq ! (
1237
+ right_column_indices[ right_column. index( ) ] ,
1238
+ join_column. index( )
1239
+ ) ;
1240
+ assert_eq ! ( right. options, join. options) ;
1241
+ }
1242
+
1243
+ let batches = nested_loop_join
1244
+ . execute ( 0 , Arc :: new ( TaskContext :: default ( ) ) ) ?
1245
+ . try_collect :: < Vec < _ > > ( )
1246
+ . await ?;
1247
+
1248
+ // Make sure that the order of the right side is maintained
1249
+ let mut prev_values = [ i32:: MIN , i32:: MIN , i32:: MIN ] ;
1250
+
1251
+ for ( batch_index, batch) in batches. iter ( ) . enumerate ( ) {
1252
+ let columns: Vec < _ > = right_column_indices
1253
+ . iter ( )
1254
+ . map ( |& i| {
1255
+ batch
1256
+ . column ( i)
1257
+ . as_any ( )
1258
+ . downcast_ref :: < Int32Array > ( )
1259
+ . unwrap ( )
1260
+ } )
1261
+ . collect ( ) ;
1262
+
1263
+ for row in 0 ..batch. num_rows ( ) {
1264
+ let current_values = [
1265
+ columns[ 0 ] . value ( row) ,
1266
+ columns[ 1 ] . value ( row) ,
1267
+ columns[ 2 ] . value ( row) ,
1268
+ ] ;
1269
+ assert ! (
1270
+ current_values
1271
+ . into_iter( )
1272
+ . zip( prev_values)
1273
+ . all( |( current, prev) | current >= prev) ,
1274
+ "batch_index: {} row: {} current: {:?}, prev: {:?}" ,
1275
+ batch_index,
1276
+ row,
1277
+ current_values,
1278
+ prev_values
1279
+ ) ;
1280
+ prev_values = current_values;
1281
+ }
1282
+ }
1283
+
1284
+ Ok ( ( ) )
1285
+ }
1286
+
1053
1287
/// Returns the column names on the schema
1054
1288
fn columns ( schema : & Schema ) -> Vec < String > {
1055
1289
schema. fields ( ) . iter ( ) . map ( |f| f. name ( ) . clone ( ) ) . collect ( )
0 commit comments