Skip to content

Commit d6ddd23

Browse files
authored
Fix SMJ Left Anti Join when the join filter is set (#10724)
* Fix: Sort Merge Join crashes on TPCH Q21 * Fix LeftAnti SMJ join when the join filter is set * rm dbg
1 parent 68f8476 commit d6ddd23

File tree

2 files changed

+306
-64
lines changed

2 files changed

+306
-64
lines changed

datafusion/physical-plan/src/joins/sort_merge_join.rs

+203-46
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,6 @@ struct StreamedBatch {
487487
/// The join key arrays of streamed batch which are used to compare with buffered batches
488488
/// and to produce output. They are produced by evaluating `on` expressions.
489489
pub join_arrays: Vec<ArrayRef>,
490-
491490
/// Chunks of indices from buffered side (may be nulls) joined to streamed
492491
pub output_indices: Vec<StreamedJoinedChunk>,
493492
/// Index of currently scanned batch from buffered data
@@ -1021,6 +1020,15 @@ impl SMJStream {
10211020
join_streamed = true;
10221021
join_buffered = true;
10231022
};
1023+
1024+
if matches!(self.join_type, JoinType::LeftAnti) && self.filter.is_some() {
1025+
join_streamed = !self
1026+
.streamed_batch
1027+
.join_filter_matched_idxs
1028+
.contains(&(self.streamed_batch.idx as u64))
1029+
&& !self.streamed_joined;
1030+
join_buffered = join_streamed;
1031+
}
10241032
}
10251033
Ordering::Greater => {
10261034
if matches!(self.join_type, JoinType::Full) {
@@ -1181,7 +1189,10 @@ impl SMJStream {
11811189
let filter_columns = if chunk.buffered_batch_idx.is_some() {
11821190
if matches!(self.join_type, JoinType::Right) {
11831191
get_filter_column(&self.filter, &buffered_columns, &streamed_columns)
1184-
} else if matches!(self.join_type, JoinType::LeftSemi) {
1192+
} else if matches!(
1193+
self.join_type,
1194+
JoinType::LeftSemi | JoinType::LeftAnti
1195+
) {
11851196
// unwrap is safe here as we check is_some on top of if statement
11861197
let buffered_columns = get_buffered_columns(
11871198
&self.buffered_data,
@@ -1228,7 +1239,15 @@ impl SMJStream {
12281239
datafusion_common::cast::as_boolean_array(&filter_result)?;
12291240

12301241
let maybe_filtered_join_mask: Option<(BooleanArray, Vec<u64>)> =
1231-
get_filtered_join_mask(self.join_type, streamed_indices, mask);
1242+
get_filtered_join_mask(
1243+
self.join_type,
1244+
streamed_indices,
1245+
mask,
1246+
&self.streamed_batch.join_filter_matched_idxs,
1247+
&self.buffered_data.scanning_batch_idx,
1248+
&self.buffered_data.batches.len(),
1249+
);
1250+
12321251
if let Some(ref filtered_join_mask) = maybe_filtered_join_mask {
12331252
mask = &filtered_join_mask.0;
12341253
self.streamed_batch
@@ -1419,51 +1438,87 @@ fn get_buffered_columns(
14191438
.collect::<Result<Vec<_>, ArrowError>>()
14201439
}
14211440

1422-
// Calculate join filter bit mask considering join type specifics
1423-
// `streamed_indices` - array of streamed datasource JOINED row indices
1424-
// `mask` - array booleans representing computed join filter expression eval result:
1425-
// true = the row index matches the join filter
1426-
// false = the row index doesn't match the join filter
1427-
// `streamed_indices` have the same length as `mask`
1441+
/// Calculate join filter bit mask considering join type specifics
1442+
/// `streamed_indices` - array of streamed datasource JOINED row indices
1443+
/// `mask` - array booleans representing computed join filter expression eval result:
1444+
/// true = the row index matches the join filter
1445+
/// false = the row index doesn't match the join filter
1446+
/// `streamed_indices` have the same length as `mask`
1447+
/// `matched_indices` array of streaming indices that already has a join filter match
1448+
/// `scanning_batch_idx` current buffered batch
1449+
/// `buffered_batches_len` how many batches are in buffered data
14281450
fn get_filtered_join_mask(
14291451
join_type: JoinType,
14301452
streamed_indices: UInt64Array,
14311453
mask: &BooleanArray,
1454+
matched_indices: &HashSet<u64>,
1455+
scanning_buffered_batch_idx: &usize,
1456+
buffered_batches_len: &usize,
14321457
) -> Option<(BooleanArray, Vec<u64>)> {
1433-
// for LeftSemi Join the filter mask should be calculated in its own way:
1434-
// if we find at least one matching row for specific streaming index
1435-
// we don't need to check any others for the same index
1436-
if matches!(join_type, JoinType::LeftSemi) {
1437-
// have we seen a filter match for a streaming index before
1438-
let mut seen_as_true: bool = false;
1439-
let streamed_indices_length = streamed_indices.len();
1440-
let mut corrected_mask: BooleanBuilder =
1441-
BooleanBuilder::with_capacity(streamed_indices_length);
1442-
1443-
let mut filter_matched_indices: Vec<u64> = vec![];
1444-
1445-
#[allow(clippy::needless_range_loop)]
1446-
for i in 0..streamed_indices_length {
1447-
// LeftSemi respects only first true values for specific streaming index,
1448-
// others true values for the same index must be false
1449-
if mask.value(i) && !seen_as_true {
1450-
seen_as_true = true;
1451-
corrected_mask.append_value(true);
1452-
filter_matched_indices.push(streamed_indices.value(i));
1453-
} else {
1454-
corrected_mask.append_value(false);
1458+
let mut seen_as_true: bool = false;
1459+
let streamed_indices_length = streamed_indices.len();
1460+
let mut corrected_mask: BooleanBuilder =
1461+
BooleanBuilder::with_capacity(streamed_indices_length);
1462+
1463+
let mut filter_matched_indices: Vec<u64> = vec![];
1464+
1465+
#[allow(clippy::needless_range_loop)]
1466+
match join_type {
1467+
// for LeftSemi Join the filter mask should be calculated in its own way:
1468+
// if we find at least one matching row for specific streaming index
1469+
// we don't need to check any others for the same index
1470+
JoinType::LeftSemi => {
1471+
// have we seen a filter match for a streaming index before
1472+
for i in 0..streamed_indices_length {
1473+
// LeftSemi respects only first true values for specific streaming index,
1474+
// others true values for the same index must be false
1475+
if mask.value(i) && !seen_as_true {
1476+
seen_as_true = true;
1477+
corrected_mask.append_value(true);
1478+
filter_matched_indices.push(streamed_indices.value(i));
1479+
} else {
1480+
corrected_mask.append_value(false);
1481+
}
1482+
1483+
// if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag
1484+
if i < streamed_indices_length - 1
1485+
&& streamed_indices.value(i) != streamed_indices.value(i + 1)
1486+
{
1487+
seen_as_true = false;
1488+
}
14551489
}
1490+
Some((corrected_mask.finish(), filter_matched_indices))
1491+
}
1492+
// LeftAnti semantics: return true if for every x in the collection, p(x) is false.
1493+
// the true(if any) flag needs to be set only once per streaming index
1494+
// to prevent duplicates in the output
1495+
JoinType::LeftAnti => {
1496+
// have we seen a filter match for a streaming index before
1497+
for i in 0..streamed_indices_length {
1498+
if mask.value(i) && !seen_as_true {
1499+
seen_as_true = true;
1500+
filter_matched_indices.push(streamed_indices.value(i));
1501+
}
14561502

1457-
// if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag
1458-
if i < streamed_indices_length - 1
1459-
&& streamed_indices.value(i) != streamed_indices.value(i + 1)
1460-
{
1461-
seen_as_true = false;
1503+
// if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag
1504+
if (i < streamed_indices_length - 1
1505+
&& streamed_indices.value(i) != streamed_indices.value(i + 1))
1506+
|| (i == streamed_indices_length - 1
1507+
&& *scanning_buffered_batch_idx == buffered_batches_len - 1)
1508+
{
1509+
corrected_mask.append_value(
1510+
!matched_indices.contains(&streamed_indices.value(i))
1511+
&& !seen_as_true,
1512+
);
1513+
seen_as_true = false;
1514+
} else {
1515+
corrected_mask.append_value(false);
1516+
}
14621517
}
1518+
1519+
Some((corrected_mask.finish(), filter_matched_indices))
14631520
}
1464-
Some((corrected_mask.finish(), filter_matched_indices))
1465-
} else {
1466-
None
1521+
_ => None,
14671522
}
14681523
}
14691524

@@ -1711,8 +1766,9 @@ mod tests {
17111766
use arrow::datatypes::{DataType, Field, Schema};
17121767
use arrow::record_batch::RecordBatch;
17131768
use arrow_array::{BooleanArray, UInt64Array};
1769+
use hashbrown::HashSet;
17141770

1715-
use datafusion_common::JoinType::LeftSemi;
1771+
use datafusion_common::JoinType::{LeftAnti, LeftSemi};
17161772
use datafusion_common::{
17171773
assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result,
17181774
};
@@ -2754,7 +2810,10 @@ mod tests {
27542810
get_filtered_join_mask(
27552811
LeftSemi,
27562812
UInt64Array::from(vec![0, 0, 1, 1]),
2757-
&BooleanArray::from(vec![true, true, false, false])
2813+
&BooleanArray::from(vec![true, true, false, false]),
2814+
&HashSet::new(),
2815+
&0,
2816+
&0
27582817
),
27592818
Some((BooleanArray::from(vec![true, false, false, false]), vec![0]))
27602819
);
@@ -2763,7 +2822,10 @@ mod tests {
27632822
get_filtered_join_mask(
27642823
LeftSemi,
27652824
UInt64Array::from(vec![0, 1]),
2766-
&BooleanArray::from(vec![true, true])
2825+
&BooleanArray::from(vec![true, true]),
2826+
&HashSet::new(),
2827+
&0,
2828+
&0
27672829
),
27682830
Some((BooleanArray::from(vec![true, true]), vec![0, 1]))
27692831
);
@@ -2772,7 +2834,10 @@ mod tests {
27722834
get_filtered_join_mask(
27732835
LeftSemi,
27742836
UInt64Array::from(vec![0, 1]),
2775-
&BooleanArray::from(vec![false, true])
2837+
&BooleanArray::from(vec![false, true]),
2838+
&HashSet::new(),
2839+
&0,
2840+
&0
27762841
),
27772842
Some((BooleanArray::from(vec![false, true]), vec![1]))
27782843
);
@@ -2781,7 +2846,10 @@ mod tests {
27812846
get_filtered_join_mask(
27822847
LeftSemi,
27832848
UInt64Array::from(vec![0, 1]),
2784-
&BooleanArray::from(vec![true, false])
2849+
&BooleanArray::from(vec![true, false]),
2850+
&HashSet::new(),
2851+
&0,
2852+
&0
27852853
),
27862854
Some((BooleanArray::from(vec![true, false]), vec![0]))
27872855
);
@@ -2790,7 +2858,10 @@ mod tests {
27902858
get_filtered_join_mask(
27912859
LeftSemi,
27922860
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
2793-
&BooleanArray::from(vec![false, true, true, true, true, true])
2861+
&BooleanArray::from(vec![false, true, true, true, true, true]),
2862+
&HashSet::new(),
2863+
&0,
2864+
&0
27942865
),
27952866
Some((
27962867
BooleanArray::from(vec![false, true, false, true, false, false]),
@@ -2802,7 +2873,10 @@ mod tests {
28022873
get_filtered_join_mask(
28032874
LeftSemi,
28042875
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
2805-
&BooleanArray::from(vec![false, false, false, false, false, true])
2876+
&BooleanArray::from(vec![false, false, false, false, false, true]),
2877+
&HashSet::new(),
2878+
&0,
2879+
&0
28062880
),
28072881
Some((
28082882
BooleanArray::from(vec![false, false, false, false, false, true]),
@@ -2813,6 +2887,89 @@ mod tests {
28132887
Ok(())
28142888
}
28152889

2890+
#[tokio::test]
2891+
async fn left_anti_join_filtered_mask() -> Result<()> {
2892+
assert_eq!(
2893+
get_filtered_join_mask(
2894+
LeftAnti,
2895+
UInt64Array::from(vec![0, 0, 1, 1]),
2896+
&BooleanArray::from(vec![true, true, false, false]),
2897+
&HashSet::new(),
2898+
&0,
2899+
&1
2900+
),
2901+
Some((BooleanArray::from(vec![false, false, false, true]), vec![0]))
2902+
);
2903+
2904+
assert_eq!(
2905+
get_filtered_join_mask(
2906+
LeftAnti,
2907+
UInt64Array::from(vec![0, 1]),
2908+
&BooleanArray::from(vec![true, true]),
2909+
&HashSet::new(),
2910+
&0,
2911+
&1
2912+
),
2913+
Some((BooleanArray::from(vec![false, false]), vec![0, 1]))
2914+
);
2915+
2916+
assert_eq!(
2917+
get_filtered_join_mask(
2918+
LeftAnti,
2919+
UInt64Array::from(vec![0, 1]),
2920+
&BooleanArray::from(vec![false, true]),
2921+
&HashSet::new(),
2922+
&0,
2923+
&1
2924+
),
2925+
Some((BooleanArray::from(vec![true, false]), vec![1]))
2926+
);
2927+
2928+
assert_eq!(
2929+
get_filtered_join_mask(
2930+
LeftAnti,
2931+
UInt64Array::from(vec![0, 1]),
2932+
&BooleanArray::from(vec![true, false]),
2933+
&HashSet::new(),
2934+
&0,
2935+
&1
2936+
),
2937+
Some((BooleanArray::from(vec![false, true]), vec![0]))
2938+
);
2939+
2940+
assert_eq!(
2941+
get_filtered_join_mask(
2942+
LeftAnti,
2943+
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
2944+
&BooleanArray::from(vec![false, true, true, true, true, true]),
2945+
&HashSet::new(),
2946+
&0,
2947+
&1
2948+
),
2949+
Some((
2950+
BooleanArray::from(vec![false, false, false, false, false, false]),
2951+
vec![0, 1]
2952+
))
2953+
);
2954+
2955+
assert_eq!(
2956+
get_filtered_join_mask(
2957+
LeftAnti,
2958+
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
2959+
&BooleanArray::from(vec![false, false, false, false, false, true]),
2960+
&HashSet::new(),
2961+
&0,
2962+
&1
2963+
),
2964+
Some((
2965+
BooleanArray::from(vec![false, false, true, false, false, false]),
2966+
vec![1]
2967+
))
2968+
);
2969+
2970+
Ok(())
2971+
}
2972+
28162973
/// Returns the column names on the schema
28172974
fn columns(schema: &Schema) -> Vec<String> {
28182975
schema.fields().iter().map(|f| f.name().clone()).collect()

0 commit comments

Comments
 (0)