@@ -487,7 +487,6 @@ struct StreamedBatch {
487
487
/// The join key arrays of streamed batch which are used to compare with buffered batches
488
488
/// and to produce output. They are produced by evaluating `on` expressions.
489
489
pub join_arrays : Vec < ArrayRef > ,
490
-
491
490
/// Chunks of indices from buffered side (may be nulls) joined to streamed
492
491
pub output_indices : Vec < StreamedJoinedChunk > ,
493
492
/// Index of currently scanned batch from buffered data
@@ -1021,6 +1020,15 @@ impl SMJStream {
1021
1020
join_streamed = true ;
1022
1021
join_buffered = true ;
1023
1022
} ;
1023
+
1024
+ if matches ! ( self . join_type, JoinType :: LeftAnti ) && self . filter . is_some ( ) {
1025
+ join_streamed = !self
1026
+ . streamed_batch
1027
+ . join_filter_matched_idxs
1028
+ . contains ( & ( self . streamed_batch . idx as u64 ) )
1029
+ && !self . streamed_joined ;
1030
+ join_buffered = join_streamed;
1031
+ }
1024
1032
}
1025
1033
Ordering :: Greater => {
1026
1034
if matches ! ( self . join_type, JoinType :: Full ) {
@@ -1181,7 +1189,10 @@ impl SMJStream {
1181
1189
let filter_columns = if chunk. buffered_batch_idx . is_some ( ) {
1182
1190
if matches ! ( self . join_type, JoinType :: Right ) {
1183
1191
get_filter_column ( & self . filter , & buffered_columns, & streamed_columns)
1184
- } else if matches ! ( self . join_type, JoinType :: LeftSemi ) {
1192
+ } else if matches ! (
1193
+ self . join_type,
1194
+ JoinType :: LeftSemi | JoinType :: LeftAnti
1195
+ ) {
1185
1196
// unwrap is safe here as we check is_some on top of if statement
1186
1197
let buffered_columns = get_buffered_columns (
1187
1198
& self . buffered_data ,
@@ -1228,7 +1239,15 @@ impl SMJStream {
1228
1239
datafusion_common:: cast:: as_boolean_array ( & filter_result) ?;
1229
1240
1230
1241
let maybe_filtered_join_mask: Option < ( BooleanArray , Vec < u64 > ) > =
1231
- get_filtered_join_mask ( self . join_type , streamed_indices, mask) ;
1242
+ get_filtered_join_mask (
1243
+ self . join_type ,
1244
+ streamed_indices,
1245
+ mask,
1246
+ & self . streamed_batch . join_filter_matched_idxs ,
1247
+ & self . buffered_data . scanning_batch_idx ,
1248
+ & self . buffered_data . batches . len ( ) ,
1249
+ ) ;
1250
+
1232
1251
if let Some ( ref filtered_join_mask) = maybe_filtered_join_mask {
1233
1252
mask = & filtered_join_mask. 0 ;
1234
1253
self . streamed_batch
@@ -1419,51 +1438,87 @@ fn get_buffered_columns(
1419
1438
. collect :: < Result < Vec < _ > , ArrowError > > ( )
1420
1439
}
1421
1440
1422
- // Calculate join filter bit mask considering join type specifics
1423
- // `streamed_indices` - array of streamed datasource JOINED row indices
1424
- // `mask` - array booleans representing computed join filter expression eval result:
1425
- // true = the row index matches the join filter
1426
- // false = the row index doesn't match the join filter
1427
- // `streamed_indices` have the same length as `mask`
1441
+ /// Calculate join filter bit mask considering join type specifics
1442
+ /// `streamed_indices` - array of streamed datasource JOINED row indices
1443
+ /// `mask` - array booleans representing computed join filter expression eval result:
1444
+ /// true = the row index matches the join filter
1445
+ /// false = the row index doesn't match the join filter
1446
+ /// `streamed_indices` have the same length as `mask`
1447
+ /// `matched_indices` array of streaming indices that already has a join filter match
1448
+ /// `scanning_batch_idx` current buffered batch
1449
+ /// `buffered_batches_len` how many batches are in buffered data
1428
1450
fn get_filtered_join_mask (
1429
1451
join_type : JoinType ,
1430
1452
streamed_indices : UInt64Array ,
1431
1453
mask : & BooleanArray ,
1454
+ matched_indices : & HashSet < u64 > ,
1455
+ scanning_buffered_batch_idx : & usize ,
1456
+ buffered_batches_len : & usize ,
1432
1457
) -> Option < ( BooleanArray , Vec < u64 > ) > {
1433
- // for LeftSemi Join the filter mask should be calculated in its own way:
1434
- // if we find at least one matching row for specific streaming index
1435
- // we don't need to check any others for the same index
1436
- if matches ! ( join_type, JoinType :: LeftSemi ) {
1437
- // have we seen a filter match for a streaming index before
1438
- let mut seen_as_true: bool = false ;
1439
- let streamed_indices_length = streamed_indices. len ( ) ;
1440
- let mut corrected_mask: BooleanBuilder =
1441
- BooleanBuilder :: with_capacity ( streamed_indices_length) ;
1442
-
1443
- let mut filter_matched_indices: Vec < u64 > = vec ! [ ] ;
1444
-
1445
- #[ allow( clippy:: needless_range_loop) ]
1446
- for i in 0 ..streamed_indices_length {
1447
- // LeftSemi respects only first true values for specific streaming index,
1448
- // others true values for the same index must be false
1449
- if mask. value ( i) && !seen_as_true {
1450
- seen_as_true = true ;
1451
- corrected_mask. append_value ( true ) ;
1452
- filter_matched_indices. push ( streamed_indices. value ( i) ) ;
1453
- } else {
1454
- corrected_mask. append_value ( false ) ;
1458
+ let mut seen_as_true: bool = false ;
1459
+ let streamed_indices_length = streamed_indices. len ( ) ;
1460
+ let mut corrected_mask: BooleanBuilder =
1461
+ BooleanBuilder :: with_capacity ( streamed_indices_length) ;
1462
+
1463
+ let mut filter_matched_indices: Vec < u64 > = vec ! [ ] ;
1464
+
1465
+ #[ allow( clippy:: needless_range_loop) ]
1466
+ match join_type {
1467
+ // for LeftSemi Join the filter mask should be calculated in its own way:
1468
+ // if we find at least one matching row for specific streaming index
1469
+ // we don't need to check any others for the same index
1470
+ JoinType :: LeftSemi => {
1471
+ // have we seen a filter match for a streaming index before
1472
+ for i in 0 ..streamed_indices_length {
1473
+ // LeftSemi respects only first true values for specific streaming index,
1474
+ // others true values for the same index must be false
1475
+ if mask. value ( i) && !seen_as_true {
1476
+ seen_as_true = true ;
1477
+ corrected_mask. append_value ( true ) ;
1478
+ filter_matched_indices. push ( streamed_indices. value ( i) ) ;
1479
+ } else {
1480
+ corrected_mask. append_value ( false ) ;
1481
+ }
1482
+
1483
+ // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag
1484
+ if i < streamed_indices_length - 1
1485
+ && streamed_indices. value ( i) != streamed_indices. value ( i + 1 )
1486
+ {
1487
+ seen_as_true = false ;
1488
+ }
1455
1489
}
1490
+ Some ( ( corrected_mask. finish ( ) , filter_matched_indices) )
1491
+ }
1492
+ // LeftAnti semantics: return true if for every x in the collection, p(x) is false.
1493
+ // the true(if any) flag needs to be set only once per streaming index
1494
+ // to prevent duplicates in the output
1495
+ JoinType :: LeftAnti => {
1496
+ // have we seen a filter match for a streaming index before
1497
+ for i in 0 ..streamed_indices_length {
1498
+ if mask. value ( i) && !seen_as_true {
1499
+ seen_as_true = true ;
1500
+ filter_matched_indices. push ( streamed_indices. value ( i) ) ;
1501
+ }
1456
1502
1457
- // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag
1458
- if i < streamed_indices_length - 1
1459
- && streamed_indices. value ( i) != streamed_indices. value ( i + 1 )
1460
- {
1461
- seen_as_true = false ;
1503
+ // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag
1504
+ if ( i < streamed_indices_length - 1
1505
+ && streamed_indices. value ( i) != streamed_indices. value ( i + 1 ) )
1506
+ || ( i == streamed_indices_length - 1
1507
+ && * scanning_buffered_batch_idx == buffered_batches_len - 1 )
1508
+ {
1509
+ corrected_mask. append_value (
1510
+ !matched_indices. contains ( & streamed_indices. value ( i) )
1511
+ && !seen_as_true,
1512
+ ) ;
1513
+ seen_as_true = false ;
1514
+ } else {
1515
+ corrected_mask. append_value ( false ) ;
1516
+ }
1462
1517
}
1518
+
1519
+ Some ( ( corrected_mask. finish ( ) , filter_matched_indices) )
1463
1520
}
1464
- Some ( ( corrected_mask. finish ( ) , filter_matched_indices) )
1465
- } else {
1466
- None
1521
+ _ => None ,
1467
1522
}
1468
1523
}
1469
1524
@@ -1711,8 +1766,9 @@ mod tests {
1711
1766
use arrow:: datatypes:: { DataType , Field , Schema } ;
1712
1767
use arrow:: record_batch:: RecordBatch ;
1713
1768
use arrow_array:: { BooleanArray , UInt64Array } ;
1769
+ use hashbrown:: HashSet ;
1714
1770
1715
- use datafusion_common:: JoinType :: LeftSemi ;
1771
+ use datafusion_common:: JoinType :: { LeftAnti , LeftSemi } ;
1716
1772
use datafusion_common:: {
1717
1773
assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType , Result ,
1718
1774
} ;
@@ -2754,7 +2810,10 @@ mod tests {
2754
2810
get_filtered_join_mask(
2755
2811
LeftSemi ,
2756
2812
UInt64Array :: from( vec![ 0 , 0 , 1 , 1 ] ) ,
2757
- & BooleanArray :: from( vec![ true , true , false , false ] )
2813
+ & BooleanArray :: from( vec![ true , true , false , false ] ) ,
2814
+ & HashSet :: new( ) ,
2815
+ & 0 ,
2816
+ & 0
2758
2817
) ,
2759
2818
Some ( ( BooleanArray :: from( vec![ true , false , false , false ] ) , vec![ 0 ] ) )
2760
2819
) ;
@@ -2763,7 +2822,10 @@ mod tests {
2763
2822
get_filtered_join_mask(
2764
2823
LeftSemi ,
2765
2824
UInt64Array :: from( vec![ 0 , 1 ] ) ,
2766
- & BooleanArray :: from( vec![ true , true ] )
2825
+ & BooleanArray :: from( vec![ true , true ] ) ,
2826
+ & HashSet :: new( ) ,
2827
+ & 0 ,
2828
+ & 0
2767
2829
) ,
2768
2830
Some ( ( BooleanArray :: from( vec![ true , true ] ) , vec![ 0 , 1 ] ) )
2769
2831
) ;
@@ -2772,7 +2834,10 @@ mod tests {
2772
2834
get_filtered_join_mask(
2773
2835
LeftSemi ,
2774
2836
UInt64Array :: from( vec![ 0 , 1 ] ) ,
2775
- & BooleanArray :: from( vec![ false , true ] )
2837
+ & BooleanArray :: from( vec![ false , true ] ) ,
2838
+ & HashSet :: new( ) ,
2839
+ & 0 ,
2840
+ & 0
2776
2841
) ,
2777
2842
Some ( ( BooleanArray :: from( vec![ false , true ] ) , vec![ 1 ] ) )
2778
2843
) ;
@@ -2781,7 +2846,10 @@ mod tests {
2781
2846
get_filtered_join_mask(
2782
2847
LeftSemi ,
2783
2848
UInt64Array :: from( vec![ 0 , 1 ] ) ,
2784
- & BooleanArray :: from( vec![ true , false ] )
2849
+ & BooleanArray :: from( vec![ true , false ] ) ,
2850
+ & HashSet :: new( ) ,
2851
+ & 0 ,
2852
+ & 0
2785
2853
) ,
2786
2854
Some ( ( BooleanArray :: from( vec![ true , false ] ) , vec![ 0 ] ) )
2787
2855
) ;
@@ -2790,7 +2858,10 @@ mod tests {
2790
2858
get_filtered_join_mask(
2791
2859
LeftSemi ,
2792
2860
UInt64Array :: from( vec![ 0 , 0 , 0 , 1 , 1 , 1 ] ) ,
2793
- & BooleanArray :: from( vec![ false , true , true , true , true , true ] )
2861
+ & BooleanArray :: from( vec![ false , true , true , true , true , true ] ) ,
2862
+ & HashSet :: new( ) ,
2863
+ & 0 ,
2864
+ & 0
2794
2865
) ,
2795
2866
Some ( (
2796
2867
BooleanArray :: from( vec![ false , true , false , true , false , false ] ) ,
@@ -2802,7 +2873,10 @@ mod tests {
2802
2873
get_filtered_join_mask(
2803
2874
LeftSemi ,
2804
2875
UInt64Array :: from( vec![ 0 , 0 , 0 , 1 , 1 , 1 ] ) ,
2805
- & BooleanArray :: from( vec![ false , false , false , false , false , true ] )
2876
+ & BooleanArray :: from( vec![ false , false , false , false , false , true ] ) ,
2877
+ & HashSet :: new( ) ,
2878
+ & 0 ,
2879
+ & 0
2806
2880
) ,
2807
2881
Some ( (
2808
2882
BooleanArray :: from( vec![ false , false , false , false , false , true ] ) ,
@@ -2813,6 +2887,89 @@ mod tests {
2813
2887
Ok ( ( ) )
2814
2888
}
2815
2889
2890
+ #[ tokio:: test]
2891
+ async fn left_anti_join_filtered_mask ( ) -> Result < ( ) > {
2892
+ assert_eq ! (
2893
+ get_filtered_join_mask(
2894
+ LeftAnti ,
2895
+ UInt64Array :: from( vec![ 0 , 0 , 1 , 1 ] ) ,
2896
+ & BooleanArray :: from( vec![ true , true , false , false ] ) ,
2897
+ & HashSet :: new( ) ,
2898
+ & 0 ,
2899
+ & 1
2900
+ ) ,
2901
+ Some ( ( BooleanArray :: from( vec![ false , false , false , true ] ) , vec![ 0 ] ) )
2902
+ ) ;
2903
+
2904
+ assert_eq ! (
2905
+ get_filtered_join_mask(
2906
+ LeftAnti ,
2907
+ UInt64Array :: from( vec![ 0 , 1 ] ) ,
2908
+ & BooleanArray :: from( vec![ true , true ] ) ,
2909
+ & HashSet :: new( ) ,
2910
+ & 0 ,
2911
+ & 1
2912
+ ) ,
2913
+ Some ( ( BooleanArray :: from( vec![ false , false ] ) , vec![ 0 , 1 ] ) )
2914
+ ) ;
2915
+
2916
+ assert_eq ! (
2917
+ get_filtered_join_mask(
2918
+ LeftAnti ,
2919
+ UInt64Array :: from( vec![ 0 , 1 ] ) ,
2920
+ & BooleanArray :: from( vec![ false , true ] ) ,
2921
+ & HashSet :: new( ) ,
2922
+ & 0 ,
2923
+ & 1
2924
+ ) ,
2925
+ Some ( ( BooleanArray :: from( vec![ true , false ] ) , vec![ 1 ] ) )
2926
+ ) ;
2927
+
2928
+ assert_eq ! (
2929
+ get_filtered_join_mask(
2930
+ LeftAnti ,
2931
+ UInt64Array :: from( vec![ 0 , 1 ] ) ,
2932
+ & BooleanArray :: from( vec![ true , false ] ) ,
2933
+ & HashSet :: new( ) ,
2934
+ & 0 ,
2935
+ & 1
2936
+ ) ,
2937
+ Some ( ( BooleanArray :: from( vec![ false , true ] ) , vec![ 0 ] ) )
2938
+ ) ;
2939
+
2940
+ assert_eq ! (
2941
+ get_filtered_join_mask(
2942
+ LeftAnti ,
2943
+ UInt64Array :: from( vec![ 0 , 0 , 0 , 1 , 1 , 1 ] ) ,
2944
+ & BooleanArray :: from( vec![ false , true , true , true , true , true ] ) ,
2945
+ & HashSet :: new( ) ,
2946
+ & 0 ,
2947
+ & 1
2948
+ ) ,
2949
+ Some ( (
2950
+ BooleanArray :: from( vec![ false , false , false , false , false , false ] ) ,
2951
+ vec![ 0 , 1 ]
2952
+ ) )
2953
+ ) ;
2954
+
2955
+ assert_eq ! (
2956
+ get_filtered_join_mask(
2957
+ LeftAnti ,
2958
+ UInt64Array :: from( vec![ 0 , 0 , 0 , 1 , 1 , 1 ] ) ,
2959
+ & BooleanArray :: from( vec![ false , false , false , false , false , true ] ) ,
2960
+ & HashSet :: new( ) ,
2961
+ & 0 ,
2962
+ & 1
2963
+ ) ,
2964
+ Some ( (
2965
+ BooleanArray :: from( vec![ false , false , true , false , false , false ] ) ,
2966
+ vec![ 1 ]
2967
+ ) )
2968
+ ) ;
2969
+
2970
+ Ok ( ( ) )
2971
+ }
2972
+
2816
2973
/// Returns the column names on the schema
2817
2974
fn columns ( schema : & Schema ) -> Vec < String > {
2818
2975
schema. fields ( ) . iter ( ) . map ( |f| f. name ( ) . clone ( ) ) . collect ( )
0 commit comments