Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use filtered_null_mask in CountGroupsAccumulator and PrimitiveGroupsAccumulator #11825

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,13 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray};
use arrow::buffer::NullBuffer;
use arrow::compute;
use arrow::datatypes::ArrowPrimitiveType;
use crate::aggregate::groups_accumulator::accumulate::NullState;
use crate::aggregate::groups_accumulator::nulls::{filtered_null_mask, set_nulls};
use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, PrimitiveArray};
use arrow::datatypes::DataType;
use datafusion_common::{internal_datafusion_err, DataFusionError, Result};
use datafusion_common::Result;
use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};

use super::accumulate::NullState;
use std::sync::Arc;

/// An accumulator that implements a single operation over
/// [`ArrowPrimitiveType`] where the accumulated state is the same as
Expand Down Expand Up @@ -147,44 +143,31 @@ where
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
let values = values[0].as_primitive::<T>().clone();

// Initializing state with starting values
let initial_state =
PrimitiveArray::<T>::from_value(self.starting_value, values.len());

// Recalculating values in case there is filter
let values = match opt_filter {
None => values,
Some(filter) => {
let (filter_values, filter_nulls) = filter.clone().into_parts();
// Calculating filter mask as a result of bitand of filter, and converting it to null buffer
let filter_bool = match filter_nulls {
Some(filter_nulls) => filter_nulls.inner() & &filter_values,
None => filter_values,
};
let filter_nulls = NullBuffer::from(filter_bool);

// Rebuilding input values with a new nulls mask, which is equal to
// the union of original nulls and filter mask
let (dt, values_buf, original_nulls) = values.into_parts();
let nulls_buf =
NullBuffer::union(original_nulls.as_ref(), Some(&filter_nulls));
PrimitiveArray::<T>::new(values_buf, nulls_buf).with_data_type(dt)
let values = values[0].as_primitive::<T>();

// Figure out which values will be non null in the output
let nulls = filtered_null_mask(opt_filter, values);

// Initializing state with starting value
let mut state = vec![self.starting_value; values.len()];

// update state with any non-filtered input
if nulls.is_some() {
// mask out any filtered / null input values
let values = set_nulls(values.clone(), nulls.clone());
for (state, value) in state.iter_mut().zip(values.iter()) {
if let Some(value) = value {
(self.prim_fn)(state, value);
}
}
} else {
// no nulls in input, so iterate over all values
let all_values = values.values().iter();
for (state, value) in state.iter_mut().zip(all_values) {
(self.prim_fn)(state, *value)
}
};

let state_values = compute::binary_mut(initial_state, &values, |mut x, y| {
(self.prim_fn)(&mut x, y);
x
});
let state_values = state_values
.map_err(|_| {
internal_datafusion_err!(
"initial_values underlying buffer must not be shared"
)
})?
.map_err(DataFusionError::from)?
let state_values = PrimitiveArray::<T>::new(state.into(), nulls)
.with_data_type(self.data_type.clone());

Ok(vec![Arc::new(state_values)])
Expand Down
74 changes: 22 additions & 52 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
// under the License.

use ahash::RandomState;
use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
use datafusion_functions_aggregate_common::aggregate::count_distinct::{
BytesDistinctCountAccumulator, BytesViewDistinctCountAccumulator,
FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator,
};
use std::collections::HashSet;
use std::ops::BitAnd;
use std::{fmt::Debug, sync::Arc};
Expand Down Expand Up @@ -47,11 +50,8 @@ use datafusion_expr::{
EmitTo, GroupsAccumulator, Signature, Volatility,
};
use datafusion_expr::{Expr, ReversedUDAF, TypeSignature};
use datafusion_functions_aggregate_common::aggregate::count_distinct::{
BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
PrimitiveDistinctCountAccumulator,
};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
use datafusion_physical_expr_common::binary_map::OutputType;

make_udaf_expr_and_func!(
Expand Down Expand Up @@ -450,59 +450,29 @@ impl GroupsAccumulator for CountGroupsAccumulator {
/// Converts an input batch directly to a state batch
///
/// The state of `COUNT` is always a single Int64Array:
/// * `1` (for non-null, non filtered values)
/// * `0` (for null values)
/// * `1` (for non null, non filtered values)
/// * `0` (for filtered or null values)
fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
let values = &values[0];

let state_array = match (values.logical_nulls(), opt_filter) {
(None, None) => {
// In case there is no nulls in input and no filter, returning array of 1
Arc::new(Int64Array::from_value(1, values.len()))
}
(Some(nulls), None) => {
// If there are any nulls in input values -- casting `nulls` (true for values, false for nulls)
// of input array to Int64
let nulls = BooleanArray::new(nulls.into_inner(), None);
compute::cast(&nulls, &DataType::Int64)?
}
(None, Some(filter)) => {
// If there is only filter
// - applying filter null mask to filter values by bitand filter values and nulls buffers
// (using buffers guarantees absence of nulls in result)
// - casting result of bitand to Int64 array
let (filter_values, filter_nulls) = filter.clone().into_parts();

let state_buf = match filter_nulls {
Some(filter_nulls) => &filter_values & filter_nulls.inner(),
None => filter_values,
};

let boolean_state = BooleanArray::new(state_buf, None);
compute::cast(&boolean_state, &DataType::Int64)?
}
(Some(nulls), Some(filter)) => {
// For both input nulls and filter
// - applying filter null mask to filter values by bitand filter values and nulls buffers
// (using buffers guarantees absence of nulls in result)
// - applying values null mask to filter buffer by another bitand on filter result and
// nulls from input values
// - casting result to Int64 array
let (filter_values, filter_nulls) = filter.clone().into_parts();

let filter_buf = match filter_nulls {
Some(filter_nulls) => &filter_values & filter_nulls.inner(),
None => filter_values,
};
let state_buf = &filter_buf & nulls.inner();

let boolean_state = BooleanArray::new(state_buf, None);
compute::cast(&boolean_state, &DataType::Int64)?
}
let nulls = filtered_null_mask(opt_filter, values);

let state_array: ArrayRef = if let Some(nulls) = nulls {
// nulls (false) in the filtered mask means we should output 0
// counts for those values.
//
// cast kernel does the following conversion:

// * `true` -> `1`
// * `false` -> `0`
let nulls = BooleanArray::new(nulls.into_inner(), None);
compute::cast(&nulls, &DataType::Int64)?
} else {
// all input values contribute a 1
Arc::new(Int64Array::from_value(1, values.len()))
};

Ok(vec![state_array])
Expand Down