Skip to content

Commit 5d424ef

Browse files
authored
Remove AggregateState wrapper (#4582)
* Remove AggregateState wrapper * Remove more unwrap * Fix logical conflicts * Remove unecessary array
1 parent 84d3ae8 commit 5d424ef

25 files changed

+97
-164
lines changed

datafusion-examples/examples/simple_udaf.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ use datafusion::arrow::{
2121
array::ArrayRef, array::Float32Array, datatypes::DataType, record_batch::RecordBatch,
2222
};
2323
use datafusion::from_slice::FromSlice;
24-
use datafusion::logical_expr::AggregateState;
2524
use datafusion::{error::Result, physical_plan::Accumulator};
2625
use datafusion::{logical_expr::Volatility, prelude::*, scalar::ScalarValue};
2726
use datafusion_common::cast::as_float64_array;
@@ -108,10 +107,10 @@ impl Accumulator for GeometricMean {
108107
// This function serializes our state to `ScalarValue`, which DataFusion uses
109108
// to pass this state between execution stages.
110109
// Note that this can be arbitrary data.
111-
fn state(&self) -> Result<Vec<AggregateState>> {
110+
fn state(&self) -> Result<Vec<ScalarValue>> {
112111
Ok(vec![
113-
AggregateState::Scalar(ScalarValue::from(self.prod)),
114-
AggregateState::Scalar(ScalarValue::from(self.n)),
112+
ScalarValue::from(self.prod),
113+
ScalarValue::from(self.n),
115114
])
116115
}
117116

datafusion/core/src/physical_plan/aggregates/hash.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ fn create_batch_from_map(
519519
accumulators.group_states.iter().map(|group_state| {
520520
group_state.accumulator_set[x]
521521
.state()
522-
.and_then(|x| x[y].as_scalar().map(|v| v.clone()))
522+
.map(|x| x[y].clone())
523523
.expect("unexpected accumulator state in hash aggregate")
524524
}),
525525
)?;

datafusion/core/src/physical_plan/windows/mod.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ mod tests {
178178
use arrow::datatypes::{DataType, Field, SchemaRef};
179179
use arrow::record_batch::RecordBatch;
180180
use datafusion_common::cast::as_primitive_array;
181-
use datafusion_expr::{create_udaf, Accumulator, AggregateState, Volatility};
181+
use datafusion_expr::{create_udaf, Accumulator, Volatility};
182182
use futures::FutureExt;
183183

184184
fn create_test_schema(partitions: usize) -> Result<(Arc<CsvExec>, SchemaRef)> {
@@ -193,10 +193,8 @@ mod tests {
193193
struct MyCount(i64);
194194

195195
impl Accumulator for MyCount {
196-
fn state(&self) -> Result<Vec<AggregateState>> {
197-
Ok(vec![AggregateState::Scalar(ScalarValue::Int64(Some(
198-
self.0,
199-
)))])
196+
fn state(&self) -> Result<Vec<ScalarValue>> {
197+
Ok(vec![ScalarValue::Int64(Some(self.0))])
200198
}
201199

202200
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {

datafusion/core/tests/sql/udf.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use datafusion::{
2222
physical_plan::{expressions::AvgAccumulator, functions::make_scalar_function},
2323
};
2424
use datafusion_common::{cast::as_int32_array, ScalarValue};
25-
use datafusion_expr::{create_udaf, Accumulator, AggregateState, LogicalPlanBuilder};
25+
use datafusion_expr::{create_udaf, Accumulator, LogicalPlanBuilder};
2626

2727
/// test that casting happens on udfs.
2828
/// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and
@@ -175,7 +175,7 @@ fn udaf_as_window_func() -> Result<()> {
175175
struct MyAccumulator;
176176

177177
impl Accumulator for MyAccumulator {
178-
fn state(&self) -> Result<Vec<AggregateState>> {
178+
fn state(&self) -> Result<Vec<ScalarValue>> {
179179
unimplemented!()
180180
}
181181

datafusion/core/tests/user_defined_aggregates.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ use datafusion::{
2828
},
2929
assert_batches_eq,
3030
error::Result,
31-
logical_expr::AggregateState,
3231
logical_expr::{
3332
AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction, Signature,
3433
StateTypeFunction, TypeSignature, Volatility,
@@ -210,12 +209,8 @@ impl FirstSelector {
210209
}
211210

212211
impl Accumulator for FirstSelector {
213-
fn state(&self) -> Result<Vec<AggregateState>> {
214-
let state = self
215-
.to_state()
216-
.into_iter()
217-
.map(AggregateState::Scalar)
218-
.collect::<Vec<_>>();
212+
fn state(&self) -> Result<Vec<ScalarValue>> {
213+
let state = self.to_state().into_iter().collect::<Vec<_>>();
219214

220215
Ok(state)
221216
}

datafusion/expr/src/accumulator.rs

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,20 @@ pub trait Accumulator: Send + Sync + Debug {
3737
/// accumulator (that ran on different partitions, for
3838
/// example).
3939
///
40-
/// The state can be a different type than the output of the
41-
/// [`Accumulator`]
40+
/// The state can be and often is a different type than the output
41+
/// type of the [`Accumulator`].
4242
///
4343
/// See [`merge_batch`] for more details on the merging process.
4444
///
45-
/// For example, in the case of an average, for which we track `sum` and `n`,
46-
/// this function should return a vector of two values, sum and n.
47-
fn state(&self) -> Result<Vec<AggregateState>>;
45+
/// Some accumulators can return multiple values for their
46+
/// intermediate states. For example average, tracks `sum` and
47+
/// `n`, and this function should return
48+
/// a vector of two values, sum and n.
49+
///
50+
/// `ScalarValue::List` can also be used to pass multiple values
51+
/// if the number of intermediate values is not known at planning
52+
/// time (e.g. median)
53+
fn state(&self) -> Result<Vec<ScalarValue>>;
4854

4955
/// Updates the accumulator's state from a vector of arrays.
5056
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
@@ -80,38 +86,3 @@ pub trait Accumulator: Send + Sync + Debug {
8086
/// not the `len`
8187
fn size(&self) -> usize;
8288
}
83-
84-
/// Representation of internal accumulator state. Accumulators can potentially have a mix of
85-
/// scalar and array values. It may be desirable to add custom aggregator states here as well
86-
/// in the future (perhaps `Custom(Box<dyn Any>)`?).
87-
#[derive(Debug)]
88-
pub enum AggregateState {
89-
/// Simple scalar value. Note that `ScalarValue::List` can be used to pass multiple
90-
/// values around
91-
Scalar(ScalarValue),
92-
/// Arrays can be used instead of `ScalarValue::List` and could potentially have better
93-
/// performance with large data sets, although this has not been verified. It also allows
94-
/// for use of arrow kernels with less overhead.
95-
Array(ArrayRef),
96-
}
97-
98-
impl AggregateState {
99-
/// Access the aggregate state as a scalar value. An error will occur if the
100-
/// state is not a scalar value.
101-
pub fn as_scalar(&self) -> Result<&ScalarValue> {
102-
match &self {
103-
Self::Scalar(v) => Ok(v),
104-
_ => Err(DataFusionError::Internal(
105-
"AggregateState is not a scalar aggregate".to_string(),
106-
)),
107-
}
108-
}
109-
110-
/// Access the aggregate state as an array value.
111-
pub fn to_array(&self) -> ArrayRef {
112-
match &self {
113-
Self::Scalar(v) => v.to_array(),
114-
Self::Array(array) => array.clone(),
115-
}
116-
}
117-
}

datafusion/expr/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ pub mod utils;
5252
pub mod window_frame;
5353
pub mod window_function;
5454

55-
pub use accumulator::{Accumulator, AggregateState};
55+
pub use accumulator::Accumulator;
5656
pub use aggregate_function::AggregateFunction;
5757
pub use built_in_function::BuiltinScalarFunction;
5858
pub use columnar_value::ColumnarValue;

datafusion/physical-expr/src/aggregate/approx_distinct.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use arrow::datatypes::{
3030
};
3131
use datafusion_common::{downcast_value, ScalarValue};
3232
use datafusion_common::{DataFusionError, Result};
33-
use datafusion_expr::{Accumulator, AggregateState};
33+
use datafusion_expr::Accumulator;
3434
use std::any::Any;
3535
use std::convert::TryFrom;
3636
use std::convert::TryInto;
@@ -231,8 +231,8 @@ macro_rules! default_accumulator_impl {
231231
Ok(())
232232
}
233233

234-
fn state(&self) -> Result<Vec<AggregateState>> {
235-
let value = AggregateState::Scalar(ScalarValue::from(&self.hll));
234+
fn state(&self) -> Result<Vec<ScalarValue>> {
235+
let value = ScalarValue::from(&self.hll);
236236
Ok(vec![value])
237237
}
238238

datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use arrow::{
2929
use datafusion_common::DataFusionError;
3030
use datafusion_common::Result;
3131
use datafusion_common::{downcast_value, ScalarValue};
32-
use datafusion_expr::{Accumulator, AggregateState};
32+
use datafusion_expr::Accumulator;
3333
use std::{any::Any, iter, sync::Arc};
3434

3535
/// APPROX_PERCENTILE_CONT aggregate expression
@@ -357,13 +357,8 @@ impl ApproxPercentileAccumulator {
357357
}
358358

359359
impl Accumulator for ApproxPercentileAccumulator {
360-
fn state(&self) -> Result<Vec<AggregateState>> {
361-
Ok(self
362-
.digest
363-
.to_scalar_state()
364-
.into_iter()
365-
.map(AggregateState::Scalar)
366-
.collect())
360+
fn state(&self) -> Result<Vec<ScalarValue>> {
361+
Ok(self.digest.to_scalar_state().into_iter().collect())
367362
}
368363

369364
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {

datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use arrow::{
2626

2727
use datafusion_common::Result;
2828
use datafusion_common::ScalarValue;
29-
use datafusion_expr::{Accumulator, AggregateState};
29+
use datafusion_expr::Accumulator;
3030

3131
use std::{any::Any, sync::Arc};
3232

@@ -114,7 +114,7 @@ impl ApproxPercentileWithWeightAccumulator {
114114
}
115115

116116
impl Accumulator for ApproxPercentileWithWeightAccumulator {
117-
fn state(&self) -> Result<Vec<AggregateState>> {
117+
fn state(&self) -> Result<Vec<ScalarValue>> {
118118
self.approx_percentile_cont_accumulator.state()
119119
}
120120

datafusion/physical-expr/src/aggregate/array_agg.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use arrow::array::ArrayRef;
2323
use arrow::datatypes::{DataType, Field};
2424
use datafusion_common::ScalarValue;
2525
use datafusion_common::{DataFusionError, Result};
26-
use datafusion_expr::{Accumulator, AggregateState};
26+
use datafusion_expr::Accumulator;
2727
use std::any::Any;
2828
use std::sync::Arc;
2929

@@ -143,8 +143,8 @@ impl Accumulator for ArrayAggAccumulator {
143143
})
144144
}
145145

146-
fn state(&self) -> Result<Vec<AggregateState>> {
147-
Ok(vec![AggregateState::Scalar(self.evaluate()?)])
146+
fn state(&self) -> Result<Vec<ScalarValue>> {
147+
Ok(vec![self.evaluate()?])
148148
}
149149

150150
fn evaluate(&self) -> Result<ScalarValue> {

datafusion/physical-expr/src/aggregate/array_agg_distinct.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use crate::expressions::format_state_name;
2929
use crate::{AggregateExpr, PhysicalExpr};
3030
use datafusion_common::Result;
3131
use datafusion_common::ScalarValue;
32-
use datafusion_expr::{Accumulator, AggregateState};
32+
use datafusion_expr::Accumulator;
3333

3434
/// Expression for a ARRAY_AGG(DISTINCT) aggregation.
3535
#[derive(Debug)]
@@ -119,11 +119,11 @@ impl DistinctArrayAggAccumulator {
119119
}
120120

121121
impl Accumulator for DistinctArrayAggAccumulator {
122-
fn state(&self) -> Result<Vec<AggregateState>> {
123-
Ok(vec![AggregateState::Scalar(ScalarValue::new_list(
122+
fn state(&self) -> Result<Vec<ScalarValue>> {
123+
Ok(vec![ScalarValue::new_list(
124124
Some(self.values.clone().into_iter().collect()),
125125
self.datatype.clone(),
126-
))])
126+
)])
127127
}
128128

129129
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {

datafusion/physical-expr/src/aggregate/average.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use arrow::{
3333
};
3434
use datafusion_common::{downcast_value, ScalarValue};
3535
use datafusion_common::{DataFusionError, Result};
36-
use datafusion_expr::{Accumulator, AggregateState};
36+
use datafusion_expr::Accumulator;
3737
use datafusion_row::accessor::RowAccessor;
3838

3939
/// AVG aggregate expression
@@ -150,11 +150,8 @@ impl AvgAccumulator {
150150
}
151151

152152
impl Accumulator for AvgAccumulator {
153-
fn state(&self) -> Result<Vec<AggregateState>> {
154-
Ok(vec![
155-
AggregateState::Scalar(ScalarValue::from(self.count)),
156-
AggregateState::Scalar(self.sum.clone()),
157-
])
153+
fn state(&self) -> Result<Vec<ScalarValue>> {
154+
Ok(vec![ScalarValue::from(self.count), self.sum.clone()])
158155
}
159156

160157
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {

datafusion/physical-expr/src/aggregate/correlation.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::{AggregateExpr, PhysicalExpr};
2525
use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field};
2626
use datafusion_common::Result;
2727
use datafusion_common::ScalarValue;
28-
use datafusion_expr::{Accumulator, AggregateState};
28+
use datafusion_expr::Accumulator;
2929
use std::any::Any;
3030
use std::sync::Arc;
3131

@@ -133,14 +133,14 @@ impl CorrelationAccumulator {
133133
}
134134

135135
impl Accumulator for CorrelationAccumulator {
136-
fn state(&self) -> Result<Vec<AggregateState>> {
136+
fn state(&self) -> Result<Vec<ScalarValue>> {
137137
Ok(vec![
138-
AggregateState::Scalar(ScalarValue::from(self.covar.get_count())),
139-
AggregateState::Scalar(ScalarValue::from(self.covar.get_mean1())),
140-
AggregateState::Scalar(ScalarValue::from(self.stddev1.get_m2())),
141-
AggregateState::Scalar(ScalarValue::from(self.covar.get_mean2())),
142-
AggregateState::Scalar(ScalarValue::from(self.stddev2.get_m2())),
143-
AggregateState::Scalar(ScalarValue::from(self.covar.get_algo_const())),
138+
ScalarValue::from(self.covar.get_count()),
139+
ScalarValue::from(self.covar.get_mean1()),
140+
ScalarValue::from(self.stddev1.get_m2()),
141+
ScalarValue::from(self.covar.get_mean2()),
142+
ScalarValue::from(self.stddev2.get_m2()),
143+
ScalarValue::from(self.covar.get_algo_const()),
144144
])
145145
}
146146

datafusion/physical-expr/src/aggregate/count.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use arrow::datatypes::DataType;
2929
use arrow::{array::ArrayRef, datatypes::Field};
3030
use datafusion_common::{downcast_value, ScalarValue};
3131
use datafusion_common::{DataFusionError, Result};
32-
use datafusion_expr::{Accumulator, AggregateState};
32+
use datafusion_expr::Accumulator;
3333
use datafusion_row::accessor::RowAccessor;
3434

3535
use crate::expressions::format_state_name;
@@ -119,10 +119,8 @@ impl CountAccumulator {
119119
}
120120

121121
impl Accumulator for CountAccumulator {
122-
fn state(&self) -> Result<Vec<AggregateState>> {
123-
Ok(vec![AggregateState::Scalar(ScalarValue::Int64(Some(
124-
self.count,
125-
)))])
122+
fn state(&self) -> Result<Vec<ScalarValue>> {
123+
Ok(vec![ScalarValue::Int64(Some(self.count))])
126124
}
127125

128126
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {

0 commit comments

Comments
 (0)