Skip to content

Commit 00ef820

Browse files
authored
Support convert_to_state for AVG accumulator (#11734)
* Support `convert_to_state` for `AVG` accumulator * Update datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs * fix documentation * Fix after merge * fix for change in location
1 parent 140f7ce commit 00ef820

File tree

4 files changed

+154
-1
lines changed

4 files changed

+154
-1
lines changed

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
2121
pub mod accumulate;
2222
pub mod bool_op;
23+
pub mod nulls;
2324
pub mod prim_op;
2425

2526
use arrow::{
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`set_nulls`], and [`filtered_null_mask`], utilities for working with nulls
19+
20+
use arrow::array::{Array, ArrowNumericType, BooleanArray, PrimitiveArray};
21+
use arrow::buffer::NullBuffer;
22+
23+
/// Sets the validity mask for a `PrimitiveArray` to `nulls`
24+
/// replacing any existing null mask
25+
pub fn set_nulls<T: ArrowNumericType + Send>(
26+
array: PrimitiveArray<T>,
27+
nulls: Option<NullBuffer>,
28+
) -> PrimitiveArray<T> {
29+
let (dt, values, _old_nulls) = array.into_parts();
30+
PrimitiveArray::<T>::new(values, nulls).with_data_type(dt)
31+
}
32+
33+
/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer.
34+
///
35+
/// The `NullBuffer` is
36+
/// * `true` (representing valid) for values that were `true` in filter
37+
/// * `false` (representing null) for values that were `false` or `null` in filter
38+
fn filter_to_nulls(filter: &BooleanArray) -> Option<NullBuffer> {
39+
let (filter_bools, filter_nulls) = filter.clone().into_parts();
40+
let filter_bools = NullBuffer::from(filter_bools);
41+
NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref())
42+
}
43+
44+
/// Compute an output validity mask for an array that has been filtered
45+
///
46+
/// This can be used to compute nulls for the output of
47+
/// [`GroupsAccumulator::convert_to_state`], which quickly applies an optional
48+
/// filter to the input rows by setting any filtered rows to NULL in the output.
49+
/// Subsequent applications of aggregate functions that ignore NULLs (most of
50+
/// them) will thus ignore the filtered rows as well.
51+
///
52+
/// # Output element is `true` (and thus output is non-null)
53+
///
54+
/// A `true` in the output represents non null output for all values that were *both*:
55+
///
56+
/// * `true` in any `opt_filter` (aka values that passed the filter)
57+
///
58+
/// * `non null` in `input`
59+
///
60+
/// # Output element is `false` (and thus output is null)
61+
///
62+
/// A `false` in the output represents an input that was *either*:
63+
///
64+
/// * `null`
65+
///
66+
/// * filtered (aka the value was `false` or `null` in the filter)
67+
///
68+
/// # Example
69+
///
70+
/// ```text
71+
/// ┌─────┐ ┌─────┐ ┌─────┐
72+
/// │true │ │NULL │ │false│
73+
/// │true │ │ │true │ │true │
74+
/// │true │ ───┼─── │false│ ────────▶ │false│ filtered_nulls
75+
/// │false│ │ │NULL │ │false│
76+
/// │false│ │true │ │false│
77+
/// └─────┘ └─────┘ └─────┘
78+
/// array opt_filter output
79+
/// .nulls()
80+
///
81+
/// false = NULL true = pass false = NULL Meanings
82+
/// true = valid false = filter true = valid
83+
/// NULL = filter
84+
/// ```
85+
///
86+
/// [`GroupsAccumulator::convert_to_state`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator
87+
pub fn filtered_null_mask(
88+
opt_filter: Option<&BooleanArray>,
89+
input: &dyn Array,
90+
) -> Option<NullBuffer> {
91+
let opt_filter = opt_filter.and_then(filter_to_nulls);
92+
NullBuffer::union(opt_filter.as_ref(), input.nulls())
93+
}

datafusion/functions-aggregate/src/average.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
2020
use arrow::array::{
2121
self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType,
22-
AsArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
22+
AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
2323
};
24+
2425
use arrow::compute::sum;
2526
use arrow::datatypes::{
2627
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
@@ -34,7 +35,12 @@ use datafusion_expr::Volatility::Immutable;
3435
use datafusion_expr::{
3536
Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature,
3637
};
38+
3739
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
40+
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
41+
filtered_null_mask, set_nulls,
42+
};
43+
3844
use datafusion_functions_aggregate_common::utils::DecimalAverager;
3945
use log::debug;
4046
use std::any::Any;
@@ -551,6 +557,30 @@ where
551557
Ok(())
552558
}
553559

560+
fn convert_to_state(
561+
&self,
562+
values: &[ArrayRef],
563+
opt_filter: Option<&BooleanArray>,
564+
) -> Result<Vec<ArrayRef>> {
565+
let sums = values[0]
566+
.as_primitive::<T>()
567+
.clone()
568+
.with_data_type(self.sum_data_type.clone());
569+
let counts = UInt64Array::from_value(1, sums.len());
570+
571+
let nulls = filtered_null_mask(opt_filter, &sums);
572+
573+
// set nulls on the arrays
574+
let counts = set_nulls(counts, nulls.clone());
575+
let sums = set_nulls(sums, nulls);
576+
577+
Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
578+
}
579+
580+
fn supports_convert_to_state(&self) -> bool {
581+
true
582+
}
583+
554584
fn size(&self) -> usize {
555585
self.counts.capacity() * std::mem::size_of::<u64>()
556586
+ self.sums.capacity() * std::mem::size_of::<T>()

datafusion/sqllogictest/test_files/aggregate_skip_partial.slt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,21 @@ SELECT c2, sum(c3), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
209209
4 29 9.531112968922
210210
5 -194 7.074412226677
211211

212+
# Test avg for tinyint / float
213+
query TRR
214+
SELECT
215+
c1,
216+
avg(c2),
217+
avg(c11)
218+
FROM aggregate_test_100 GROUP BY c1 ORDER BY c1;
219+
----
220+
a 2.857142857143 0.438223421574
221+
b 3.263157894737 0.496481208425
222+
c 2.666666666667 0.425241138254
223+
d 2.444444444444 0.541519476308
224+
e 3 0.505440263521
225+
226+
212227
# Enabling PG dialect for filtered aggregates tests
213228
statement ok
214229
set datafusion.sql_parser.dialect = 'Postgres';
@@ -267,6 +282,20 @@ FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
267282
4 11 14
268283
5 8 7
269284

285+
# Test avg for tinyint / float
286+
query TRR
287+
SELECT
288+
c1,
289+
avg(c2) FILTER (WHERE c2 != 5),
290+
avg(c11) FILTER (WHERE c2 != 5)
291+
FROM aggregate_test_100 GROUP BY c1 ORDER BY c1;
292+
----
293+
a 2.5 0.449071887467
294+
b 2.642857142857 0.445486298629
295+
c 2.421052631579 0.422882117723
296+
d 2.125 0.518706191331
297+
e 2.789473684211 0.536785323369
298+
270299
# Test count with nullable fields and nullable filter
271300
query III
272301
SELECT c2,

0 commit comments

Comments
 (0)