From 8bc00f8f96d12235ed30ba6e64af3566ebf853e7 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 7 Oct 2024 08:49:16 -0400 Subject: [PATCH 01/13] Implement special min/max accumulator for Strings: `MinMaxBytesAccumulator` --- .../groups_accumulator/accumulate.rs | 2 +- datafusion/functions-aggregate/src/min_max.rs | 192 ++++-- .../src/min_max/min_max_bytes.rs | 596 ++++++++++++++++++ datafusion/sqllogictest/test_files/aal.slt | 134 ++++ 4 files changed, 855 insertions(+), 69 deletions(-) create mode 100644 datafusion/functions-aggregate/src/min_max/min_max_bytes.rs create mode 100644 datafusion/sqllogictest/test_files/aal.slt diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index a0475fe8e446..3efd348937ed 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -95,7 +95,7 @@ impl NullState { /// /// When value_fn is called it also sets /// - /// 1. `self.seen_values[group_index]` to true for all rows that had a non null vale + /// 1. `self.seen_values[group_index]` to true for all rows that had a non null value pub fn accumulate( &mut self, group_indices: &[usize], diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index e0b029f0909d..be85c9b9dacf 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -17,20 +17,7 @@ //! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function //! [`Min`] and [`MinAccumulator`] accumulator for the `min` function -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. +mod min_max_bytes; use arrow::array::{ ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, @@ -64,6 +51,7 @@ use arrow::datatypes::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; +use crate::min_max::min_max_bytes::MinMaxBytesAccumulator; use datafusion_common::ScalarValue; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, @@ -116,7 +104,7 @@ impl Default for Max { /// the specified [`ArrowPrimitiveType`]. /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType -macro_rules! instantiate_max_accumulator { +macro_rules! instantiate_primitive_max_accumulator { ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| { @@ -135,7 +123,7 @@ macro_rules! instantiate_max_accumulator { /// /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType -macro_rules! instantiate_min_accumulator { +macro_rules! instantiate_primitive_min_accumulator { ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| { @@ -243,6 +231,12 @@ impl AggregateUDFImpl for Max { | Time32(_) | Time64(_) | Timestamp(_, _) + | Utf8 + | LargeUtf8 + | Utf8View + | Binary + | LargeBinary + | BinaryView ) } @@ -254,58 +248,86 @@ impl AggregateUDFImpl for Max { use TimeUnit::*; let data_type = args.return_type; match data_type { - Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type), - Int16 => instantiate_max_accumulator!(data_type, i16, Int16Type), - Int32 => instantiate_max_accumulator!(data_type, i32, Int32Type), - Int64 => instantiate_max_accumulator!(data_type, i64, Int64Type), - UInt8 => instantiate_max_accumulator!(data_type, u8, UInt8Type), - UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type), - UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type), - UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type), + Int8 => instantiate_primitive_max_accumulator!(data_type, i8, Int8Type), + Int16 => instantiate_primitive_max_accumulator!(data_type, i16, Int16Type), + Int32 => instantiate_primitive_max_accumulator!(data_type, i32, Int32Type), + Int64 => instantiate_primitive_max_accumulator!(data_type, i64, Int64Type), + UInt8 => instantiate_primitive_max_accumulator!(data_type, u8, UInt8Type), + UInt16 => instantiate_primitive_max_accumulator!(data_type, u16, UInt16Type), + UInt32 => instantiate_primitive_max_accumulator!(data_type, u32, UInt32Type), + UInt64 => instantiate_primitive_max_accumulator!(data_type, u64, UInt64Type), Float16 => { - instantiate_max_accumulator!(data_type, f16, Float16Type) + instantiate_primitive_max_accumulator!(data_type, f16, Float16Type) } Float32 => { - instantiate_max_accumulator!(data_type, f32, Float32Type) + instantiate_primitive_max_accumulator!(data_type, f32, Float32Type) } Float64 => { - instantiate_max_accumulator!(data_type, f64, Float64Type) + instantiate_primitive_max_accumulator!(data_type, f64, Float64Type) } - Date32 => instantiate_max_accumulator!(data_type, i32, Date32Type), - Date64 => instantiate_max_accumulator!(data_type, i64, Date64Type), + Date32 => instantiate_primitive_max_accumulator!(data_type, i32, Date32Type), + Date64 => instantiate_primitive_max_accumulator!(data_type, i64, Date64Type), Time32(Second) => { - instantiate_max_accumulator!(data_type, i32, Time32SecondType) + instantiate_primitive_max_accumulator!(data_type, i32, Time32SecondType) } Time32(Millisecond) => { - instantiate_max_accumulator!(data_type, i32, Time32MillisecondType) + instantiate_primitive_max_accumulator!( + data_type, + i32, + Time32MillisecondType + ) } Time64(Microsecond) => { - instantiate_max_accumulator!(data_type, i64, Time64MicrosecondType) + instantiate_primitive_max_accumulator!( + data_type, + i64, + Time64MicrosecondType + ) } Time64(Nanosecond) => { - instantiate_max_accumulator!(data_type, i64, Time64NanosecondType) + instantiate_primitive_max_accumulator!( + data_type, + i64, + Time64NanosecondType + ) } Timestamp(Second, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampSecondType) + instantiate_primitive_max_accumulator!( + data_type, + i64, + TimestampSecondType + ) } Timestamp(Millisecond, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampMillisecondType) + instantiate_primitive_max_accumulator!( + data_type, + i64, + TimestampMillisecondType + ) } Timestamp(Microsecond, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampMicrosecondType) + instantiate_primitive_max_accumulator!( + data_type, + i64, + TimestampMicrosecondType + ) } Timestamp(Nanosecond, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampNanosecondType) + instantiate_primitive_max_accumulator!( + data_type, + i64, + TimestampNanosecondType + ) } Decimal128(_, _) => { - instantiate_max_accumulator!(data_type, i128, Decimal128Type) + instantiate_primitive_max_accumulator!(data_type, i128, Decimal128Type) } Decimal256(_, _) => { - instantiate_max_accumulator!(data_type, i256, Decimal256Type) + instantiate_primitive_max_accumulator!(data_type, i256, Decimal256Type) + } + Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { + Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone()))) } - - // It would be nice to have a fast implementation for Strings as well - // https://github.com/apache/datafusion/issues/6906 // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), @@ -1040,6 +1062,12 @@ impl AggregateUDFImpl for Min { | Time32(_) | Time64(_) | Timestamp(_, _) + | Utf8 + | LargeUtf8 + | Utf8View + | Binary + | LargeBinary + | BinaryView ) } @@ -1051,58 +1079,86 @@ impl AggregateUDFImpl for Min { use TimeUnit::*; let data_type = args.return_type; match data_type { - Int8 => instantiate_min_accumulator!(data_type, i8, Int8Type), - Int16 => instantiate_min_accumulator!(data_type, i16, Int16Type), - Int32 => instantiate_min_accumulator!(data_type, i32, Int32Type), - Int64 => instantiate_min_accumulator!(data_type, i64, Int64Type), - UInt8 => instantiate_min_accumulator!(data_type, u8, UInt8Type), - UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type), - UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type), - UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type), + Int8 => instantiate_primitive_min_accumulator!(data_type, i8, Int8Type), + Int16 => instantiate_primitive_min_accumulator!(data_type, i16, Int16Type), + Int32 => instantiate_primitive_min_accumulator!(data_type, i32, Int32Type), + Int64 => instantiate_primitive_min_accumulator!(data_type, i64, Int64Type), + UInt8 => instantiate_primitive_min_accumulator!(data_type, u8, UInt8Type), + UInt16 => instantiate_primitive_min_accumulator!(data_type, u16, UInt16Type), + UInt32 => instantiate_primitive_min_accumulator!(data_type, u32, UInt32Type), + UInt64 => instantiate_primitive_min_accumulator!(data_type, u64, UInt64Type), Float16 => { - instantiate_min_accumulator!(data_type, f16, Float16Type) + instantiate_primitive_min_accumulator!(data_type, f16, Float16Type) } Float32 => { - instantiate_min_accumulator!(data_type, f32, Float32Type) + instantiate_primitive_min_accumulator!(data_type, f32, Float32Type) } Float64 => { - instantiate_min_accumulator!(data_type, f64, Float64Type) + instantiate_primitive_min_accumulator!(data_type, f64, Float64Type) } - Date32 => instantiate_min_accumulator!(data_type, i32, Date32Type), - Date64 => instantiate_min_accumulator!(data_type, i64, Date64Type), + Date32 => instantiate_primitive_min_accumulator!(data_type, i32, Date32Type), + Date64 => instantiate_primitive_min_accumulator!(data_type, i64, Date64Type), Time32(Second) => { - instantiate_min_accumulator!(data_type, i32, Time32SecondType) + instantiate_primitive_min_accumulator!(data_type, i32, Time32SecondType) } Time32(Millisecond) => { - instantiate_min_accumulator!(data_type, i32, Time32MillisecondType) + instantiate_primitive_min_accumulator!( + data_type, + i32, + Time32MillisecondType + ) } Time64(Microsecond) => { - instantiate_min_accumulator!(data_type, i64, Time64MicrosecondType) + instantiate_primitive_min_accumulator!( + data_type, + i64, + Time64MicrosecondType + ) } Time64(Nanosecond) => { - instantiate_min_accumulator!(data_type, i64, Time64NanosecondType) + instantiate_primitive_min_accumulator!( + data_type, + i64, + Time64NanosecondType + ) } Timestamp(Second, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampSecondType) + instantiate_primitive_min_accumulator!( + data_type, + i64, + TimestampSecondType + ) } Timestamp(Millisecond, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampMillisecondType) + instantiate_primitive_min_accumulator!( + data_type, + i64, + TimestampMillisecondType + ) } Timestamp(Microsecond, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampMicrosecondType) + instantiate_primitive_min_accumulator!( + data_type, + i64, + TimestampMicrosecondType + ) } Timestamp(Nanosecond, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampNanosecondType) + instantiate_primitive_min_accumulator!( + data_type, + i64, + TimestampNanosecondType + ) } Decimal128(_, _) => { - instantiate_min_accumulator!(data_type, i128, Decimal128Type) + instantiate_primitive_min_accumulator!(data_type, i128, Decimal128Type) } Decimal256(_, _) => { - instantiate_min_accumulator!(data_type, i256, Decimal256Type) + instantiate_primitive_min_accumulator!(data_type, i256, Decimal256Type) + } + Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { + Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone()))) } - - // It would be nice to have a fast implementation for Strings as well - // https://github.com/apache/datafusion/issues/6906 // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs new file mode 100644 index 000000000000..0f01af9f692b --- /dev/null +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -0,0 +1,596 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, AsArray, BinaryArray, BinaryBuilder, BinaryViewArray, + BinaryViewBuilder, BooleanArray, LargeBinaryArray, LargeBinaryBuilder, + LargeStringArray, LargeStringBuilder, StringArray, StringBuilder, StringViewArray, + StringViewBuilder, +}; +use arrow_schema::DataType; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; +use std::sync::Arc; + +/// Implements Min/Max accumulators for "bytes" types ([`StringArray`], [`BinaryArray`], etc) +/// +/// This implementation dispatches to the appropriate specialized code in +/// [`MinMaxBytesState`] based on data type and comparison function +pub(crate) struct MinMaxBytesAccumulator { + inner: MinMaxBytesState, + /// if true, is `MIN` otherwise is `MAX` + is_min: bool, +} + +impl MinMaxBytesAccumulator { + /// Create a new accumulator fo computing min(string) + pub fn new_min(data_type: DataType) -> Self { + Self { + inner: MinMaxBytesState::new(data_type), + is_min: true, + } + } + + /// Create a new accumulator fo computing max(string) + pub fn new_max(data_type: DataType) -> Self { + Self { + inner: MinMaxBytesState::new(data_type), + is_min: false, + } + } +} + +impl GroupsAccumulator for MinMaxBytesAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let array = &values[0]; + assert_eq!(array.len(), group_indices.len()); + // Ensure the input matches the output + assert_eq!(array.data_type(), &self.inner.data_type); + + // dispatch to appropriate kernel / specialized implementation + fn string_min(a: &[u8], b: &[u8]) -> bool { + // safety: only called from update_batch, which ensures a and b come + // from a string array, and thus are valid utf8 + unsafe { + let a = std::str::from_utf8_unchecked(a); + let b = std::str::from_utf8_unchecked(b); + a < b + } + } + fn string_max(a: &[u8], b: &[u8]) -> bool { + // safety: only called from update_batch, which ensures a and b come + // from a string array, and thus are valid utf8 + unsafe { + let a = std::str::from_utf8_unchecked(a); + let b = std::str::from_utf8_unchecked(b); + a > b + } + } + fn binary_min(a: &[u8], b: &[u8]) -> bool { + a < b + } + + fn binary_max(a: &[u8], b: &[u8]) -> bool { + a > b + } + + fn str_to_bytes<'a>( + it: impl Iterator>, + ) -> impl Iterator> { + it.map(|s| s.map(|s| s.as_bytes())) + } + + match (self.is_min, &self.inner.data_type) { + // String Min + (true, &DataType::Utf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + opt_filter, + total_num_groups, + string_min, + ), + (true, &DataType::LargeUtf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + opt_filter, + total_num_groups, + string_min, + ), + (true, &DataType::Utf8View) => self.inner.update_batch( + str_to_bytes(array.as_string_view().iter()), + group_indices, + opt_filter, + total_num_groups, + string_min, + ), + + // String Max + (false, &DataType::Utf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + opt_filter, + total_num_groups, + string_max, + ), + (false, &DataType::LargeUtf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + opt_filter, + total_num_groups, + string_max, + ), + (false, &DataType::Utf8View) => self.inner.update_batch( + str_to_bytes(array.as_string_view().iter()), + group_indices, + opt_filter, + total_num_groups, + string_max, + ), + + // Binary Min + (true, &DataType::Binary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + opt_filter, + total_num_groups, + binary_min, + ), + (true, &DataType::LargeBinary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + opt_filter, + total_num_groups, + binary_min, + ), + (true, &DataType::BinaryView) => self.inner.update_batch( + array.as_binary_view().iter(), + group_indices, + opt_filter, + total_num_groups, + binary_min, + ), + + // Binary Max + (false, &DataType::Binary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + opt_filter, + total_num_groups, + binary_max, + ), + (false, &DataType::LargeBinary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + opt_filter, + total_num_groups, + binary_max, + ), + (false, &DataType::BinaryView) => self.inner.update_batch( + array.as_binary_view().iter(), + group_indices, + opt_filter, + total_num_groups, + binary_max, + ), + + _ => internal_err!( + "Unexpected combination for MinMaxBytesAccumulator: ({:?}, {:?})", + self.is_min, + self.inner.data_type + ), + } + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let (data_capacity, min_maxes) = self.inner.emit_to(emit_to); + + // Convert the Vec of bytes to a vec of Strings (no cost) + fn bytes_to_str( + min_maxes: Vec>>, + ) -> impl Iterator> { + min_maxes.into_iter().map(|opt| { + opt.map(|bytes| { + // Safety: only called on data added from update_batch which ensures + // the input type matched the output type + unsafe { String::from_utf8_unchecked(bytes) } + }) + }) + } + + let result: ArrayRef = match self.inner.data_type { + DataType::Utf8 => { + let mut builder = + StringBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in bytes_to_str(min_maxes) { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_str()), + } + } + Arc::new(builder.finish()) + } + DataType::LargeUtf8 => { + let mut builder = + LargeStringBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in bytes_to_str(min_maxes) { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_str()), + } + } + Arc::new(builder.finish()) + } + DataType::Utf8View => { + let block_size = capacity_to_view_block_size(data_capacity); + + let mut builder = StringViewBuilder::with_capacity(min_maxes.len()) + .with_fixed_block_size(block_size); + for opt in bytes_to_str(min_maxes) { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_str()), + } + } + Arc::new(builder.finish()) + } + DataType::Binary => { + let mut builder = + BinaryBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in min_maxes { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_ref() as &[u8]), + } + } + Arc::new(builder.finish()) + } + DataType::LargeBinary => { + let mut builder = + LargeBinaryBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in min_maxes { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_ref() as &[u8]), + } + } + Arc::new(builder.finish()) + } + DataType::BinaryView => { + let block_size = capacity_to_view_block_size(data_capacity); + + let mut builder = BinaryViewBuilder::with_capacity(min_maxes.len()) + .with_fixed_block_size(block_size); + for opt in min_maxes { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_ref() as &[u8]), + } + } + Arc::new(builder.finish()) + } + _ => { + return internal_err!( + "Unexpected data type for MinMaxBytesAccumulator: {:?}", + self.inner.data_type + ); + } + }; + + assert_eq!(&self.inner.data_type, result.data_type()); + Ok(result) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + // min/max are their own states (no transition needed) + self.evaluate(emit_to).map(|arr| vec![arr]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // min/max are their own states (no transition needed) + self.update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + // Min/max do not change the values as they are their own states + // apply the filter by combining with the null mask, if any + let input = &values[0]; + let nulls = filtered_null_mask(opt_filter, input); + if let Some(nulls) = nulls.as_ref() { + assert_eq!(nulls.len(), input.len()); + } + + let output: ArrayRef = match input.data_type() { + // TODO it would be nice to have safe apis in arrow-rs to update the null buffers in the arrays + DataType::Utf8 => { + let input = input.as_string::(); + // safety: values / offsets came from a valid string array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(StringArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::LargeUtf8 => { + let input = input.as_string::(); + // safety: values / offsets came from a valid string array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(LargeStringArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::Utf8View => { + let input = input.as_string_view(); + // safety: values / views came from a valid string view array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(StringViewArray::new_unchecked( + input.views().clone(), + input.data_buffers().to_vec(), + nulls, + )) + } + } + + DataType::Binary => { + let input = input.as_binary::(); + // safety: values / offsets came from a valid binary array + // and we checked nulls has the same length as values + unsafe { + Arc::new(BinaryArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::LargeBinary => { + let input = input.as_binary::(); + // safety: values / offsets came from a valid large binary array + // and we checked nulls has the same length as values + unsafe { + Arc::new(LargeBinaryArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::BinaryView => { + let input = input.as_binary_view(); + // safety: values / views came from a valid binary view array + // and we checked nulls has the same length as values + unsafe { + Arc::new(BinaryViewArray::new_unchecked( + input.views().clone(), + input.data_buffers().to_vec(), + nulls, + )) + } + } + _ => { + return internal_err!( + "Unexpected data type for convert_to_state in MinMaxBytesAccumulator: {:?}", + self.inner.data_type + ); + } + }; + assert_eq!(input.len(), output.len()); + assert_eq!(input.data_type(), output.data_type()); + Ok(vec![output]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.inner.size() + } +} + +/// Returns the block size in (contiguous buffer size) to use +/// for a given data capacity (total string length) +/// +/// This is a heuristic to avoid allocating too many small buffers +fn capacity_to_view_block_size(data_capacity: usize) -> u32 { + let max_block_size = 2 * 1024 * 1024; + if let Ok(block_size) = u32::try_from(data_capacity) { + block_size.min(max_block_size) + } else { + max_block_size + } +} + +/// Stores internal Min/Max state for "bytes" types ([`StringArray`], [`BinaryArray`], etc) +/// +/// This implementation is general and stores the minimum/maximum for each +/// groups in an individual byte array, which balances allocations and memory +/// fragmentation (aka garbage). +/// +/// Note that for StringViewArray and BinaryViewArray, there are potentially +/// more efficient implementations by managing a string data buffer directly, +/// but then garbage collection, memory management, and final array +/// construction becomes more complex. +/// +/// See discussion on +#[derive(Debug)] +struct MinMaxBytesState { + /// The minimum/maximum value for each group + min_max: Vec>>, + // todo use null state + // /// Have we seen any non-null values yet? + // null_state: NullState, + /// The data type of the array + data_type: DataType, + /// The total bytes of the string data (for pre-allocating the final array, + /// and tracking memory usage) + total_data_bytes: usize, +} + +#[derive(Debug, Clone, Copy)] +enum MinMaxLocation<'a> { + /// the min/max value is stored in the existing `min_max` array + ExistingMinMax, + /// the min/max value is stored in the input array at the given index + Input(&'a [u8]), +} + +/// Implement the MinMaxBytesAccumulator with a comparison function +/// for comparing strings +impl MinMaxBytesState { + /// Create a new MinMaxBytesAccumulator + /// + /// # Arguments: + /// * `data_type`: The data type of the arrays that will be passed to this accumulator + fn new(data_type: DataType) -> Self { + Self { + min_max: vec![], + data_type, + total_data_bytes: 0, + } + } + + /// Set the specified group to the given value, updating memory usage appropriately + fn set_value(&mut self, group_index: usize, new_val: &[u8]) { + match self.min_max[group_index].as_mut() { + None => { + self.min_max[group_index] = Some(new_val.to_vec()); + self.total_data_bytes += new_val.len(); + } + Some(existing_val) => { + // Copy data over to avoid re-allocating + self.total_data_bytes -= existing_val.len(); + self.total_data_bytes += new_val.len(); + existing_val.clear(); + existing_val.extend_from_slice(new_val); + } + } + } + + /// Updates the min/max values for the given string values + /// + /// `cmp` is the comparison function to use, called like `cmp(new_val, existing_val)` + /// returns true if the `new_val` should replace `existing_val` + fn update_batch<'a, F, I>( + &mut self, + iter: I, + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut cmp: F, + ) -> Result<()> + where + F: FnMut(&[u8], &[u8]) -> bool + Send + Sync, + I: IntoIterator>, + { + self.min_max.resize(total_num_groups, None); + let mut locations = vec![MinMaxLocation::ExistingMinMax; group_indices.len()]; + + assert!(opt_filter.is_none(), "Filtering not yet implemented"); + + // Figure out the new min value for each group + for (new_val, group_index) in iter.into_iter().zip(group_indices.iter()) { + let group_index = *group_index; + // ignore null inputs + let Some(new_val) = new_val else { + continue; + }; + + // go through contortions to avoid copying strings too many times + let existing_val = match locations[group_index] { + // previous input value was the min/max, so compare it + MinMaxLocation::Input(existing_val) => existing_val, + MinMaxLocation::ExistingMinMax => { + let Some(exising_val) = self.min_max[group_index].as_ref() else { + // no existing min/max, so this is the new min/max + locations[group_index] = MinMaxLocation::Input(new_val); + continue; + }; + exising_val.as_ref() + } + }; + + // Actually compare the new value to the existing value, replacing if necessary + if cmp(new_val, existing_val) { + locations[group_index] = MinMaxLocation::Input(new_val); + } + } + + // Update self.min_max with the new min values we found in the input + for (group_index, location) in locations.iter().enumerate() { + match location { + MinMaxLocation::ExistingMinMax => {} + MinMaxLocation::Input(new_val) => self.set_value(group_index, new_val), + } + } + Ok(()) + } + + /// Emits the specified min_max values + /// + /// Returns (data_capacity, min_maxes), updating the current value of total_data_bytes + /// + /// - `data_capacity`: the total length of all strings and their contents, + /// - `min_maxes`: the actual min/max values for each group + fn emit_to(&mut self, emit_to: EmitTo) -> (usize, Vec>>) { + match emit_to { + EmitTo::All => { + ( + std::mem::take(&mut self.total_data_bytes), // reset total bytes and min_max + std::mem::take(&mut self.min_max), + ) + } + EmitTo::First(n) => { + let min_maxes = self.min_max.drain(..n).collect(); + let data_capacity: usize = self + .min_max + .iter() + .map(|opt| opt.as_ref().map(|s| s.len()).unwrap_or(0)) + .sum(); + self.total_data_bytes -= data_capacity; + (data_capacity, min_maxes) + } + } + } + + fn size(&self) -> usize { + self.total_data_bytes + self.min_max.len() * size_of::>() + } +} diff --git a/datafusion/sqllogictest/test_files/aal.slt b/datafusion/sqllogictest/test_files/aal.slt new file mode 100644 index 000000000000..23a69c90aa61 --- /dev/null +++ b/datafusion/sqllogictest/test_files/aal.slt @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# min_utf8, max_utf8 +statement ok +CREATE TABLE strings (value TEXT, id int); + +statement ok +INSERT INTO strings VALUES + ('d', 1), + ('a', 3), + ('c', 1), + ('b', 1), + ('d', 1), + ('z', 2), + ('c', 1), + ('a', 2); + +query IT +SELECT id, MIN(value) FROM strings GROUP BY id ORDER BY id; +---- +1 b +2 a +3 a + +query IT +SELECT id, MAX(value) FROM strings GROUP BY id ORDER BY id; +---- +1 d +2 z +3 a + + + +#query TT +#SELECT MIN(value), MAX(value) FROM strings +#---- +#a z + + +query ITT +SELECT id, MIN(value), MAX(value) FROM strings GROUP BY id ORDER BY id; +---- +1 b d +2 a z +3 a a + + + + + +# min_utf8view, max_utf8view +statement ok +CREATE VIEW string_views AS SELECT id, arrow_cast(value, 'Utf8View') as value FROM strings; + + +query IT +SELECT id, MIN(value) FROM string_views GROUP BY id ORDER BY id; +---- +1 b +2 a +3 a + +query IT +SELECT id, MAX(value) FROM string_views GROUP BY id ORDER BY id; +---- +1 d +2 z +3 a + + + +#query TT +#SELECT MIN(value), MAX(value) FROM string_views +#---- +#a z + + +query ITT +SELECT id, MIN(value), MAX(value) FROM string_views GROUP BY id ORDER BY id; +---- +1 b d +2 a z +3 a a + + + +# min_utf8view, max_utf8view +statement ok +CREATE VIEW binary_views AS SELECT id, arrow_cast(value, 'BinaryView') as value FROM strings; + + +query I? +SELECT id, MIN(value) FROM binary_views GROUP BY id ORDER BY id; +---- +1 62 +2 61 +3 61 + +query I? +SELECT id, MAX(value) FROM binary_views GROUP BY id ORDER BY id; +---- +1 64 +2 7a +3 61 + + + +#query TT +#SELECT MIN(value), MAX(value) FROM binary_views +#---- +#a z + + +query I?? +SELECT id, MIN(value), MAX(value) FROM binary_views GROUP BY id ORDER BY id; +---- +1 62 64 +2 61 7a +3 61 61 From 19b329729b3597a151c66f3237d211cae31777b0 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 8 Oct 2024 15:37:28 -0400 Subject: [PATCH 02/13] fix bug --- datafusion/functions-aggregate/src/min_max/min_max_bytes.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs index 0f01af9f692b..185661216727 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -521,7 +521,7 @@ impl MinMaxBytesState { I: IntoIterator>, { self.min_max.resize(total_num_groups, None); - let mut locations = vec![MinMaxLocation::ExistingMinMax; group_indices.len()]; + let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups]; assert!(opt_filter.is_none(), "Filtering not yet implemented"); From 4b3e625ff2472a9913ab6da6efaec5ca2dd93ac8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 8 Oct 2024 15:49:30 -0400 Subject: [PATCH 03/13] fix msrv --- datafusion/functions-aggregate/src/min_max/min_max_bytes.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs index 185661216727..64d0c7b09b6e 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -591,6 +591,6 @@ impl MinMaxBytesState { } fn size(&self) -> usize { - self.total_data_bytes + self.min_max.len() * size_of::>() + self.total_data_bytes + self.min_max.len() * std::mem::size_of::>() } } From 64e986192b41ffe40a63793be9b05c2f7c37985a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 8 Oct 2024 16:24:19 -0400 Subject: [PATCH 04/13] move code, handle filters --- .../src/aggregate/groups_accumulator/nulls.rs | 115 +++++++++++++++++- .../src/min_max/min_max_bytes.rs | 102 ++-------------- 2 files changed, 120 insertions(+), 97 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs index 25212f7f0f5f..6f80ff9b603d 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -15,13 +15,22 @@ // specific language governing permissions and limitations // under the License. -//! [`set_nulls`], and [`filtered_null_mask`], utilities for working with nulls +//! [`set_nulls`], other utilities for working with nulls -use arrow::array::{Array, ArrowNumericType, BooleanArray, PrimitiveArray}; +use arrow::array::{ + Array, ArrayRef, ArrowNumericType, AsArray, BinaryArray, BinaryViewArray, + BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray, + StringViewArray, +}; use arrow::buffer::NullBuffer; +use arrow::datatypes::DataType; +use datafusion_common::{not_impl_err, Result}; +use std::sync::Arc; /// Sets the validity mask for a `PrimitiveArray` to `nulls` /// replacing any existing null mask +/// +/// See [`set_nulls_dyn`] for a version that works with `Array` pub fn set_nulls( array: PrimitiveArray, nulls: Option, @@ -91,3 +100,105 @@ pub fn filtered_null_mask( let opt_filter = opt_filter.and_then(filter_to_nulls); NullBuffer::union(opt_filter.as_ref(), input.nulls()) } + +/// Applies optional filter to input, returning a new array of the same type +/// with the same data, but with any values that were filtered out set to null +pub fn apply_filter_as_nulls( + input: &dyn Array, + opt_filter: Option<&BooleanArray>, +) -> Result { + let nulls = filtered_null_mask(opt_filter, input); + replace_nulls(input, nulls) +} + +/// Replaces the nulls in the input array with the given `NullBuffer` +/// +/// Can replace when upstreamed in arrow-rs: https://github.com/apache/arrow-rs/issues/6528 +pub fn replace_nulls(input: &dyn Array, nulls: Option) -> Result { + if let Some(nulls) = nulls.as_ref() { + assert_eq!(nulls.len(), input.len()); + } + + let output: ArrayRef = match input.data_type() { + DataType::Utf8 => { + let input = input.as_string::(); + // safety: values / offsets came from a valid string array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(StringArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::LargeUtf8 => { + let input = input.as_string::(); + // safety: values / offsets came from a valid string array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(LargeStringArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::Utf8View => { + let input = input.as_string_view(); + // safety: values / views came from a valid string view array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(StringViewArray::new_unchecked( + input.views().clone(), + input.data_buffers().to_vec(), + nulls, + )) + } + } + + DataType::Binary => { + let input = input.as_binary::(); + // safety: values / offsets came from a valid binary array + // and we checked nulls has the same length as values + unsafe { + Arc::new(BinaryArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::LargeBinary => { + let input = input.as_binary::(); + // safety: values / offsets came from a valid large binary array + // and we checked nulls has the same length as values + unsafe { + Arc::new(LargeBinaryArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::BinaryView => { + let input = input.as_binary_view(); + // safety: values / views came from a valid binary view array + // and we checked nulls has the same length as values + unsafe { + Arc::new(BinaryViewArray::new_unchecked( + input.views().clone(), + input.data_buffers().to_vec(), + nulls, + )) + } + } + _ => { + return not_impl_err!("Applying nulls {:?}", input.data_type()); + } + }; + assert_eq!(input.len(), output.len()); + assert_eq!(input.data_type(), output.data_type()); + + Ok(output) +} diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs index 64d0c7b09b6e..67c2963850fb 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -15,15 +15,13 @@ // under the License. use arrow::array::{ - Array, ArrayRef, AsArray, BinaryArray, BinaryBuilder, BinaryViewArray, - BinaryViewBuilder, BooleanArray, LargeBinaryArray, LargeBinaryBuilder, - LargeStringArray, LargeStringBuilder, StringArray, StringBuilder, StringViewArray, - StringViewBuilder, + Array, ArrayRef, AsArray, BinaryBuilder, BinaryViewBuilder, BooleanArray, + LargeBinaryBuilder, LargeStringBuilder, StringBuilder, StringViewBuilder, }; use arrow_schema::DataType; use datafusion_common::{internal_err, Result}; use datafusion_expr::{EmitTo, GroupsAccumulator}; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; use std::sync::Arc; /// Implements Min/Max accumulators for "bytes" types ([`StringArray`], [`BinaryArray`], etc) @@ -67,6 +65,9 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { // Ensure the input matches the output assert_eq!(array.data_type(), &self.inner.data_type); + // apply filter if needed + let array = apply_filter_as_nulls(array, opt_filter)?; + // dispatch to appropriate kernel / specialized implementation fn string_min(a: &[u8], b: &[u8]) -> bool { // safety: only called from update_batch, which ensures a and b come @@ -323,96 +324,7 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { ) -> Result> { // Min/max do not change the values as they are their own states // apply the filter by combining with the null mask, if any - let input = &values[0]; - let nulls = filtered_null_mask(opt_filter, input); - if let Some(nulls) = nulls.as_ref() { - assert_eq!(nulls.len(), input.len()); - } - - let output: ArrayRef = match input.data_type() { - // TODO it would be nice to have safe apis in arrow-rs to update the null buffers in the arrays - DataType::Utf8 => { - let input = input.as_string::(); - // safety: values / offsets came from a valid string array, so are valid utf8 - // and we checked nulls has the same length as values - unsafe { - Arc::new(StringArray::new_unchecked( - input.offsets().clone(), - input.values().clone(), - nulls, - )) - } - } - DataType::LargeUtf8 => { - let input = input.as_string::(); - // safety: values / offsets came from a valid string array, so are valid utf8 - // and we checked nulls has the same length as values - unsafe { - Arc::new(LargeStringArray::new_unchecked( - input.offsets().clone(), - input.values().clone(), - nulls, - )) - } - } - DataType::Utf8View => { - let input = input.as_string_view(); - // safety: values / views came from a valid string view array, so are valid utf8 - // and we checked nulls has the same length as values - unsafe { - Arc::new(StringViewArray::new_unchecked( - input.views().clone(), - input.data_buffers().to_vec(), - nulls, - )) - } - } - - DataType::Binary => { - let input = input.as_binary::(); - // safety: values / offsets came from a valid binary array - // and we checked nulls has the same length as values - unsafe { - Arc::new(BinaryArray::new_unchecked( - input.offsets().clone(), - input.values().clone(), - nulls, - )) - } - } - DataType::LargeBinary => { - let input = input.as_binary::(); - // safety: values / offsets came from a valid large binary array - // and we checked nulls has the same length as values - unsafe { - Arc::new(LargeBinaryArray::new_unchecked( - input.offsets().clone(), - input.values().clone(), - nulls, - )) - } - } - DataType::BinaryView => { - let input = input.as_binary_view(); - // safety: values / views came from a valid binary view array - // and we checked nulls has the same length as values - unsafe { - Arc::new(BinaryViewArray::new_unchecked( - input.views().clone(), - input.data_buffers().to_vec(), - nulls, - )) - } - } - _ => { - return internal_err!( - "Unexpected data type for convert_to_state in MinMaxBytesAccumulator: {:?}", - self.inner.data_type - ); - } - }; - assert_eq!(input.len(), output.len()); - assert_eq!(input.data_type(), output.data_type()); + let output = apply_filter_as_nulls(&values[0], opt_filter)?; Ok(vec![output]) } From c4f62717788e9e8b5402550d966eaf819a3a4976 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 8 Oct 2024 16:27:20 -0400 Subject: [PATCH 05/13] simplify --- datafusion/functions-aggregate/src/min_max.rs | 152 ++++++------------ 1 file changed, 48 insertions(+), 104 deletions(-) diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index be85c9b9dacf..26ba97f505fd 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -104,7 +104,7 @@ impl Default for Max { /// the specified [`ArrowPrimitiveType`]. /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType -macro_rules! instantiate_primitive_max_accumulator { +macro_rules! primitive_max_accumulator { ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| { @@ -123,7 +123,7 @@ macro_rules! instantiate_primitive_max_accumulator { /// /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType -macro_rules! instantiate_primitive_min_accumulator { +macro_rules! primitive_min_accumulator { ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| { @@ -248,82 +248,54 @@ impl AggregateUDFImpl for Max { use TimeUnit::*; let data_type = args.return_type; match data_type { - Int8 => instantiate_primitive_max_accumulator!(data_type, i8, Int8Type), - Int16 => instantiate_primitive_max_accumulator!(data_type, i16, Int16Type), - Int32 => instantiate_primitive_max_accumulator!(data_type, i32, Int32Type), - Int64 => instantiate_primitive_max_accumulator!(data_type, i64, Int64Type), - UInt8 => instantiate_primitive_max_accumulator!(data_type, u8, UInt8Type), - UInt16 => instantiate_primitive_max_accumulator!(data_type, u16, UInt16Type), - UInt32 => instantiate_primitive_max_accumulator!(data_type, u32, UInt32Type), - UInt64 => instantiate_primitive_max_accumulator!(data_type, u64, UInt64Type), + Int8 => primitive_max_accumulator!(data_type, i8, Int8Type), + Int16 => primitive_max_accumulator!(data_type, i16, Int16Type), + Int32 => primitive_max_accumulator!(data_type, i32, Int32Type), + Int64 => primitive_max_accumulator!(data_type, i64, Int64Type), + UInt8 => primitive_max_accumulator!(data_type, u8, UInt8Type), + UInt16 => primitive_max_accumulator!(data_type, u16, UInt16Type), + UInt32 => primitive_max_accumulator!(data_type, u32, UInt32Type), + UInt64 => primitive_max_accumulator!(data_type, u64, UInt64Type), Float16 => { - instantiate_primitive_max_accumulator!(data_type, f16, Float16Type) + primitive_max_accumulator!(data_type, f16, Float16Type) } Float32 => { - instantiate_primitive_max_accumulator!(data_type, f32, Float32Type) + primitive_max_accumulator!(data_type, f32, Float32Type) } Float64 => { - instantiate_primitive_max_accumulator!(data_type, f64, Float64Type) + primitive_max_accumulator!(data_type, f64, Float64Type) } - Date32 => instantiate_primitive_max_accumulator!(data_type, i32, Date32Type), - Date64 => instantiate_primitive_max_accumulator!(data_type, i64, Date64Type), + Date32 => primitive_max_accumulator!(data_type, i32, Date32Type), + Date64 => primitive_max_accumulator!(data_type, i64, Date64Type), Time32(Second) => { - instantiate_primitive_max_accumulator!(data_type, i32, Time32SecondType) + primitive_max_accumulator!(data_type, i32, Time32SecondType) } Time32(Millisecond) => { - instantiate_primitive_max_accumulator!( - data_type, - i32, - Time32MillisecondType - ) + primitive_max_accumulator!(data_type, i32, Time32MillisecondType) } Time64(Microsecond) => { - instantiate_primitive_max_accumulator!( - data_type, - i64, - Time64MicrosecondType - ) + primitive_max_accumulator!(data_type, i64, Time64MicrosecondType) } Time64(Nanosecond) => { - instantiate_primitive_max_accumulator!( - data_type, - i64, - Time64NanosecondType - ) + primitive_max_accumulator!(data_type, i64, Time64NanosecondType) } Timestamp(Second, _) => { - instantiate_primitive_max_accumulator!( - data_type, - i64, - TimestampSecondType - ) + primitive_max_accumulator!(data_type, i64, TimestampSecondType) } Timestamp(Millisecond, _) => { - instantiate_primitive_max_accumulator!( - data_type, - i64, - TimestampMillisecondType - ) + primitive_max_accumulator!(data_type, i64, TimestampMillisecondType) } Timestamp(Microsecond, _) => { - instantiate_primitive_max_accumulator!( - data_type, - i64, - TimestampMicrosecondType - ) + primitive_max_accumulator!(data_type, i64, TimestampMicrosecondType) } Timestamp(Nanosecond, _) => { - instantiate_primitive_max_accumulator!( - data_type, - i64, - TimestampNanosecondType - ) + primitive_max_accumulator!(data_type, i64, TimestampNanosecondType) } Decimal128(_, _) => { - instantiate_primitive_max_accumulator!(data_type, i128, Decimal128Type) + primitive_max_accumulator!(data_type, i128, Decimal128Type) } Decimal256(_, _) => { - instantiate_primitive_max_accumulator!(data_type, i256, Decimal256Type) + primitive_max_accumulator!(data_type, i256, Decimal256Type) } Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone()))) @@ -1079,82 +1051,54 @@ impl AggregateUDFImpl for Min { use TimeUnit::*; let data_type = args.return_type; match data_type { - Int8 => instantiate_primitive_min_accumulator!(data_type, i8, Int8Type), - Int16 => instantiate_primitive_min_accumulator!(data_type, i16, Int16Type), - Int32 => instantiate_primitive_min_accumulator!(data_type, i32, Int32Type), - Int64 => instantiate_primitive_min_accumulator!(data_type, i64, Int64Type), - UInt8 => instantiate_primitive_min_accumulator!(data_type, u8, UInt8Type), - UInt16 => instantiate_primitive_min_accumulator!(data_type, u16, UInt16Type), - UInt32 => instantiate_primitive_min_accumulator!(data_type, u32, UInt32Type), - UInt64 => instantiate_primitive_min_accumulator!(data_type, u64, UInt64Type), + Int8 => primitive_min_accumulator!(data_type, i8, Int8Type), + Int16 => primitive_min_accumulator!(data_type, i16, Int16Type), + Int32 => primitive_min_accumulator!(data_type, i32, Int32Type), + Int64 => primitive_min_accumulator!(data_type, i64, Int64Type), + UInt8 => primitive_min_accumulator!(data_type, u8, UInt8Type), + UInt16 => primitive_min_accumulator!(data_type, u16, UInt16Type), + UInt32 => primitive_min_accumulator!(data_type, u32, UInt32Type), + UInt64 => primitive_min_accumulator!(data_type, u64, UInt64Type), Float16 => { - instantiate_primitive_min_accumulator!(data_type, f16, Float16Type) + primitive_min_accumulator!(data_type, f16, Float16Type) } Float32 => { - instantiate_primitive_min_accumulator!(data_type, f32, Float32Type) + primitive_min_accumulator!(data_type, f32, Float32Type) } Float64 => { - instantiate_primitive_min_accumulator!(data_type, f64, Float64Type) + primitive_min_accumulator!(data_type, f64, Float64Type) } - Date32 => instantiate_primitive_min_accumulator!(data_type, i32, Date32Type), - Date64 => instantiate_primitive_min_accumulator!(data_type, i64, Date64Type), + Date32 => primitive_min_accumulator!(data_type, i32, Date32Type), + Date64 => primitive_min_accumulator!(data_type, i64, Date64Type), Time32(Second) => { - instantiate_primitive_min_accumulator!(data_type, i32, Time32SecondType) + primitive_min_accumulator!(data_type, i32, Time32SecondType) } Time32(Millisecond) => { - instantiate_primitive_min_accumulator!( - data_type, - i32, - Time32MillisecondType - ) + primitive_min_accumulator!(data_type, i32, Time32MillisecondType) } Time64(Microsecond) => { - instantiate_primitive_min_accumulator!( - data_type, - i64, - Time64MicrosecondType - ) + primitive_min_accumulator!(data_type, i64, Time64MicrosecondType) } Time64(Nanosecond) => { - instantiate_primitive_min_accumulator!( - data_type, - i64, - Time64NanosecondType - ) + primitive_min_accumulator!(data_type, i64, Time64NanosecondType) } Timestamp(Second, _) => { - instantiate_primitive_min_accumulator!( - data_type, - i64, - TimestampSecondType - ) + primitive_min_accumulator!(data_type, i64, TimestampSecondType) } Timestamp(Millisecond, _) => { - instantiate_primitive_min_accumulator!( - data_type, - i64, - TimestampMillisecondType - ) + primitive_min_accumulator!(data_type, i64, TimestampMillisecondType) } Timestamp(Microsecond, _) => { - instantiate_primitive_min_accumulator!( - data_type, - i64, - TimestampMicrosecondType - ) + primitive_min_accumulator!(data_type, i64, TimestampMicrosecondType) } Timestamp(Nanosecond, _) => { - instantiate_primitive_min_accumulator!( - data_type, - i64, - TimestampNanosecondType - ) + primitive_min_accumulator!(data_type, i64, TimestampNanosecondType) } Decimal128(_, _) => { - instantiate_primitive_min_accumulator!(data_type, i128, Decimal128Type) + primitive_min_accumulator!(data_type, i128, Decimal128Type) } Decimal256(_, _) => { - instantiate_primitive_min_accumulator!(data_type, i256, Decimal256Type) + primitive_min_accumulator!(data_type, i256, Decimal256Type) } Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone()))) From 126a9a8865f3c2389cabdba982f4318cb7034a9b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 10 Oct 2024 06:28:47 -0400 Subject: [PATCH 06/13] Add functional tests --- .../sqllogictest/test_files/aggregate.slt | 173 ++++++++++++++++++ 1 file changed, 173 insertions(+) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 250fa85cddef..c207044333f5 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3800,6 +3800,179 @@ DROP TABLE min_bool; # Min_Max End # ################# + + +################# +# min_max on strings/binary with null values and groups +################# + +statement ok +CREATE TABLE strings (value TEXT, id int); + +statement ok +INSERT INTO strings VALUES + ('d', 1), + ('a', 3), + ('c', 1), + ('b', 1), + (NULL, 1), + (NULL, 4), + ('d', 1), + ('z', 2), + ('c', 1), + ('a', 2); + +############ Utf8 ############ + +query IT +SELECT id, MIN(value) FROM strings GROUP BY id ORDER BY id; +---- +1 b +2 a +3 a +4 NULL + +query IT +SELECT id, MAX(value) FROM strings GROUP BY id ORDER BY id; +---- +1 d +2 z +3 a +4 NULL + +############ LargeUtf8 ############ + +statement ok +CREATE VIEW large_strings AS SELECT id, arrow_cast(value, 'LargeUtf8') as value FROM strings; + + +query IT +SELECT id, MIN(value) FROM large_strings GROUP BY id ORDER BY id; +---- +1 b +2 a +3 a +4 NULL + +query IT +SELECT id, MAX(value) FROM large_strings GROUP BY id ORDER BY id; +---- +1 d +2 z +3 a +4 NULL + +statement ok +DROP VIEW large_strings + +############ Utf8View ############ + +statement ok +CREATE VIEW string_views AS SELECT id, arrow_cast(value, 'Utf8View') as value FROM strings; + + +query IT +SELECT id, MIN(value) FROM string_views GROUP BY id ORDER BY id; +---- +1 b +2 a +3 a +4 NULL + +query IT +SELECT id, MAX(value) FROM string_views GROUP BY id ORDER BY id; +---- +1 d +2 z +3 a +4 NULL + +statement ok +DROP VIEW string_views + +############ Binary ############ + +statement ok +CREATE VIEW binary AS SELECT id, arrow_cast(value, 'Binary') as value FROM strings; + + +query I? +SELECT id, MIN(value) FROM binary GROUP BY id ORDER BY id; +---- +1 62 +2 61 +3 61 +4 NULL + +query I? +SELECT id, MAX(value) FROM binary GROUP BY id ORDER BY id; +---- +1 64 +2 7a +3 61 +4 NULL + +statement ok +DROP VIEW binary + +############ LargeBinary ############ + +statement ok +CREATE VIEW large_binary AS SELECT id, arrow_cast(value, 'LargeBinary') as value FROM strings; + + +query I? +SELECT id, MIN(value) FROM large_binary GROUP BY id ORDER BY id; +---- +1 62 +2 61 +3 61 +4 NULL + +query I? +SELECT id, MAX(value) FROM large_binary GROUP BY id ORDER BY id; +---- +1 64 +2 7a +3 61 +4 NULL + +statement ok +DROP VIEW large_binary + +############ BinaryView ############ + +statement ok +CREATE VIEW binary_views AS SELECT id, arrow_cast(value, 'BinaryView') as value FROM strings; + + +query I? +SELECT id, MIN(value) FROM binary_views GROUP BY id ORDER BY id; +---- +1 62 +2 61 +3 61 +4 NULL + +query I? +SELECT id, MAX(value) FROM binary_views GROUP BY id ORDER BY id; +---- +1 64 +2 7a +3 61 +4 NULL + +statement ok +DROP VIEW binary_views + +statement ok +DROP TABLE strings; + +################# +# End min_max on strings/binary with null values and groups +################# + + statement ok create table bool_aggregate_functions ( c1 boolean not null, From a7ebb56af43423b5d520e50b1ce8a1f02607f236 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 10 Oct 2024 06:29:00 -0400 Subject: [PATCH 07/13] remove unecessary test --- datafusion/sqllogictest/test_files/aal.slt | 134 --------------------- 1 file changed, 134 deletions(-) delete mode 100644 datafusion/sqllogictest/test_files/aal.slt diff --git a/datafusion/sqllogictest/test_files/aal.slt b/datafusion/sqllogictest/test_files/aal.slt deleted file mode 100644 index 23a69c90aa61..000000000000 --- a/datafusion/sqllogictest/test_files/aal.slt +++ /dev/null @@ -1,134 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# min_utf8, max_utf8 -statement ok -CREATE TABLE strings (value TEXT, id int); - -statement ok -INSERT INTO strings VALUES - ('d', 1), - ('a', 3), - ('c', 1), - ('b', 1), - ('d', 1), - ('z', 2), - ('c', 1), - ('a', 2); - -query IT -SELECT id, MIN(value) FROM strings GROUP BY id ORDER BY id; ----- -1 b -2 a -3 a - -query IT -SELECT id, MAX(value) FROM strings GROUP BY id ORDER BY id; ----- -1 d -2 z -3 a - - - -#query TT -#SELECT MIN(value), MAX(value) FROM strings -#---- -#a z - - -query ITT -SELECT id, MIN(value), MAX(value) FROM strings GROUP BY id ORDER BY id; ----- -1 b d -2 a z -3 a a - - - - - -# min_utf8view, max_utf8view -statement ok -CREATE VIEW string_views AS SELECT id, arrow_cast(value, 'Utf8View') as value FROM strings; - - -query IT -SELECT id, MIN(value) FROM string_views GROUP BY id ORDER BY id; ----- -1 b -2 a -3 a - -query IT -SELECT id, MAX(value) FROM string_views GROUP BY id ORDER BY id; ----- -1 d -2 z -3 a - - - -#query TT -#SELECT MIN(value), MAX(value) FROM string_views -#---- -#a z - - -query ITT -SELECT id, MIN(value), MAX(value) FROM string_views GROUP BY id ORDER BY id; ----- -1 b d -2 a z -3 a a - - - -# min_utf8view, max_utf8view -statement ok -CREATE VIEW binary_views AS SELECT id, arrow_cast(value, 'BinaryView') as value FROM strings; - - -query I? -SELECT id, MIN(value) FROM binary_views GROUP BY id ORDER BY id; ----- -1 62 -2 61 -3 61 - -query I? -SELECT id, MAX(value) FROM binary_views GROUP BY id ORDER BY id; ----- -1 64 -2 7a -3 61 - - - -#query TT -#SELECT MIN(value), MAX(value) FROM binary_views -#---- -#a z - - -query I?? -SELECT id, MIN(value), MAX(value) FROM binary_views GROUP BY id ORDER BY id; ----- -1 62 64 -2 61 7a -3 61 61 From 2d5957df1ba04b61548ecc218d5bd91c5a41c183 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 10 Oct 2024 06:36:07 -0400 Subject: [PATCH 08/13] improve docs --- .../src/min_max/min_max_bytes.rs | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs index 67c2963850fb..0af79a2b954e 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -24,18 +24,20 @@ use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; use std::sync::Arc; -/// Implements Min/Max accumulators for "bytes" types ([`StringArray`], [`BinaryArray`], etc) +/// Implements fast Min/Max [`GroupsAccumulator`] for "bytes" types ([`StringArray`], +/// [`BinaryArray`], [`StringViewArray`], etc) /// /// This implementation dispatches to the appropriate specialized code in /// [`MinMaxBytesState`] based on data type and comparison function pub(crate) struct MinMaxBytesAccumulator { + /// Inner data storage. inner: MinMaxBytesState, /// if true, is `MIN` otherwise is `MAX` is_min: bool, } impl MinMaxBytesAccumulator { - /// Create a new accumulator fo computing min(string) + /// Create a new accumulator for computing `min(val)` pub fn new_min(data_type: DataType) -> Self { Self { inner: MinMaxBytesState::new(data_type), @@ -43,7 +45,7 @@ impl MinMaxBytesAccumulator { } } - /// Create a new accumulator fo computing max(string) + /// Create a new accumulator fo computing `max(val)` pub fn new_max(data_type: DataType) -> Self { Self { inner: MinMaxBytesState::new(data_type), @@ -62,7 +64,6 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { ) -> Result<()> { let array = &values[0]; assert_eq!(array.len(), group_indices.len()); - // Ensure the input matches the output assert_eq!(array.data_type(), &self.inner.data_type); // apply filter if needed @@ -70,8 +71,8 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { // dispatch to appropriate kernel / specialized implementation fn string_min(a: &[u8], b: &[u8]) -> bool { - // safety: only called from update_batch, which ensures a and b come - // from a string array, and thus are valid utf8 + // safety: only called from this function, which ensures a and b come + // from an array with valid utf8 data unsafe { let a = std::str::from_utf8_unchecked(a); let b = std::str::from_utf8_unchecked(b); @@ -79,8 +80,8 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { } } fn string_max(a: &[u8], b: &[u8]) -> bool { - // safety: only called from update_batch, which ensures a and b come - // from a string array, and thus are valid utf8 + // safety: only called from this function, which ensures a and b come + // from an array with valid utf8 data unsafe { let a = std::str::from_utf8_unchecked(a); let b = std::str::from_utf8_unchecked(b); @@ -102,7 +103,7 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { } match (self.is_min, &self.inner.data_type) { - // String Min + // Utf8/LargeUtf8/Utf8View Min (true, &DataType::Utf8) => self.inner.update_batch( str_to_bytes(array.as_string::().iter()), group_indices, @@ -125,7 +126,7 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { string_min, ), - // String Max + // Utf8/LargeUtf8/Utf8View Max (false, &DataType::Utf8) => self.inner.update_batch( str_to_bytes(array.as_string::().iter()), group_indices, @@ -148,7 +149,7 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { string_max, ), - // Binary Min + // Binary/LargeBinary/BinaryView Min (true, &DataType::Binary) => self.inner.update_batch( array.as_binary::().iter(), group_indices, @@ -171,7 +172,7 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { binary_min, ), - // Binary Max + // Binary/LargeBinary/BinaryView Max (false, &DataType::Binary) => self.inner.update_batch( array.as_binary::().iter(), group_indices, From 2671b2d75e4a31bcd6aac1735d75cdd5de7f4641 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 10 Oct 2024 06:48:29 -0400 Subject: [PATCH 09/13] improve docs --- .../src/aggregate/groups_accumulator/nulls.rs | 6 +-- .../src/min_max/min_max_bytes.rs | 48 +++++++++++-------- .../sqllogictest/test_files/aggregate.slt | 1 + 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs index 6f80ff9b603d..b19624edae36 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -108,13 +108,13 @@ pub fn apply_filter_as_nulls( opt_filter: Option<&BooleanArray>, ) -> Result { let nulls = filtered_null_mask(opt_filter, input); - replace_nulls(input, nulls) + set_nulls_dyn(input, nulls) } /// Replaces the nulls in the input array with the given `NullBuffer` /// -/// Can replace when upstreamed in arrow-rs: https://github.com/apache/arrow-rs/issues/6528 -pub fn replace_nulls(input: &dyn Array, nulls: Option) -> Result { +/// Can replace when upstreamed in arrow-rs: +pub fn set_nulls_dyn(input: &dyn Array, nulls: Option) -> Result { if let Some(nulls) = nulls.as_ref() { assert_eq!(nulls.len(), input.len()); } diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs index 0af79a2b954e..0ae9a1f4b25f 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -29,6 +29,11 @@ use std::sync::Arc; /// /// This implementation dispatches to the appropriate specialized code in /// [`MinMaxBytesState`] based on data type and comparison function +/// +/// [`StringArray`]: arrow::array::StringArray +/// [`BinaryArray`]: arrow::array::BinaryArray +/// [`StringViewArray`]: arrow::array::StringViewArray +#[derive(Debug)] pub(crate) struct MinMaxBytesAccumulator { /// Inner data storage. inner: MinMaxBytesState, @@ -107,21 +112,18 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { (true, &DataType::Utf8) => self.inner.update_batch( str_to_bytes(array.as_string::().iter()), group_indices, - opt_filter, total_num_groups, string_min, ), (true, &DataType::LargeUtf8) => self.inner.update_batch( str_to_bytes(array.as_string::().iter()), group_indices, - opt_filter, total_num_groups, string_min, ), (true, &DataType::Utf8View) => self.inner.update_batch( str_to_bytes(array.as_string_view().iter()), group_indices, - opt_filter, total_num_groups, string_min, ), @@ -130,21 +132,18 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { (false, &DataType::Utf8) => self.inner.update_batch( str_to_bytes(array.as_string::().iter()), group_indices, - opt_filter, total_num_groups, string_max, ), (false, &DataType::LargeUtf8) => self.inner.update_batch( str_to_bytes(array.as_string::().iter()), group_indices, - opt_filter, total_num_groups, string_max, ), (false, &DataType::Utf8View) => self.inner.update_batch( str_to_bytes(array.as_string_view().iter()), group_indices, - opt_filter, total_num_groups, string_max, ), @@ -153,21 +152,18 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { (true, &DataType::Binary) => self.inner.update_batch( array.as_binary::().iter(), group_indices, - opt_filter, total_num_groups, binary_min, ), (true, &DataType::LargeBinary) => self.inner.update_batch( array.as_binary::().iter(), group_indices, - opt_filter, total_num_groups, binary_min, ), (true, &DataType::BinaryView) => self.inner.update_batch( array.as_binary_view().iter(), group_indices, - opt_filter, total_num_groups, binary_min, ), @@ -176,21 +172,18 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { (false, &DataType::Binary) => self.inner.update_batch( array.as_binary::().iter(), group_indices, - opt_filter, total_num_groups, binary_max, ), (false, &DataType::LargeBinary) => self.inner.update_batch( array.as_binary::().iter(), group_indices, - opt_filter, total_num_groups, binary_max, ), (false, &DataType::BinaryView) => self.inner.update_batch( array.as_binary_view().iter(), group_indices, - opt_filter, total_num_groups, binary_max, ), @@ -206,7 +199,7 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { fn evaluate(&mut self, emit_to: EmitTo) -> Result { let (data_capacity, min_maxes) = self.inner.emit_to(emit_to); - // Convert the Vec of bytes to a vec of Strings (no cost) + // Convert the Vec of bytes to a vec of Strings (at no cost) fn bytes_to_str( min_maxes: Vec>>, ) -> impl Iterator> { @@ -351,15 +344,33 @@ fn capacity_to_view_block_size(data_capacity: usize) -> u32 { } } -/// Stores internal Min/Max state for "bytes" types ([`StringArray`], [`BinaryArray`], etc) +/// Stores internal Min/Max state for "bytes" types. /// /// This implementation is general and stores the minimum/maximum for each /// groups in an individual byte array, which balances allocations and memory /// fragmentation (aka garbage). /// -/// Note that for StringViewArray and BinaryViewArray, there are potentially -/// more efficient implementations by managing a string data buffer directly, -/// but then garbage collection, memory management, and final array +/// ```text +/// ┌─────────────────────────────────┐ +/// ┌─────┐ ┌─────▶│Option> (["A"]) │──────────────▶ "A" +/// │ 0 │─────┘ └─────────────────────────────────┘ +/// ├─────┤ ┌─────────────────────────────────┐ +/// │ 1 │───────────▶│Option> (["Z"]) │───────────────▶ "Z" +/// └─────┘ └─────────────────────────────────┘ ... +/// ... ... +/// ┌─────┐ ┌────────────────────────────────┐ +/// │ N-2 │ │Option> (["A"]) │────────────────▶ "A" +/// ├─────┤ └────────────────────────────────┘ +/// │ N-1 │─────┐ ┌────────────────────────────────┐ +/// └─────┘ └─────▶│Option> (["Q"]) │────────────────▶ "Q" +/// └────────────────────────────────┘ +/// +/// min_max: Vec> +/// ``` +/// +/// Note that for `StringViewArray` and `BinaryViewArray`, there are potentially +/// more efficient implementations (e.g. by managing a string data buffer +/// directly), but then garbage collection, memory management, and final array /// construction becomes more complex. /// /// See discussion on @@ -425,7 +436,6 @@ impl MinMaxBytesState { &mut self, iter: I, group_indices: &[usize], - opt_filter: Option<&BooleanArray>, total_num_groups: usize, mut cmp: F, ) -> Result<()> @@ -436,8 +446,6 @@ impl MinMaxBytesState { self.min_max.resize(total_num_groups, None); let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups]; - assert!(opt_filter.is_none(), "Filtering not yet implemented"); - // Figure out the new min value for each group for (new_val, group_index) in iter.into_iter().zip(group_indices.iter()) { let group_index = *group_index; diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index c207044333f5..f0746b37e9d6 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3811,6 +3811,7 @@ CREATE TABLE strings (value TEXT, id int); statement ok INSERT INTO strings VALUES + ('c', 1), ('d', 1), ('a', 3), ('c', 1), From 043ac35a315b9c3b213dab8663d39cd389359d67 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 10 Oct 2024 06:52:52 -0400 Subject: [PATCH 10/13] cleanup --- datafusion/functions-aggregate/src/min_max/min_max_bytes.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs index 0ae9a1f4b25f..bcbd177c8b84 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -378,9 +378,6 @@ fn capacity_to_view_block_size(data_capacity: usize) -> u32 { struct MinMaxBytesState { /// The minimum/maximum value for each group min_max: Vec>>, - // todo use null state - // /// Have we seen any non-null values yet? - // null_state: NullState, /// The data type of the array data_type: DataType, /// The total bytes of the string data (for pre-allocating the final array, From e4548382fe311e6e2c8adf4507a53f20c9cf8e31 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 10 Oct 2024 06:58:06 -0400 Subject: [PATCH 11/13] improve comments --- .../functions-aggregate/src/min_max/min_max_bytes.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs index bcbd177c8b84..4a1babde3d90 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -441,17 +441,18 @@ impl MinMaxBytesState { I: IntoIterator>, { self.min_max.resize(total_num_groups, None); + // Minimize value copies by calculating the new min/maxes for each group + // in this batch (either the existing min/max or the new input value) + // and updating the owne values in `self.min_maxes` at most once let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups]; // Figure out the new min value for each group for (new_val, group_index) in iter.into_iter().zip(group_indices.iter()) { let group_index = *group_index; - // ignore null inputs let Some(new_val) = new_val else { - continue; + continue; // skip nulls }; - // go through contortions to avoid copying strings too many times let existing_val = match locations[group_index] { // previous input value was the min/max, so compare it MinMaxLocation::Input(existing_val) => existing_val, @@ -465,13 +466,13 @@ impl MinMaxBytesState { } }; - // Actually compare the new value to the existing value, replacing if necessary + // Compare the new value to the existing value, replacing if necessary if cmp(new_val, existing_val) { locations[group_index] = MinMaxLocation::Input(new_val); } } - // Update self.min_max with the new min values we found in the input + // Update self.min_max with any new min/max values we found in the input for (group_index, location) in locations.iter().enumerate() { match location { MinMaxLocation::ExistingMinMax => {} From c7aa11f29267c7736da8da408988d4658ecf9aea Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 10 Oct 2024 07:00:39 -0400 Subject: [PATCH 12/13] fix diagram --- .../src/min_max/min_max_bytes.rs | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs index 4a1babde3d90..3ce1ac37625f 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -351,21 +351,21 @@ fn capacity_to_view_block_size(data_capacity: usize) -> u32 { /// fragmentation (aka garbage). /// /// ```text -/// ┌─────────────────────────────────┐ -/// ┌─────┐ ┌─────▶│Option> (["A"]) │──────────────▶ "A" -/// │ 0 │─────┘ └─────────────────────────────────┘ -/// ├─────┤ ┌─────────────────────────────────┐ -/// │ 1 │───────────▶│Option> (["Z"]) │───────────────▶ "Z" -/// └─────┘ └─────────────────────────────────┘ ... -/// ... ... -/// ┌─────┐ ┌────────────────────────────────┐ -/// │ N-2 │ │Option> (["A"]) │────────────────▶ "A" -/// ├─────┤ └────────────────────────────────┘ -/// │ N-1 │─────┐ ┌────────────────────────────────┐ -/// └─────┘ └─────▶│Option> (["Q"]) │────────────────▶ "Q" -/// └────────────────────────────────┘ +/// ┌─────────────────────────────────┐ +/// ┌─────┐ ┌────▶│Option> (["A"]) │───────────▶ "A" +/// │ 0 │────┘ └─────────────────────────────────┘ +/// ├─────┤ ┌─────────────────────────────────┐ +/// │ 1 │─────────▶│Option> (["Z"]) │───────────▶ "Z" +/// └─────┘ └─────────────────────────────────┘ ... +/// ... ... +/// ┌─────┐ ┌────────────────────────────────┐ +/// │ N-2 │─────────▶│Option> (["A"]) │────────────▶ "A" +/// ├─────┤ └────────────────────────────────┘ +/// │ N-1 │────┐ ┌────────────────────────────────┐ +/// └─────┘ └────▶│Option> (["Q"]) │────────────▶ "Q" +/// └────────────────────────────────┘ /// -/// min_max: Vec> +/// min_max: Vec> /// ``` /// /// Note that for `StringViewArray` and `BinaryViewArray`, there are potentially From 576cc3226dd8236aa103ff73255347422379d7b6 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 10 Oct 2024 07:08:09 -0400 Subject: [PATCH 13/13] x2 --- datafusion/functions-aggregate/src/min_max/min_max_bytes.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs index 3ce1ac37625f..bb5c822aaf6a 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -412,7 +412,11 @@ impl MinMaxBytesState { fn set_value(&mut self, group_index: usize, new_val: &[u8]) { match self.min_max[group_index].as_mut() { None => { - self.min_max[group_index] = Some(new_val.to_vec()); + // No existing value, so allocate a new one (allocate 2x the size of the input) + // to avoid re-allocating for small strings + let mut new_vec = Vec::with_capacity(new_val.len() * 2); + new_vec.extend_from_slice(new_val); + self.min_max[group_index] = Some(new_vec); self.total_data_bytes += new_val.len(); } Some(existing_val) => {