Skip to content

Commit 53728b3

Browse files
RachelintDandandanalamb
authored
Improve speed of median by implementing special GroupsAccumulator (#13681)
* draft of `MedianGroupAccumulator`. * impl `state`. * impl rest methods of `MedianGroupsAccumulator`. * improve comments. * use `MedianGroupsAccumulator`. * remove unused import. * add `group_median_table` to test group median. * complete group median test cases in aggregate slt. * fix type of state. * Clippy * Fmt * add fuzzy tests for median. * fix decimal. * fix clippy. * improve comments. * add median cases with nulls. * Update datafusion/functions-aggregate/src/median.rs Co-authored-by: Andrew Lamb <[email protected]> * use `OffsetBuffer::new_unchecked` in `convert_to_state`. * add todo. * remove assert and switch to i32 try from. * return error when try from failed. --------- Co-authored-by: Daniël Heres <[email protected]> Co-authored-by: Andrew Lamb <[email protected]>
1 parent 11435de commit 53728b3

File tree

3 files changed

+541
-2
lines changed

3 files changed

+541
-2
lines changed

datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,26 @@ async fn test_count() {
148148
.await;
149149
}
150150

151+
#[tokio::test(flavor = "multi_thread")]
152+
async fn test_median() {
153+
let data_gen_config = baseline_config();
154+
155+
// Queries like SELECT median(a), median(distinct) FROM fuzz_table GROUP BY b
156+
let query_builder = QueryBuilder::new()
157+
.with_table_name("fuzz_table")
158+
.with_aggregate_function("median")
159+
.with_distinct_aggregate_function("median")
160+
// median only works on numeric columns
161+
.with_aggregate_arguments(data_gen_config.numeric_columns())
162+
.set_group_by_columns(data_gen_config.all_columns());
163+
164+
AggregationFuzzerBuilder::from(data_gen_config)
165+
.add_query_builder(query_builder)
166+
.build()
167+
.run()
168+
.await;
169+
}
170+
151171
/// Return a standard set of columns for testing data generation
152172
///
153173
/// Includes numeric and string types

datafusion/functions-aggregate/src/median.rs

Lines changed: 260 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ use std::fmt::{Debug, Formatter};
2020
use std::mem::{size_of, size_of_val};
2121
use std::sync::Arc;
2222

23-
use arrow::array::{downcast_integer, ArrowNumericType};
23+
use arrow::array::{
24+
downcast_integer, ArrowNumericType, BooleanArray, ListArray, PrimitiveArray,
25+
PrimitiveBuilder,
26+
};
27+
use arrow::buffer::{OffsetBuffer, ScalarBuffer};
2428
use arrow::{
2529
array::{ArrayRef, AsArray},
2630
datatypes::{
@@ -33,12 +37,17 @@ use arrow::array::Array;
3337
use arrow::array::ArrowNativeTypeOp;
3438
use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType};
3539

36-
use datafusion_common::{DataFusionError, HashSet, Result, ScalarValue};
40+
use datafusion_common::{
41+
internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue,
42+
};
3743
use datafusion_expr::function::StateFieldsArgs;
3844
use datafusion_expr::{
3945
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
4046
Documentation, Signature, Volatility,
4147
};
48+
use datafusion_expr::{EmitTo, GroupsAccumulator};
49+
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate;
50+
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
4251
use datafusion_functions_aggregate_common::utils::Hashable;
4352
use datafusion_macros::user_doc;
4453

@@ -165,6 +174,45 @@ impl AggregateUDFImpl for Median {
165174
}
166175
}
167176

177+
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
178+
!args.is_distinct
179+
}
180+
181+
fn create_groups_accumulator(
182+
&self,
183+
args: AccumulatorArgs,
184+
) -> Result<Box<dyn GroupsAccumulator>> {
185+
let num_args = args.exprs.len();
186+
if num_args != 1 {
187+
return internal_err!(
188+
"median should only have 1 arg, but found num args:{}",
189+
args.exprs.len()
190+
);
191+
}
192+
193+
let dt = args.exprs[0].data_type(args.schema)?;
194+
195+
macro_rules! helper {
196+
($t:ty, $dt:expr) => {
197+
Ok(Box::new(MedianGroupsAccumulator::<$t>::new($dt)))
198+
};
199+
}
200+
201+
downcast_integer! {
202+
dt => (helper, dt),
203+
DataType::Float16 => helper!(Float16Type, dt),
204+
DataType::Float32 => helper!(Float32Type, dt),
205+
DataType::Float64 => helper!(Float64Type, dt),
206+
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
207+
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
208+
_ => Err(DataFusionError::NotImplemented(format!(
209+
"MedianGroupsAccumulator not supported for {} with {}",
210+
args.name,
211+
dt,
212+
))),
213+
}
214+
}
215+
168216
fn aliases(&self) -> &[String] {
169217
&[]
170218
}
@@ -230,6 +278,216 @@ impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
230278
}
231279
}
232280

