Skip to content

Commit ee5ac1c

Browse files
committed
Support convert_to_state for AVG accumulator
1 parent 0d994a6 commit ee5ac1c

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

datafusion/functions-aggregate/src/average.rs

+45-1
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
2020
use arrow::array::{
2121
self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType,
22-
AsArray, PrimitiveArray, PrimitiveBuilder, UInt64Array,
22+
AsArray, BooleanArray, Int64Array, PrimitiveArray, PrimitiveBuilder, UInt64Array,
2323
};
24+
use arrow::buffer::NullBuffer;
2425
use arrow::compute::sum;
2526
use arrow::datatypes::{
2627
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
@@ -554,8 +555,51 @@ where
554555
Ok(())
555556
}
556557

558+
fn convert_to_state(
559+
&self,
560+
values: &[ArrayRef],
561+
opt_filter: Option<&BooleanArray>,
562+
) -> Result<Vec<ArrayRef>> {
563+
let counts = Arc::new(Int64Array::from_value(1, values.len()));
564+
let sums = values[0].as_primitive::<T>();
565+
566+
let nulls = filtered_null_mask(opt_filter, sums);
567+
let sums = PrimitiveArray::<T>::new(sums.values().clone(), nulls)
568+
.with_data_type(self.sum_data_type.clone());
569+
570+
Ok(vec![counts, Arc::new(sums)])
571+
}
572+
573+
fn convert_to_state_supported(&self) -> bool {
574+
true
575+
}
576+
557577
fn size(&self) -> usize {
558578
self.counts.capacity() * std::mem::size_of::<u64>()
559579
+ self.sums.capacity() * std::mem::size_of::<T>()
560580
}
561581
}
582+
583+
/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer`
584+
/// where the NullBuffer is true for all values that were true
585+
/// in the filter and `null` for any values that were false or null
586+
fn filter_to_nulls(filter: &BooleanArray) -> Option<NullBuffer> {
587+
let (filter_bools, filter_nulls) = filter.clone().into_parts();
588+
// Only keep values where the filter was true
589+
// convert all false to null
590+
let filter_bools = NullBuffer::from(filter_bools);
591+
NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref())
592+
}
593+
594+
/// Compute the final null mask for an array
595+
///
596+
/// The output null mask :
597+
/// * is true (non null) for all values that were true in the filter and non null in the input
598+
/// * is false (null) for all values that were false in the filter or null in the input
599+
fn filtered_null_mask(
600+
opt_filter: Option<&BooleanArray>,
601+
input: &dyn Array,
602+
) -> Option<NullBuffer> {
603+
let opt_filter = opt_filter.and_then(filter_to_nulls);
604+
NullBuffer::union(opt_filter.as_ref(), input.nulls())
605+
}

0 commit comments

Comments
 (0)