Skip to content

Commit 245def0

Browse files
andygrovealamb
andauthored
Implement exact median, add AggregateState (#3009)
* Implement exact median * revert some changes * toml format * add median to protobuf * remove some unwraps * remove some unwraps * remove some unwraps * fix * clippy * reduce code duplication * reduce code duplication * more tests * move tests to simplify github diff * Update datafusion/expr/src/accumulator.rs Co-authored-by: Andrew Lamb <[email protected]> * refactor to make it more obvious that empty arrays are being created * partially address feedback * Update datafusion/physical-expr/src/aggregate/count_distinct.rs Co-authored-by: Andrew Lamb <[email protected]> * add more tests * more docs * clippy * avoid a clone Co-authored-by: Andrew Lamb <[email protected]>
1 parent 581934d commit 245def0

32 files changed

+645
-107
lines changed

datafusion-examples/examples/simple_udaf.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use datafusion::arrow::{
2323
};
2424

2525
use datafusion::from_slice::FromSlice;
26+
use datafusion::logical_expr::AggregateState;
2627
use datafusion::{error::Result, logical_plan::create_udaf, physical_plan::Accumulator};
2728
use datafusion::{logical_expr::Volatility, prelude::*, scalar::ScalarValue};
2829
use std::sync::Arc;
@@ -107,10 +108,10 @@ impl Accumulator for GeometricMean {
107108
// This function serializes our state to `ScalarValue`, which DataFusion uses
108109
// to pass this state between execution stages.
109110
// Note that this can be arbitrary data.
110-
fn state(&self) -> Result<Vec<ScalarValue>> {
111+
fn state(&self) -> Result<Vec<AggregateState>> {
111112
Ok(vec![
112-
ScalarValue::from(self.prod),
113-
ScalarValue::from(self.n),
113+
AggregateState::Scalar(ScalarValue::from(self.prod)),
114+
AggregateState::Scalar(ScalarValue::from(self.n)),
114115
])
115116
}
116117

datafusion/common/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,5 @@ object_store = { version = "0.3", optional = true }
4545
ordered-float = "3.0"
4646
parquet = { version = "19.0.0", features = ["arrow"], optional = true }
4747
pyo3 = { version = "0.16", optional = true }
48+
serde_json = "1.0"
4849
sqlparser = "0.19"

datafusion/core/src/physical_plan/aggregates/hash.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,10 @@ fn create_batch_from_map(
428428
AggregateMode::Partial => {
429429
let res = ScalarValue::iter_to_array(
430430
accumulators.group_states.iter().map(|group_state| {
431-
let x = group_state.accumulator_set[x].state().unwrap();
432-
x[y].clone()
431+
group_state.accumulator_set[x]
432+
.state()
433+
.and_then(|x| x[y].as_scalar().map(|v| v.clone()))
434+
.expect("unexpected accumulator state in hash aggregate")
433435
}),
434436
)?;
435437

datafusion/core/tests/sql/aggregates.rs

Lines changed: 186 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ async fn csv_query_stddev_6() -> Result<()> {
221221
}
222222

223223
#[tokio::test]
224-
async fn csv_query_median_1() -> Result<()> {
224+
async fn csv_query_approx_median_1() -> Result<()> {
225225
let ctx = SessionContext::new();
226226
register_aggregate_csv(&ctx).await?;
227227
let sql = "SELECT approx_median(c2) FROM aggregate_test_100";
@@ -232,7 +232,7 @@ async fn csv_query_median_1() -> Result<()> {
232232
}
233233

234234
#[tokio::test]
235-
async fn csv_query_median_2() -> Result<()> {
235+
async fn csv_query_approx_median_2() -> Result<()> {
236236
let ctx = SessionContext::new();
237237
register_aggregate_csv(&ctx).await?;
238238
let sql = "SELECT approx_median(c6) FROM aggregate_test_100";
@@ -243,7 +243,7 @@ async fn csv_query_median_2() -> Result<()> {
243243
}
244244

245245
#[tokio::test]
246-
async fn csv_query_median_3() -> Result<()> {
246+
async fn csv_query_approx_median_3() -> Result<()> {
247247
let ctx = SessionContext::new();
248248
register_aggregate_csv(&ctx).await?;
249249
let sql = "SELECT approx_median(c12) FROM aggregate_test_100";
@@ -253,6 +253,189 @@ async fn csv_query_median_3() -> Result<()> {
253253
Ok(())
254254
}
255255

256+
#[tokio::test]
257+
async fn csv_query_median_1() -> Result<()> {
258+
let ctx = SessionContext::new();
259+
register_aggregate_csv(&ctx).await?;
260+
let sql = "SELECT median(c2) FROM aggregate_test_100";
261+
let actual = execute(&ctx, sql).await;
262+
let expected = vec![vec!["3"]];
263+
assert_float_eq(&expected, &actual);
264+
Ok(())
265+
}
266+
267+
#[tokio::test]
268+
async fn csv_query_median_2() -> Result<()> {
269+
let ctx = SessionContext::new();
270+
register_aggregate_csv(&ctx).await?;
271+
let sql = "SELECT median(c6) FROM aggregate_test_100";
272+
let actual = execute(&ctx, sql).await;
273+
let expected = vec![vec!["1125553990140691277"]];
274+
assert_float_eq(&expected, &actual);
275+
Ok(())
276+
}
277+
278+
#[tokio::test]
279+
async fn csv_query_median_3() -> Result<()> {
280+
let ctx = SessionContext::new();
281+
register_aggregate_csv(&ctx).await?;
282+
let sql = "SELECT median(c12) FROM aggregate_test_100";
283+
let actual = execute(&ctx, sql).await;
284+
let expected = vec![vec!["0.5513900544385053"]];
285+
assert_float_eq(&expected, &actual);
286+
Ok(())
287+
}
288+
289+
#[tokio::test]
290+
async fn median_i8() -> Result<()> {
291+
median_test(
292+
"median",
293+
DataType::Int8,
294+
Arc::new(Int8Array::from(vec![i8::MIN, i8::MIN, 100, i8::MAX])),
295+
"-14",
296+
)
297+
.await
298+
}
299+
300+
#[tokio::test]
301+
async fn median_i16() -> Result<()> {
302+
median_test(
303+
"median",
304+
DataType::Int16,
305+
Arc::new(Int16Array::from(vec![i16::MIN, i16::MIN, 100, i16::MAX])),
306+
"-16334",
307+
)
308+
.await
309+
}
310+
311+
#[tokio::test]
312+
async fn median_i32() -> Result<()> {
313+
median_test(
314+
"median",
315+
DataType::Int32,
316+
Arc::new(Int32Array::from(vec![i32::MIN, i32::MIN, 100, i32::MAX])),
317+
"-1073741774",
318+
)
319+
.await
320+
}
321+
322+
#[tokio::test]
323+
async fn median_i64() -> Result<()> {
324+
median_test(
325+
"median",
326+
DataType::Int64,
327+
Arc::new(Int64Array::from(vec![i64::MIN, i64::MIN, 100, i64::MAX])),
328+
"-4611686018427388000",
329+
)
330+
.await
331+
}
332+
333+
#[tokio::test]
334+
async fn median_u8() -> Result<()> {
335+
median_test(
336+
"median",
337+
DataType::UInt8,
338+
Arc::new(UInt8Array::from(vec![u8::MIN, u8::MIN, 100, u8::MAX])),
339+
"50",
340+
)
341+
.await
342+
}
343+
344+
#[tokio::test]
345+
async fn median_u16() -> Result<()> {
346+
median_test(
347+
"median",
348+
DataType::UInt16,
349+
Arc::new(UInt16Array::from(vec![u16::MIN, u16::MIN, 100, u16::MAX])),
350+
"50",
351+
)
352+
.await
353+
}
354+
355+
#[tokio::test]
356+
async fn median_u32() -> Result<()> {
357+
median_test(
358+
"median",
359+
DataType::UInt32,
360+
Arc::new(UInt32Array::from(vec![u32::MIN, u32::MIN, 100, u32::MAX])),
361+
"50",
362+
)
363+
.await
364+
}
365+
366+
#[tokio::test]
367+
async fn median_u64() -> Result<()> {
368+
median_test(
369+
"median",
370+
DataType::UInt64,
371+
Arc::new(UInt64Array::from(vec![u64::MIN, u64::MIN, 100, u64::MAX])),
372+
"50",
373+
)
374+
.await
375+
}
376+
377+
#[tokio::test]
378+
async fn median_f32() -> Result<()> {
379+
median_test(
380+
"median",
381+
DataType::Float32,
382+
Arc::new(Float32Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])),
383+
"3.3",
384+
)
385+
.await
386+
}
387+
388+
#[tokio::test]
389+
async fn median_f64() -> Result<()> {
390+
median_test(
391+
"median",
392+
DataType::Float64,
393+
Arc::new(Float64Array::from(vec![1.1, 4.4, 5.5, 3.3, 2.2])),
394+
"3.3",
395+
)
396+
.await
397+
}
398+
399+
#[tokio::test]
400+
async fn median_f64_nan() -> Result<()> {
401+
median_test(
402+
"median",
403+
DataType::Float64,
404+
Arc::new(Float64Array::from(vec![1.1, f64::NAN, f64::NAN, f64::NAN])),
405+
"NaN", // probably not the desired behavior? - see https://github.com/apache/arrow-datafusion/issues/3039
406+
)
407+
.await
408+
}
409+
410+
#[tokio::test]
411+
async fn approx_median_f64_nan() -> Result<()> {
412+
median_test(
413+
"approx_median",
414+
DataType::Float64,
415+
Arc::new(Float64Array::from(vec![1.1, f64::NAN, f64::NAN, f64::NAN])),
416+
"NaN", // probably not the desired behavior? - see https://github.com/apache/arrow-datafusion/issues/3039
417+
)
418+
.await
419+
}
420+
421+
async fn median_test(
422+
func: &str,
423+
data_type: DataType,
424+
values: ArrayRef,
425+
expected: &str,
426+
) -> Result<()> {
427+
let ctx = SessionContext::new();
428+
let schema = Arc::new(Schema::new(vec![Field::new("a", data_type, false)]));
429+
let batch = RecordBatch::try_new(schema.clone(), vec![values])?;
430+
let table = Arc::new(MemTable::try_new(schema, vec![vec![batch]])?);
431+
ctx.register_table("t", table)?;
432+
let sql = format!("SELECT {}(a) FROM t", func);
433+
let actual = execute(&ctx, &sql).await;
434+
let expected = vec![vec![expected.to_owned()]];
435+
assert_float_eq(&expected, &actual);
436+
Ok(())
437+
}
438+
256439
#[tokio::test]
257440
async fn csv_query_external_table_count() {
258441
let ctx = SessionContext::new();

datafusion/core/tests/sql/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,11 @@ where
128128
l.as_ref().parse::<f64>().unwrap(),
129129
r.as_str().parse::<f64>().unwrap(),
130130
);
131-
assert!((l - r).abs() <= 2.0 * f64::EPSILON);
131+
if l.is_nan() || r.is_nan() {
132+
assert!(l.is_nan() && r.is_nan());
133+
} else if (l - r).abs() > 2.0 * f64::EPSILON {
134+
panic!("{} != {}", l, r)
135+
}
132136
});
133137
}
134138

datafusion/expr/src/accumulator.rs

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,22 @@
1818
//! Accumulator module contains the trait definition for aggregation function's accumulators.
1919
2020
use arrow::array::ArrayRef;
21-
use datafusion_common::{Result, ScalarValue};
21+
use datafusion_common::{DataFusionError, Result, ScalarValue};
2222
use std::fmt::Debug;
2323

2424
/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and
2525
/// generically accumulates values.
2626
///
2727
/// An accumulator knows how to:
2828
/// * update its state from inputs via `update_batch`
29-
/// * convert its internal state to a vector of scalar values
29+
/// * convert its internal state to a vector of aggregate values
3030
/// * update its state from multiple accumulators' states via `merge_batch`
3131
/// * compute the final value from its internal state via `evaluate`
3232
pub trait Accumulator: Send + Sync + Debug {
3333
/// Returns the state of the accumulator at the end of the accumulation.
34-
// in the case of an average on which we track `sum` and `n`, this function should return a vector
35-
// of two values, sum and n.
36-
fn state(&self) -> Result<Vec<ScalarValue>>;
34+
/// in the case of an average on which we track `sum` and `n`, this function should return a vector
35+
/// of two values, sum and n.
36+
fn state(&self) -> Result<Vec<AggregateState>>;
3737

3838
/// updates the accumulator's state from a vector of arrays.
3939
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
@@ -44,3 +44,38 @@ pub trait Accumulator: Send + Sync + Debug {
4444
/// returns its value based on its current state.
4545
fn evaluate(&self) -> Result<ScalarValue>;
4646
}
47+
48+
/// Representation of internal accumulator state. Accumulators can potentially have a mix of
49+
/// scalar and array values. It may be desirable to add custom aggregator states here as well
50+
/// in the future (perhaps `Custom(Box<dyn Any>)`?).
51+
#[derive(Debug)]
52+
pub enum AggregateState {
53+
/// Simple scalar value. Note that `ScalarValue::List` can be used to pass multiple
54+
/// values around
55+
Scalar(ScalarValue),
56+
/// Arrays can be used instead of `ScalarValue::List` and could potentially have better
57+
/// performance with large data sets, although this has not been verified. It also allows
58+
/// for use of arrow kernels with less overhead.
59+
Array(ArrayRef),
60+
}
61+
62+
impl AggregateState {
63+
/// Access the aggregate state as a scalar value. An error will occur if the
64+
/// state is not a scalar value.
65+
pub fn as_scalar(&self) -> Result<&ScalarValue> {
66+
match &self {
67+
Self::Scalar(v) => Ok(v),
68+
_ => Err(DataFusionError::Internal(
69+
"AggregateState is not a scalar aggregate".to_string(),
70+
)),
71+
}
72+
}
73+
74+
/// Access the aggregate state as an array value.
75+
pub fn to_array(&self) -> ArrayRef {
76+
match &self {
77+
Self::Scalar(v) => v.to_array(),
78+
Self::Array(array) => array.clone(),
79+
}
80+
}
81+
}

datafusion/expr/src/aggregate_function.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ pub enum AggregateFunction {
6262
Max,
6363
/// avg
6464
Avg,
65+
/// median
66+
Median,
6567
/// Approximate aggregate function
6668
ApproxDistinct,
6769
/// array_agg
@@ -107,6 +109,7 @@ impl FromStr for AggregateFunction {
107109
"avg" => AggregateFunction::Avg,
108110
"mean" => AggregateFunction::Avg,
109111
"sum" => AggregateFunction::Sum,
112+
"median" => AggregateFunction::Median,
110113
"approx_distinct" => AggregateFunction::ApproxDistinct,
111114
"array_agg" => AggregateFunction::ArrayAgg,
112115
"var" => AggregateFunction::Variance,
@@ -175,7 +178,9 @@ pub fn return_type(
175178
AggregateFunction::ApproxPercentileContWithWeight => {
176179
Ok(coerced_data_types[0].clone())
177180
}
178-
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
181+
AggregateFunction::ApproxMedian | AggregateFunction::Median => {
182+
Ok(coerced_data_types[0].clone())
183+
}
179184
AggregateFunction::Grouping => Ok(DataType::Int32),
180185
}
181186
}
@@ -330,6 +335,7 @@ pub fn coerce_types(
330335
}
331336
Ok(input_types.to_vec())
332337
}
338+
AggregateFunction::Median => Ok(input_types.to_vec()),
333339
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
334340
}
335341
}
@@ -358,6 +364,7 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
358364
| AggregateFunction::VariancePop
359365
| AggregateFunction::Stddev
360366
| AggregateFunction::StddevPop
367+
| AggregateFunction::Median
361368
| AggregateFunction::ApproxMedian => {
362369
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
363370
}

datafusion/expr/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ pub mod utils;
5353
pub mod window_frame;
5454
pub mod window_function;
5555

56-
pub use accumulator::Accumulator;
56+
pub use accumulator::{Accumulator, AggregateState};
5757
pub use aggregate_function::AggregateFunction;
5858
pub use built_in_function::BuiltinScalarFunction;
5959
pub use columnar_value::{ColumnarValue, NullColumnarValue};

0 commit comments

Comments
 (0)