diff --git a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs index 15c93262968e..aa246ac95b8b 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs @@ -91,15 +91,28 @@ impl GroupColumn for PrimitiveGroupValueBuilder { fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { - // Perf: skip null check (by short circuit) if input is not ullable - let null_match = if NULLABLE { - self.nulls.is_null(lhs_row) == array.is_null(rhs_row) - } else { - true - }; + // Perf: skip null check (by short circuit) if input is not nullable + if NULLABLE { + // In nullable path, we should check if both `exist row` and `input row` + // are null/not null + let is_exist_null = self.nulls.is_null(lhs_row); + let null_match = is_exist_null == array.is_null(rhs_row); + if !null_match { + // If `is_null`s in `exist row` and `input row` don't match, return not equal to + return false; + } else if is_exist_null { + // If `is_null`s in `exist row` and `input row` match, and they are `null`s, + // return equal to + // + // NOTICE: we should not check their values when they are `null`s, because they are + // meaningless actually, and not ensured to be same + // + return true; + } + // Otherwise, we need to check their values + } - null_match - && self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) + self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) } fn append_val(&mut self, array: &ArrayRef, row: usize) { @@ -373,9 +386,13 @@ where mod tests { use std::sync::Arc; - use arrow_array::{ArrayRef, StringArray}; + use arrow::datatypes::Int64Type; + use arrow_array::{ArrayRef, Int64Array, StringArray}; + use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; use datafusion_physical_expr::binary_map::OutputType; + use crate::aggregates::group_values::group_column::PrimitiveGroupValueBuilder; + use super::{ByteGroupValueBuilder, GroupColumn}; #[test] @@ -422,4 +439,76 @@ mod tests { ])) as ArrayRef; assert_eq!(&output, &array); } + + #[test] + fn test_nullable_primitive_equal_to() { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; values not equal + // - exist not null, input not null; values equal + + // Define PrimitiveGroupValueBuilder + let mut builder = PrimitiveGroupValueBuilder::::new(); + let builder_array = Arc::new(Int64Array::from(vec![ + None, + None, + None, + Some(1), + Some(2), + Some(3), + ])) as ArrayRef; + builder.append_val(&builder_array, 0); + builder.append_val(&builder_array, 1); + builder.append_val(&builder_array, 2); + builder.append_val(&builder_array, 3); + builder.append_val(&builder_array, 4); + builder.append_val(&builder_array, 5); + + // Define input array + let (_, values, _) = + Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)]) + .into_parts(); + + let mut boolean_buffer_builder = BooleanBufferBuilder::new(6); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = Arc::new(Int64Array::new(values, Some(nulls))) as ArrayRef; + + // Check + assert!(!builder.equal_to(0, &input_array, 0)); + assert!(builder.equal_to(1, &input_array, 1)); + assert!(builder.equal_to(2, &input_array, 2)); + assert!(!builder.equal_to(3, &input_array, 3)); + assert!(!builder.equal_to(4, &input_array, 4)); + assert!(builder.equal_to(5, &input_array, 5)); + } + + #[test] + fn test_not_nullable_primitive_equal_to() { + // Will cover such cases: + // - values equal + // - values not equal + + // Define PrimitiveGroupValueBuilder + let mut builder = PrimitiveGroupValueBuilder::::new(); + let builder_array = + Arc::new(Int64Array::from(vec![Some(0), Some(1)])) as ArrayRef; + builder.append_val(&builder_array, 0); + builder.append_val(&builder_array, 1); + + // Define input array + let input_array = Arc::new(Int64Array::from(vec![Some(0), Some(2)])) as ArrayRef; + + // Check + assert!(builder.equal_to(0, &input_array, 0)); + assert!(!builder.equal_to(1, &input_array, 1)); + } }