281+
/// The median groups accumulator accumulates the raw input values
282+
///
283+
/// For calculating the accurate medians of groups, we need to store all values
284+
/// of groups before final evaluation.
285+
/// So values in each group will be stored in a `Vec<T>`, and the total group values
286+
/// will be actually organized as a `Vec<Vec<T>>`.
287+
///
288+
#[derive(Debug)]
289+
struct MedianGroupsAccumulator<T: ArrowNumericType + Send> {
290+
data_type: DataType,
291+
group_values: Vec<Vec<T::Native>>,
292+
}
293+
294+
impl<T: ArrowNumericType + Send> MedianGroupsAccumulator<T> {
295+
pub fn new(data_type: DataType) -> Self {
296+
Self {
297+
data_type,
298+
group_values: Vec::new(),
299+
}
300+
}
301+
}
302+
303+
impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T> {
304+
fn update_batch(
305+
&mut self,
306+
values: &[ArrayRef],
307+
group_indices: &[usize],
308+
opt_filter: Option<&BooleanArray>,
309+
total_num_groups: usize,
310+
) -> Result<()> {
311+
assert_eq!(values.len(), 1, "single argument to update_batch");
312+
let values = values[0].as_primitive::<T>();
313+
314+
// Push the `not nulls + not filtered` row into its group
315+
self.group_values.resize(total_num_groups, Vec::new());
316+
accumulate(
317+
group_indices,
318+
values,
319+
opt_filter,
320+
|group_index, new_value| {
321+
self.group_values[group_index].push(new_value);
322+
},
323+
);
324+
325+
Ok(())
326+
}
327+
328+
fn merge_batch(
329+
&mut self,
330+
values: &[ArrayRef],
331+
group_indices: &[usize],
332+
// Since aggregate filter should be applied in partial stage, in final stage there should be no filter
333+
_opt_filter: Option<&BooleanArray>,
334+
total_num_groups: usize,
335+
) -> Result<()> {
336+
assert_eq!(values.len(), 1, "one argument to merge_batch");
337+
338+
// The merged values should be organized like as a `ListArray` which is nullable
339+
// (input with nulls usually generated from `convert_to_state`), but `inner array` of
340+
// `ListArray` is `non-nullable`.
341+
//
342+
// Following is the possible and impossible input `values`:
343+
//
344+
// # Possible values
345+
// ```text
346+
// group 0: [1, 2, 3]
347+
// group 1: null (list array is nullable)
348+
// group 2: [6, 7, 8]
349+
// ...
350+
// group n: [...]
351+
// ```
352+
//
353+
// # Impossible values
354+
// ```text
355+
// group x: [1, 2, null] (values in list array is non-nullable)
356+
// ```
357+
//
358+
let input_group_values = values[0].as_list::<i32>();
359+
360+
// Ensure group values big enough
361+
self.group_values.resize(total_num_groups, Vec::new());
362+
363+
// Extend values to related groups
364+
// TODO: avoid using iterator of the `ListArray`, this will lead to
365+
// many calls of `slice` of its ``inner array`, and `slice` is not
366+
// so efficient(due to the calculation of `null_count` for each `slice`).
367+
group_indices
368+
.iter()
369+
.zip(input_group_values.iter())
370+
.for_each(|(&group_index, values_opt)| {
371+
if let Some(values) = values_opt {
372+
let values = values.as_primitive::<T>();
373+
self.group_values[group_index].extend(values.values().iter());
374+
}
375+
});
376+
377+
Ok(())
378+
}
379+
380+
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
381+
// Emit values
382+
let emit_group_values = emit_to.take_needed(&mut self.group_values);
383+
384+
// Build offsets
385+
let mut offsets = Vec::with_capacity(self.group_values.len() + 1);
386+
offsets.push(0);
387+
let mut cur_len = 0_i32;
388+
for group_value in &emit_group_values {
389+
cur_len += group_value.len() as i32;
390+
offsets.push(cur_len);
391+
}
392+
// TODO: maybe we can use `OffsetBuffer::new_unchecked` like what in `convert_to_state`,
393+
// but safety should be considered more carefully here(and I am not sure if it can get
394+
// performance improvement when we introduce checks to keep the safety...).
395+
//
396+
// Can see more details in:
397+
// https://github.com/apache/datafusion/pull/13681#discussion_r1931209791
398+
//
399+
let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets));
400+
401+
// Build inner array
402+
let flatten_group_values =
403+
emit_group_values.into_iter().flatten().collect::<Vec<_>>();
404+
let group_values_array =
405+
PrimitiveArray::<T>::new(ScalarBuffer::from(flatten_group_values), None)
406+
.with_data_type(self.data_type.clone());
407+
408+
// Build the result list array
409+
let result_list_array = ListArray::new(
410+
Arc::new(Field::new_list_field(self.data_type.clone(), true)),
411+
offsets,
412+
Arc::new(group_values_array),
413+
None,
414+
);
415+
416+
Ok(vec![Arc::new(result_list_array)])
417+
}
418+
419+
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
420+
// Emit values
421+
let emit_group_values = emit_to.take_needed(&mut self.group_values);
422+
423+
// Calculate median for each group
424+
let mut evaluate_result_builder =
425+
PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
426+
for values in emit_group_values {
427+
let median = calculate_median::<T>(values);
428+
evaluate_result_builder.append_option(median);
429+
}
430+
431+
Ok(Arc::new(evaluate_result_builder.finish()))
432+
}
433+
434+
fn convert_to_state(
435+
&self,
436+
values: &[ArrayRef],
437+
opt_filter: Option<&BooleanArray>,
438+
) -> Result<Vec<ArrayRef>> {
439+
assert_eq!(values.len(), 1, "one argument to merge_batch");
440+
441+
let input_array = values[0].as_primitive::<T>();
442+
443+
// Directly convert the input array to states, each row will be
444+
// seen as a respective group.
445+
// For detail, the `input_array` will be converted to a `ListArray`.
446+
// And if row is `not null + not filtered`, it will be converted to a list
447+
// with only one element; otherwise, this row in `ListArray` will be set
448+
// to null.
449+
450+
// Reuse values buffer in `input_array` to build `values` in `ListArray`
451+
let values = PrimitiveArray::<T>::new(input_array.values().clone(), None)
452+
.with_data_type(self.data_type.clone());
453+
454+
// `offsets` in `ListArray`, each row as a list element
455+
let offset_end = i32::try_from(input_array.len()).map_err(|e| {
456+
internal_datafusion_err!(
457+
"cast array_len to i32 failed in convert_to_state of group median, err:{e:?}"
458+
)
459+
})?;
460+
let offsets = (0..=offset_end).collect::<Vec<_>>();
461+
// Safety: all checks in `OffsetBuffer::new` are ensured to pass
462+
let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
463+
464+
// `nulls` for converted `ListArray`
465+
let nulls = filtered_null_mask(opt_filter, input_array);
466+
467+
let converted_list_array = ListArray::new(
468+
Arc::new(Field::new_list_field(self.data_type.clone(), true)),
469+
offsets,
470+
Arc::new(values),
471+
nulls,
472+
);
473+
474+
Ok(vec![Arc::new(converted_list_array)])
475+
}
476+
477+
fn supports_convert_to_state(&self) -> bool {
478+
true
479+
}
480+
481+
fn size(&self) -> usize {
482+
self.group_values
483+
.iter()
484+
.map(|values| values.capacity() * size_of::<T>())
485+
.sum::<usize>()
486+
// account for size of self.grou_values too
487+
+ self.group_values.capacity() * size_of::<Vec<T>>()
488+
}
489+
}
490+
233491
/// The distinct median accumulator accumulates the raw input values
234492
/// as `ScalarValue`s
235493
///

0 commit comments

Comments
 (0)