@@ -30,8 +30,8 @@ use crate::error::Result;
30
30
use crate :: physical_optimizer:: PhysicalOptimizerRule ;
31
31
use crate :: physical_plan:: joins:: utils:: { ColumnIndex , JoinFilter } ;
32
32
use crate :: physical_plan:: joins:: {
33
- CrossJoinExec , HashJoinExec , PartitionMode , StreamJoinPartitionMode ,
34
- SymmetricHashJoinExec ,
33
+ CrossJoinExec , HashJoinExec , NestedLoopJoinExec , PartitionMode ,
34
+ StreamJoinPartitionMode , SymmetricHashJoinExec ,
35
35
} ;
36
36
use crate :: physical_plan:: projection:: ProjectionExec ;
37
37
use crate :: physical_plan:: { ExecutionPlan , ExecutionPlanProperties } ;
@@ -199,6 +199,38 @@ fn swap_hash_join(
199
199
}
200
200
}
201
201
202
+ /// Swaps inputs of `NestedLoopJoinExec` and wraps it into `ProjectionExec` is required
203
+ fn swap_nl_join ( join : & NestedLoopJoinExec ) -> Result < Arc < dyn ExecutionPlan > > {
204
+ let new_filter = swap_join_filter ( join. filter ( ) ) ;
205
+ let new_join_type = & swap_join_type ( * join. join_type ( ) ) ;
206
+
207
+ let new_join = NestedLoopJoinExec :: try_new (
208
+ Arc :: clone ( join. right ( ) ) ,
209
+ Arc :: clone ( join. left ( ) ) ,
210
+ new_filter,
211
+ new_join_type,
212
+ ) ?;
213
+
214
+ // For Semi/Anti joins, swap result will produce same output schema,
215
+ // no need to wrap them into additional projection
216
+ let plan: Arc < dyn ExecutionPlan > = if matches ! (
217
+ join. join_type( ) ,
218
+ JoinType :: LeftSemi
219
+ | JoinType :: RightSemi
220
+ | JoinType :: LeftAnti
221
+ | JoinType :: RightAnti
222
+ ) {
223
+ Arc :: new ( new_join)
224
+ } else {
225
+ let projection =
226
+ swap_reverting_projection ( & join. left ( ) . schema ( ) , & join. right ( ) . schema ( ) ) ;
227
+
228
+ Arc :: new ( ProjectionExec :: try_new ( projection, Arc :: new ( new_join) ) ?)
229
+ } ;
230
+
231
+ Ok ( plan)
232
+ }
233
+
202
234
/// When the order of the join is changed by the optimizer, the columns in
203
235
/// the output should not be impacted. This function creates the expressions
204
236
/// that will allow to swap back the values from the original left as the first
@@ -438,6 +470,14 @@ fn statistical_join_selection_subrule(
438
470
} else {
439
471
None
440
472
}
473
+ } else if let Some ( nl_join) = plan. as_any ( ) . downcast_ref :: < NestedLoopJoinExec > ( ) {
474
+ let left = nl_join. left ( ) ;
475
+ let right = nl_join. right ( ) ;
476
+ if should_swap_join_order ( & * * left, & * * right) ? {
477
+ swap_nl_join ( nl_join) . map ( Some ) ?
478
+ } else {
479
+ None
480
+ }
441
481
} else {
442
482
None
443
483
} ;
@@ -674,9 +714,12 @@ mod tests_statistical {
674
714
675
715
use arrow:: datatypes:: { DataType , Field , Schema } ;
676
716
use datafusion_common:: { stats:: Precision , JoinType , ScalarValue } ;
677
- use datafusion_physical_expr:: expressions:: Column ;
717
+ use datafusion_expr:: Operator ;
718
+ use datafusion_physical_expr:: expressions:: { BinaryExpr , Column } ;
678
719
use datafusion_physical_expr:: { PhysicalExpr , PhysicalExprRef } ;
679
720
721
+ use rstest:: rstest;
722
+
680
723
/// Return statistcs for empty table
681
724
fn empty_statistics ( ) -> Statistics {
682
725
Statistics {
@@ -762,6 +805,35 @@ mod tests_statistical {
762
805
} ]
763
806
}
764
807
808
+ /// Create join filter for NLJoinExec with expression `big_col > small_col`
809
+ /// where both columns are 0-indexed and come from left and right inputs respectively
810
+ fn nl_join_filter ( ) -> Option < JoinFilter > {
811
+ let column_indices = vec ! [
812
+ ColumnIndex {
813
+ index: 0 ,
814
+ side: JoinSide :: Left ,
815
+ } ,
816
+ ColumnIndex {
817
+ index: 0 ,
818
+ side: JoinSide :: Right ,
819
+ } ,
820
+ ] ;
821
+ let intermediate_schema = Schema :: new ( vec ! [
822
+ Field :: new( "big_col" , DataType :: Int32 , false ) ,
823
+ Field :: new( "small_col" , DataType :: Int32 , false ) ,
824
+ ] ) ;
825
+ let expression = Arc :: new ( BinaryExpr :: new (
826
+ Arc :: new ( Column :: new_with_schema ( "big_col" , & intermediate_schema) . unwrap ( ) ) ,
827
+ Operator :: Gt ,
828
+ Arc :: new ( Column :: new_with_schema ( "small_col" , & intermediate_schema) . unwrap ( ) ) ,
829
+ ) ) as _ ;
830
+ Some ( JoinFilter :: new (
831
+ expression,
832
+ column_indices,
833
+ intermediate_schema,
834
+ ) )
835
+ }
836
+
765
837
/// Returns three plans with statistics of (min, max, distinct_count)
766
838
/// * big 100K rows @ (0, 50k, 50k)
767
839
/// * medium 10K rows @ (1k, 5k, 1k)
@@ -1114,6 +1186,137 @@ mod tests_statistical {
1114
1186
crosscheck_plans ( join) . unwrap ( ) ;
1115
1187
}
1116
1188
1189
+ #[ rstest(
1190
+ join_type,
1191
+ case:: inner( JoinType :: Inner ) ,
1192
+ case:: left( JoinType :: Left ) ,
1193
+ case:: right( JoinType :: Right ) ,
1194
+ case:: full( JoinType :: Full )
1195
+ ) ]
1196
+ #[ tokio:: test]
1197
+ async fn test_nl_join_with_swap ( join_type : JoinType ) {
1198
+ let ( big, small) = create_big_and_small ( ) ;
1199
+
1200
+ let join = Arc :: new (
1201
+ NestedLoopJoinExec :: try_new (
1202
+ Arc :: clone ( & big) ,
1203
+ Arc :: clone ( & small) ,
1204
+ nl_join_filter ( ) ,
1205
+ & join_type,
1206
+ )
1207
+ . unwrap ( ) ,
1208
+ ) ;
1209
+
1210
+ let optimized_join = JoinSelection :: new ( )
1211
+ . optimize ( join. clone ( ) , & ConfigOptions :: new ( ) )
1212
+ . unwrap ( ) ;
1213
+
1214
+ let swapping_projection = optimized_join
1215
+ . as_any ( )
1216
+ . downcast_ref :: < ProjectionExec > ( )
1217
+ . expect ( "A proj is required to swap columns back to their original order" ) ;
1218
+
1219
+ assert_eq ! ( swapping_projection. expr( ) . len( ) , 2 ) ;
1220
+ let ( col, name) = & swapping_projection. expr ( ) [ 0 ] ;
1221
+ assert_eq ! ( name, "big_col" ) ;
1222
+ assert_col_expr ( col, "big_col" , 1 ) ;
1223
+ let ( col, name) = & swapping_projection. expr ( ) [ 1 ] ;
1224
+ assert_eq ! ( name, "small_col" ) ;
1225
+ assert_col_expr ( col, "small_col" , 0 ) ;
1226
+
1227
+ let swapped_join = swapping_projection
1228
+ . input ( )
1229
+ . as_any ( )
1230
+ . downcast_ref :: < NestedLoopJoinExec > ( )
1231
+ . expect ( "The type of the plan should not be changed" ) ;
1232
+
1233
+ // Assert join side of big_col swapped in filter expression
1234
+ let swapped_filter = swapped_join. filter ( ) . unwrap ( ) ;
1235
+ let swapped_big_col_idx = swapped_filter. schema ( ) . index_of ( "big_col" ) . unwrap ( ) ;
1236
+ let swapped_big_col_side = swapped_filter
1237
+ . column_indices ( )
1238
+ . get ( swapped_big_col_idx)
1239
+ . unwrap ( )
1240
+ . side ;
1241
+ assert_eq ! (
1242
+ swapped_big_col_side,
1243
+ JoinSide :: Right ,
1244
+ "Filter column side should be swapped"
1245
+ ) ;
1246
+
1247
+ assert_eq ! (
1248
+ swapped_join. left( ) . statistics( ) . unwrap( ) . total_byte_size,
1249
+ Precision :: Inexact ( 8192 )
1250
+ ) ;
1251
+ assert_eq ! (
1252
+ swapped_join. right( ) . statistics( ) . unwrap( ) . total_byte_size,
1253
+ Precision :: Inexact ( 2097152 )
1254
+ ) ;
1255
+ crosscheck_plans ( join. clone ( ) ) . unwrap ( ) ;
1256
+ }
1257
+
1258
+ #[ rstest(
1259
+ join_type,
1260
+ case:: left_semi( JoinType :: LeftSemi ) ,
1261
+ case:: left_anti( JoinType :: LeftAnti ) ,
1262
+ case:: right_semi( JoinType :: RightSemi ) ,
1263
+ case:: right_anti( JoinType :: RightAnti )
1264
+ ) ]
1265
+ #[ tokio:: test]
1266
+ async fn test_nl_join_with_swap_no_proj ( join_type : JoinType ) {
1267
+ let ( big, small) = create_big_and_small ( ) ;
1268
+
1269
+ let join = Arc :: new (
1270
+ NestedLoopJoinExec :: try_new (
1271
+ Arc :: clone ( & big) ,
1272
+ Arc :: clone ( & small) ,
1273
+ nl_join_filter ( ) ,
1274
+ & join_type,
1275
+ )
1276
+ . unwrap ( ) ,
1277
+ ) ;
1278
+
1279
+ let optimized_join = JoinSelection :: new ( )
1280
+ . optimize ( join. clone ( ) , & ConfigOptions :: new ( ) )
1281
+ . unwrap ( ) ;
1282
+
1283
+ let swapped_join = optimized_join
1284
+ . as_any ( )
1285
+ . downcast_ref :: < NestedLoopJoinExec > ( )
1286
+ . expect ( "The type of the plan should not be changed" ) ;
1287
+
1288
+ // Assert before/after schemas are equal
1289
+ assert_eq ! (
1290
+ join. schema( ) ,
1291
+ swapped_join. schema( ) ,
1292
+ "Join schema should not be modified while optimization"
1293
+ ) ;
1294
+
1295
+ // Assert join side of big_col swapped in filter expression
1296
+ let swapped_filter = swapped_join. filter ( ) . unwrap ( ) ;
1297
+ let swapped_big_col_idx = swapped_filter. schema ( ) . index_of ( "big_col" ) . unwrap ( ) ;
1298
+ let swapped_big_col_side = swapped_filter
1299
+ . column_indices ( )
1300
+ . get ( swapped_big_col_idx)
1301
+ . unwrap ( )
1302
+ . side ;
1303
+ assert_eq ! (
1304
+ swapped_big_col_side,
1305
+ JoinSide :: Right ,
1306
+ "Filter column side should be swapped"
1307
+ ) ;
1308
+
1309
+ assert_eq ! (
1310
+ swapped_join. left( ) . statistics( ) . unwrap( ) . total_byte_size,
1311
+ Precision :: Inexact ( 8192 )
1312
+ ) ;
1313
+ assert_eq ! (
1314
+ swapped_join. right( ) . statistics( ) . unwrap( ) . total_byte_size,
1315
+ Precision :: Inexact ( 2097152 )
1316
+ ) ;
1317
+ crosscheck_plans ( join. clone ( ) ) . unwrap ( ) ;
1318
+ }
1319
+
1117
1320
#[ tokio:: test]
1118
1321
async fn test_swap_reverting_projection ( ) {
1119
1322
let left_schema = Schema :: new ( vec ! [
0 commit comments