Skip to content

Commit 18f9201

Browse files
authored
Fix equal_to in ByteGroupValueBuilder (#12770)
* Fix `equal_to` in `ByteGroupValueBuilder` * refactor null_equal_to * Update datafusion/physical-plan/src/aggregates/group_values/group_column.rs
1 parent 6f8c74c commit 18f9201

File tree

1 file changed

+88
-20
lines changed

1 file changed

+88
-20
lines changed

datafusion/physical-plan/src/aggregates/group_values/group_column.rs

+88-20
Original file line numberDiff line numberDiff line change
@@ -93,21 +93,10 @@ impl<T: ArrowPrimitiveType, const NULLABLE: bool> GroupColumn
9393
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
9494
// Perf: skip null check (by short circuit) if input is not nullable
9595
if NULLABLE {
96-
// In nullable path, we should check if both `exist row` and `input row`
97-
// are null/not null
98-
let is_exist_null = self.nulls.is_null(lhs_row);
99-
let null_match = is_exist_null == array.is_null(rhs_row);
100-
if !null_match {
101-
// If `is_null`s in `exist row` and `input row` don't match, return not equal to
102-
return false;
103-
} else if is_exist_null {
104-
// If `is_null`s in `exist row` and `input row` match, and they are `null`s,
105-
// return equal to
106-
//
107-
// NOTICE: we should not check their values when they are `null`s, because they are
108-
// meaningless actually, and not ensured to be same
109-
//
110-
return true;
96+
let exist_null = self.nulls.is_null(lhs_row);
97+
let input_null = array.is_null(rhs_row);
98+
if let Some(result) = nulls_equal_to(exist_null, input_null) {
99+
return result;
111100
}
112101
// Otherwise, we need to check their values
113102
}
@@ -224,9 +213,14 @@ where
224213
where
225214
B: ByteArrayType,
226215
{
227-
let arr = array.as_bytes::<B>();
228-
self.nulls.is_null(lhs_row) == arr.is_null(rhs_row)
229-
&& self.value(lhs_row) == (arr.value(rhs_row).as_ref() as &[u8])
216+
let array = array.as_bytes::<B>();
217+
let exist_null = self.nulls.is_null(lhs_row);
218+
let input_null = array.is_null(rhs_row);
219+
if let Some(result) = nulls_equal_to(exist_null, input_null) {
220+
return result;
221+
}
222+
// Otherwise, we need to check their values
223+
self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8])
230224
}
231225

232226
/// return the current value of the specified row irrespective of null
@@ -382,6 +376,20 @@ where
382376
}
383377
}
384378

379+
/// Determines if the nullability of the existing and new input array can be used
380+
/// to short-circuit the comparison of the two values.
381+
///
382+
/// Returns `Some(result)` if the result of the comparison can be determined
383+
/// from the nullness of the two values, and `None` if the comparison must be
384+
/// done on the values themselves.
385+
fn nulls_equal_to(lhs_null: bool, rhs_null: bool) -> Option<bool> {
386+
match (lhs_null, rhs_null) {
387+
(true, true) => Some(true),
388+
(false, true) | (true, false) => Some(false),
389+
_ => None,
390+
}
391+
}
392+
385393
#[cfg(test)]
386394
mod tests {
387395
use std::sync::Arc;
@@ -468,13 +476,14 @@ mod tests {
468476
builder.append_val(&builder_array, 5);
469477

470478
// Define input array
471-
let (_, values, _) =
479+
let (_nulls, values, _) =
472480
Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)])
473481
.into_parts();
474482

483+
// explicitly build a boolean buffer where one of the null values also happens to match
475484
let mut boolean_buffer_builder = BooleanBufferBuilder::new(6);
476485
boolean_buffer_builder.append(true);
477-
boolean_buffer_builder.append(false);
486+
boolean_buffer_builder.append(false); // this sets Some(2) to null above
478487
boolean_buffer_builder.append(false);
479488
boolean_buffer_builder.append(false);
480489
boolean_buffer_builder.append(true);
@@ -511,4 +520,63 @@ mod tests {
511520
assert!(builder.equal_to(0, &input_array, 0));
512521
assert!(!builder.equal_to(1, &input_array, 1));
513522
}
523+
524+
#[test]
525+
fn test_byte_array_equal_to() {
526+
// Will cover such cases:
527+
// - exist null, input not null
528+
// - exist null, input null; values not equal
529+
// - exist null, input null; values equal
530+
// - exist not null, input null
531+
// - exist not null, input not null; values not equal
532+
// - exist not null, input not null; values equal
533+
534+
// Define PrimitiveGroupValueBuilder
535+
let mut builder = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
536+
let builder_array = Arc::new(StringArray::from(vec![
537+
None,
538+
None,
539+
None,
540+
Some("foo"),
541+
Some("bar"),
542+
Some("baz"),
543+
])) as ArrayRef;
544+
builder.append_val(&builder_array, 0);
545+
builder.append_val(&builder_array, 1);
546+
builder.append_val(&builder_array, 2);
547+
builder.append_val(&builder_array, 3);
548+
builder.append_val(&builder_array, 4);
549+
builder.append_val(&builder_array, 5);
550+
551+
// Define input array
552+
let (offsets, buffer, _nulls) = StringArray::from(vec![
553+
Some("foo"),
554+
Some("bar"),
555+
None,
556+
None,
557+
Some("foo"),
558+
Some("baz"),
559+
])
560+
.into_parts();
561+
562+
// explicitly build a boolean buffer where one of the null values also happens to match
563+
let mut boolean_buffer_builder = BooleanBufferBuilder::new(6);
564+
boolean_buffer_builder.append(true);
565+
boolean_buffer_builder.append(false); // this sets Some("bar") to null above
566+
boolean_buffer_builder.append(false);
567+
boolean_buffer_builder.append(false);
568+
boolean_buffer_builder.append(true);
569+
boolean_buffer_builder.append(true);
570+
let nulls = NullBuffer::new(boolean_buffer_builder.finish());
571+
let input_array =
572+
Arc::new(StringArray::new(offsets, buffer, Some(nulls))) as ArrayRef;
573+
574+
// Check
575+
assert!(!builder.equal_to(0, &input_array, 0));
576+
assert!(builder.equal_to(1, &input_array, 1));
577+
assert!(builder.equal_to(2, &input_array, 2));
578+
assert!(!builder.equal_to(3, &input_array, 3));
579+
assert!(!builder.equal_to(4, &input_array, 4));
580+
assert!(builder.equal_to(5, &input_array, 5));
581+
}
514582
}

0 commit comments

Comments
 (0)