Skip to content

Commit 646f40a

Browse files
authored
Implement special min/max accumulator for Strings and Binary (10% faster for Clickbench Q28) (#12792)
* Implement special min/max accumulator for Strings: `MinMaxBytesAccumulator` * fix bug * fix msrv * move code, handle filters * simplify * Add functional tests * remove unecessary test * improve docs * improve docs * cleanup * improve comments * fix diagram * fix accounting * Use correct type in memory accounting * Add TODO comment
1 parent ebfc155 commit 646f40a

File tree

5 files changed

+872
-57
lines changed

5 files changed

+872
-57
lines changed

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ impl NullState {
9595
///
9696
/// When value_fn is called it also sets
9797
///
98-
/// 1. `self.seen_values[group_index]` to true for all rows that had a non null vale
98+
/// 1. `self.seen_values[group_index]` to true for all rows that had a non null value
9999
pub fn accumulate<T, F>(
100100
&mut self,
101101
group_indices: &[usize],

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs

+113-2
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,22 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
//! [`set_nulls`], and [`filtered_null_mask`], utilities for working with nulls
18+
//! [`set_nulls`], other utilities for working with nulls
1919
20-
use arrow::array::{Array, ArrowNumericType, BooleanArray, PrimitiveArray};
20+
use arrow::array::{
21+
Array, ArrayRef, ArrowNumericType, AsArray, BinaryArray, BinaryViewArray,
22+
BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray,
23+
StringViewArray,
24+
};
2125
use arrow::buffer::NullBuffer;
26+
use arrow::datatypes::DataType;
27+
use datafusion_common::{not_impl_err, Result};
28+
use std::sync::Arc;
2229

2330
/// Sets the validity mask for a `PrimitiveArray` to `nulls`
2431
/// replacing any existing null mask
32+
///
33+
/// See [`set_nulls_dyn`] for a version that works with `Array`
2534
pub fn set_nulls<T: ArrowNumericType + Send>(
2635
array: PrimitiveArray<T>,
2736
nulls: Option<NullBuffer>,
@@ -91,3 +100,105 @@ pub fn filtered_null_mask(
91100
let opt_filter = opt_filter.and_then(filter_to_nulls);
92101
NullBuffer::union(opt_filter.as_ref(), input.nulls())
93102
}
103+
104+
/// Applies optional filter to input, returning a new array of the same type
105+
/// with the same data, but with any values that were filtered out set to null
106+
pub fn apply_filter_as_nulls(
107+
input: &dyn Array,
108+
opt_filter: Option<&BooleanArray>,
109+
) -> Result<ArrayRef> {
110+
let nulls = filtered_null_mask(opt_filter, input);
111+
set_nulls_dyn(input, nulls)
112+
}
113+
114+
/// Replaces the nulls in the input array with the given `NullBuffer`
115+
///
116+
/// TODO: replace when upstreamed in arrow-rs: <https://github.com/apache/arrow-rs/issues/6528>
117+
pub fn set_nulls_dyn(input: &dyn Array, nulls: Option<NullBuffer>) -> Result<ArrayRef> {
118+
if let Some(nulls) = nulls.as_ref() {
119+
assert_eq!(nulls.len(), input.len());
120+
}
121+
122+
let output: ArrayRef = match input.data_type() {
123+
DataType::Utf8 => {
124+
let input = input.as_string::<i32>();
125+
// safety: values / offsets came from a valid string array, so are valid utf8
126+
// and we checked nulls has the same length as values
127+
unsafe {
128+
Arc::new(StringArray::new_unchecked(
129+
input.offsets().clone(),
130+
input.values().clone(),
131+
nulls,
132+
))
133+
}
134+
}
135+
DataType::LargeUtf8 => {
136+
let input = input.as_string::<i64>();
137+
// safety: values / offsets came from a valid string array, so are valid utf8
138+
// and we checked nulls has the same length as values
139+
unsafe {
140+
Arc::new(LargeStringArray::new_unchecked(
141+
input.offsets().clone(),
142+
input.values().clone(),
143+
nulls,
144+
))
145+
}
146+
}
147+
DataType::Utf8View => {
148+
let input = input.as_string_view();
149+
// safety: values / views came from a valid string view array, so are valid utf8
150+
// and we checked nulls has the same length as values
151+
unsafe {
152+
Arc::new(StringViewArray::new_unchecked(
153+
input.views().clone(),
154+
input.data_buffers().to_vec(),
155+
nulls,
156+
))
157+
}
158+
}
159+
160+
DataType::Binary => {
161+
let input = input.as_binary::<i32>();
162+
// safety: values / offsets came from a valid binary array
163+
// and we checked nulls has the same length as values
164+
unsafe {
165+
Arc::new(BinaryArray::new_unchecked(
166+
input.offsets().clone(),
167+
input.values().clone(),
168+
nulls,
169+
))
170+
}
171+
}
172+
DataType::LargeBinary => {
173+
let input = input.as_binary::<i64>();
174+
// safety: values / offsets came from a valid large binary array
175+
// and we checked nulls has the same length as values
176+
unsafe {
177+
Arc::new(LargeBinaryArray::new_unchecked(
178+
input.offsets().clone(),
179+
input.values().clone(),
180+
nulls,
181+
))
182+
}
183+
}
184+
DataType::BinaryView => {
185+
let input = input.as_binary_view();
186+
// safety: values / views came from a valid binary view array
187+
// and we checked nulls has the same length as values
188+
unsafe {
189+
Arc::new(BinaryViewArray::new_unchecked(
190+
input.views().clone(),
191+
input.data_buffers().to_vec(),
192+
nulls,
193+
))
194+
}
195+
}
196+
_ => {
197+
return not_impl_err!("Applying nulls {:?}", input.data_type());
198+
}
199+
};
200+
assert_eq!(input.len(), output.len());
201+
assert_eq!(input.data_type(), output.data_type());
202+
203+
Ok(output)
204+
}

datafusion/functions-aggregate/src/min_max.rs

+69-54
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function
1818
//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function
1919
20+
mod min_max_bytes;
21+
2022
use arrow::array::{
2123
ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
2224
Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array,
@@ -50,6 +52,7 @@ use arrow::datatypes::{
5052
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
5153
};
5254

55+
use crate::min_max::min_max_bytes::MinMaxBytesAccumulator;
5356
use datafusion_common::ScalarValue;
5457
use datafusion_expr::{
5558
function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation, Signature,
@@ -104,7 +107,7 @@ impl Default for Max {
104107
/// the specified [`ArrowPrimitiveType`].
105108
///
106109
/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
107-
macro_rules! instantiate_max_accumulator {
110+
macro_rules! primitive_max_accumulator {
108111
($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
109112
Ok(Box::new(
110113
PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| {
@@ -123,7 +126,7 @@ macro_rules! instantiate_max_accumulator {
123126
///
124127
///
125128
/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
126-
macro_rules! instantiate_min_accumulator {
129+
macro_rules! primitive_min_accumulator {
127130
($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
128131
Ok(Box::new(
129132
PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| {
@@ -231,6 +234,12 @@ impl AggregateUDFImpl for Max {
231234
| Time32(_)
232235
| Time64(_)
233236
| Timestamp(_, _)
237+
| Utf8
238+
| LargeUtf8
239+
| Utf8View
240+
| Binary
241+
| LargeBinary
242+
| BinaryView
234243
)
235244
}
236245

@@ -242,58 +251,58 @@ impl AggregateUDFImpl for Max {
242251
use TimeUnit::*;
243252
let data_type = args.return_type;
244253
match data_type {
245-
Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type),
246-
Int16 => instantiate_max_accumulator!(data_type, i16, Int16Type),
247-
Int32 => instantiate_max_accumulator!(data_type, i32, Int32Type),
248-
Int64 => instantiate_max_accumulator!(data_type, i64, Int64Type),
249-
UInt8 => instantiate_max_accumulator!(data_type, u8, UInt8Type),
250-
UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type),
251-
UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type),
252-
UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type),
254+
Int8 => primitive_max_accumulator!(data_type, i8, Int8Type),
255+
Int16 => primitive_max_accumulator!(data_type, i16, Int16Type),
256+
Int32 => primitive_max_accumulator!(data_type, i32, Int32Type),
257+
Int64 => primitive_max_accumulator!(data_type, i64, Int64Type),
258+
UInt8 => primitive_max_accumulator!(data_type, u8, UInt8Type),
259+
UInt16 => primitive_max_accumulator!(data_type, u16, UInt16Type),
260+
UInt32 => primitive_max_accumulator!(data_type, u32, UInt32Type),
261+
UInt64 => primitive_max_accumulator!(data_type, u64, UInt64Type),
253262
Float16 => {
254-
instantiate_max_accumulator!(data_type, f16, Float16Type)
263+
primitive_max_accumulator!(data_type, f16, Float16Type)
255264
}
256265
Float32 => {
257-
instantiate_max_accumulator!(data_type, f32, Float32Type)
266+
primitive_max_accumulator!(data_type, f32, Float32Type)
258267
}
259268
Float64 => {
260-
instantiate_max_accumulator!(data_type, f64, Float64Type)
269+
primitive_max_accumulator!(data_type, f64, Float64Type)
261270
}
262-
Date32 => instantiate_max_accumulator!(data_type, i32, Date32Type),
263-
Date64 => instantiate_max_accumulator!(data_type, i64, Date64Type),
271+
Date32 => primitive_max_accumulator!(data_type, i32, Date32Type),
272+
Date64 => primitive_max_accumulator!(data_type, i64, Date64Type),
264273
Time32(Second) => {
265-
instantiate_max_accumulator!(data_type, i32, Time32SecondType)
274+
primitive_max_accumulator!(data_type, i32, Time32SecondType)
266275
}
267276
Time32(Millisecond) => {
268-
instantiate_max_accumulator!(data_type, i32, Time32MillisecondType)
277+
primitive_max_accumulator!(data_type, i32, Time32MillisecondType)
269278
}
270279
Time64(Microsecond) => {
271-
instantiate_max_accumulator!(data_type, i64, Time64MicrosecondType)
280+
primitive_max_accumulator!(data_type, i64, Time64MicrosecondType)
272281
}
273282
Time64(Nanosecond) => {
274-
instantiate_max_accumulator!(data_type, i64, Time64NanosecondType)
283+
primitive_max_accumulator!(data_type, i64, Time64NanosecondType)
275284
}
276285
Timestamp(Second, _) => {
277-
instantiate_max_accumulator!(data_type, i64, TimestampSecondType)
286+
primitive_max_accumulator!(data_type, i64, TimestampSecondType)
278287
}
279288
Timestamp(Millisecond, _) => {
280-
instantiate_max_accumulator!(data_type, i64, TimestampMillisecondType)
289+
primitive_max_accumulator!(data_type, i64, TimestampMillisecondType)
281290
}
282291
Timestamp(Microsecond, _) => {
283-
instantiate_max_accumulator!(data_type, i64, TimestampMicrosecondType)
292+
primitive_max_accumulator!(data_type, i64, TimestampMicrosecondType)
284293
}
285294
Timestamp(Nanosecond, _) => {
286-
instantiate_max_accumulator!(data_type, i64, TimestampNanosecondType)
295+
primitive_max_accumulator!(data_type, i64, TimestampNanosecondType)
287296
}
288297
Decimal128(_, _) => {
289-
instantiate_max_accumulator!(data_type, i128, Decimal128Type)
298+
primitive_max_accumulator!(data_type, i128, Decimal128Type)
290299
}
291300
Decimal256(_, _) => {
292-
instantiate_max_accumulator!(data_type, i256, Decimal256Type)
301+
primitive_max_accumulator!(data_type, i256, Decimal256Type)
302+
}
303+
Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => {
304+
Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone())))
293305
}
294-
295-
// It would be nice to have a fast implementation for Strings as well
296-
// https://github.com/apache/datafusion/issues/6906
297306

298307
// This is only reached if groups_accumulator_supported is out of sync
299308
_ => internal_err!("GroupsAccumulator not supported for max({})", data_type),
@@ -1057,6 +1066,12 @@ impl AggregateUDFImpl for Min {
10571066
| Time32(_)
10581067
| Time64(_)
10591068
| Timestamp(_, _)
1069+
| Utf8
1070+
| LargeUtf8
1071+
| Utf8View
1072+
| Binary
1073+
| LargeBinary
1074+
| BinaryView
10601075
)
10611076
}
10621077

@@ -1068,58 +1083,58 @@ impl AggregateUDFImpl for Min {
10681083
use TimeUnit::*;
10691084
let data_type = args.return_type;
10701085
match data_type {
1071-
Int8 => instantiate_min_accumulator!(data_type, i8, Int8Type),
1072-
Int16 => instantiate_min_accumulator!(data_type, i16, Int16Type),
1073-
Int32 => instantiate_min_accumulator!(data_type, i32, Int32Type),
1074-
Int64 => instantiate_min_accumulator!(data_type, i64, Int64Type),
1075-
UInt8 => instantiate_min_accumulator!(data_type, u8, UInt8Type),
1076-
UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type),
1077-
UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type),
1078-
UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type),
1086+
Int8 => primitive_min_accumulator!(data_type, i8, Int8Type),
1087+
Int16 => primitive_min_accumulator!(data_type, i16, Int16Type),
1088+
Int32 => primitive_min_accumulator!(data_type, i32, Int32Type),
1089+
Int64 => primitive_min_accumulator!(data_type, i64, Int64Type),
1090+
UInt8 => primitive_min_accumulator!(data_type, u8, UInt8Type),
1091+
UInt16 => primitive_min_accumulator!(data_type, u16, UInt16Type),
1092+
UInt32 => primitive_min_accumulator!(data_type, u32, UInt32Type),
1093+
UInt64 => primitive_min_accumulator!(data_type, u64, UInt64Type),
10791094
Float16 => {
1080-
instantiate_min_accumulator!(data_type, f16, Float16Type)
1095+
primitive_min_accumulator!(data_type, f16, Float16Type)
10811096
}
10821097
Float32 => {
1083-
instantiate_min_accumulator!(data_type, f32, Float32Type)
1098+
primitive_min_accumulator!(data_type, f32, Float32Type)
10841099
}
10851100
Float64 => {
1086-
instantiate_min_accumulator!(data_type, f64, Float64Type)
1101+
primitive_min_accumulator!(data_type, f64, Float64Type)
10871102
}
1088-
Date32 => instantiate_min_accumulator!(data_type, i32, Date32Type),
1089-
Date64 => instantiate_min_accumulator!(data_type, i64, Date64Type),
1103+
Date32 => primitive_min_accumulator!(data_type, i32, Date32Type),
1104+
Date64 => primitive_min_accumulator!(data_type, i64, Date64Type),
10901105
Time32(Second) => {
1091-
instantiate_min_accumulator!(data_type, i32, Time32SecondType)
1106+
primitive_min_accumulator!(data_type, i32, Time32SecondType)
10921107
}
10931108
Time32(Millisecond) => {
1094-
instantiate_min_accumulator!(data_type, i32, Time32MillisecondType)
1109+
primitive_min_accumulator!(data_type, i32, Time32MillisecondType)
10951110
}
10961111
Time64(Microsecond) => {
1097-
instantiate_min_accumulator!(data_type, i64, Time64MicrosecondType)
1112+
primitive_min_accumulator!(data_type, i64, Time64MicrosecondType)
10981113
}
10991114
Time64(Nanosecond) => {
1100-
instantiate_min_accumulator!(data_type, i64, Time64NanosecondType)
1115+
primitive_min_accumulator!(data_type, i64, Time64NanosecondType)
11011116
}
11021117
Timestamp(Second, _) => {
1103-
instantiate_min_accumulator!(data_type, i64, TimestampSecondType)
1118+
primitive_min_accumulator!(data_type, i64, TimestampSecondType)
11041119
}
11051120
Timestamp(Millisecond, _) => {
1106-
instantiate_min_accumulator!(data_type, i64, TimestampMillisecondType)
1121+
primitive_min_accumulator!(data_type, i64, TimestampMillisecondType)
11071122
}
11081123
Timestamp(Microsecond, _) => {
1109-
instantiate_min_accumulator!(data_type, i64, TimestampMicrosecondType)
1124+
primitive_min_accumulator!(data_type, i64, TimestampMicrosecondType)
11101125
}
11111126
Timestamp(Nanosecond, _) => {
1112-
instantiate_min_accumulator!(data_type, i64, TimestampNanosecondType)
1127+
primitive_min_accumulator!(data_type, i64, TimestampNanosecondType)
11131128
}
11141129
Decimal128(_, _) => {
1115-
instantiate_min_accumulator!(data_type, i128, Decimal128Type)
1130+
primitive_min_accumulator!(data_type, i128, Decimal128Type)
11161131
}
11171132
Decimal256(_, _) => {
1118-
instantiate_min_accumulator!(data_type, i256, Decimal256Type)
1133+
primitive_min_accumulator!(data_type, i256, Decimal256Type)
1134+
}
1135+
Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => {
1136+
Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone())))
11191137
}
1120-
1121-
// It would be nice to have a fast implementation for Strings as well
1122-
// https://github.com/apache/datafusion/issues/6906
11231138

11241139
// This is only reached if groups_accumulator_supported is out of sync
11251140
_ => internal_err!("GroupsAccumulator not supported for min({})", data_type),

0 commit comments

Comments
 (0)