From 5ff5a6c924c0d3e2f2c2959d1a348be7913431c6 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 4 Sep 2024 15:19:38 +0800 Subject: [PATCH] Implement `kurtosis_pop` UDAF (#12273) * implement kurtosis_pop udaf * add tests * add empty end line * fix MSRV check * fix the null input and enhance tests * refactor the aggregation * address the review comments * add the doc for kurtois_pop * fix the doc style * use coercible signature * remove unused cast --- .../functions-aggregate/src/kurtosis_pop.rs | 190 ++++++++++++++++++ datafusion/functions-aggregate/src/lib.rs | 2 + .../tests/cases/roundtrip_logical_plan.rs | 2 + .../sqllogictest/test_files/aggregate.slt | 61 ++++++ .../user-guide/sql/aggregate_functions.md | 14 ++ 5 files changed, 269 insertions(+) create mode 100644 datafusion/functions-aggregate/src/kurtosis_pop.rs diff --git a/datafusion/functions-aggregate/src/kurtosis_pop.rs b/datafusion/functions-aggregate/src/kurtosis_pop.rs new file mode 100644 index 000000000000..ac173a0ee579 --- /dev/null +++ b/datafusion/functions-aggregate/src/kurtosis_pop.rs @@ -0,0 +1,190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, Float64Array, UInt64Array}; +use arrow_schema::{DataType, Field}; +use datafusion_common::cast::as_float64_array; +use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_functions_aggregate_common::accumulator::{ + AccumulatorArgs, StateFieldsArgs, +}; +use std::any::Any; +use std::fmt::Debug; + +make_udaf_expr_and_func!( + KurtosisPopFunction, + kurtosis_pop, + x, + "Calculates the excess kurtosis (Fisher’s definition) without bias correction.", + kurtosis_pop_udaf +); + +pub struct KurtosisPopFunction { + signature: Signature, +} + +impl Debug for KurtosisPopFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("KurtosisPopFunction") + .field("signature", &self.signature) + .finish() + } +} + +impl Default for KurtosisPopFunction { + fn default() -> Self { + Self::new() + } +} + +impl KurtosisPopFunction { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![DataType::Float64], + Volatility::Immutable, + ), + } + } +} + +impl AggregateUDFImpl for KurtosisPopFunction { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "kurtosis_pop" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new("count", DataType::UInt64, true), + Field::new("sum", DataType::Float64, true), + Field::new("sum_sqr", DataType::Float64, true), + Field::new("sum_cub", DataType::Float64, true), + Field::new("sum_four", DataType::Float64, true), + ]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(KurtosisPopAccumulator::new())) + } +} + +/// Accumulator for calculating the excess kurtosis (Fisher’s definition) without bias correction. +/// This implementation follows the [DuckDB implementation]: +/// +#[derive(Debug, Default)] +pub struct KurtosisPopAccumulator { + count: u64, + sum: f64, + sum_sqr: f64, + sum_cub: f64, + sum_four: f64, +} + +impl KurtosisPopAccumulator { + pub fn new() -> Self { + Self { + count: 0, + sum: 0.0, + sum_sqr: 0.0, + sum_cub: 0.0, + sum_four: 0.0, + } + } +} + +impl Accumulator for KurtosisPopAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = as_float64_array(&values[0])?; + for value in array.iter().flatten() { + self.count += 1; + self.sum += value; + self.sum_sqr += value.powi(2); + self.sum_cub += value.powi(3); + self.sum_four += value.powi(4); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], UInt64Array); + let sums = downcast_value!(states[1], Float64Array); + let sum_sqrs = downcast_value!(states[2], Float64Array); + let sum_cubs = downcast_value!(states[3], Float64Array); + let sum_fours = downcast_value!(states[4], Float64Array); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0 { + continue; + } + self.count += c; + self.sum += sums.value(i); + self.sum_sqr += sum_sqrs.value(i); + self.sum_cub += sum_cubs.value(i); + self.sum_four += sum_fours.value(i); + } + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + if self.count < 1 { + return Ok(ScalarValue::Float64(None)); + } + + let count_64 = 1_f64 / self.count as f64; + let m4 = count_64 + * (self.sum_four - 4.0 * self.sum_cub * self.sum * count_64 + + 6.0 * self.sum_sqr * self.sum.powi(2) * count_64.powi(2) + - 3.0 * self.sum.powi(4) * count_64.powi(3)); + + let m2 = (self.sum_sqr - self.sum.powi(2) * count_64) * count_64; + if m2 <= 0.0 { + return Ok(ScalarValue::Float64(None)); + } + + let target = m4 / (m2.powi(2)) - 3.0; + Ok(ScalarValue::Float64(Some(target))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.sum), + ScalarValue::from(self.sum_sqr), + ScalarValue::from(self.sum_cub), + ScalarValue::from(self.sum_four), + ]) + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index ca0276d326a4..60e2602eb6ed 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -78,6 +78,7 @@ pub mod average; pub mod bit_and_or_xor; pub mod bool_and_or; pub mod grouping; +pub mod kurtosis_pop; pub mod nth_value; pub mod string_agg; @@ -170,6 +171,7 @@ pub fn all_default_aggregate_functions() -> Vec> { average::avg_udaf(), grouping::grouping_udaf(), nth_value::nth_value_udaf(), + kurtosis_pop::kurtosis_pop_udaf(), ] } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 994ed8ad2352..dd3b99b0768b 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -73,6 +73,7 @@ use datafusion_functions_aggregate::expr_fn::{ approx_distinct, array_agg, avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, nth_value, }; +use datafusion_functions_aggregate::kurtosis_pop::kurtosis_pop; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -904,6 +905,7 @@ async fn roundtrip_expr_api() -> Result<()> { vec![lit(10), lit(20), lit(30)], ), row_number(), + kurtosis_pop(lit(1)), nth_value(col("b"), 1, vec![]), nth_value( col("b"), diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 45cb4d4615d7..c52445c561ee 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5863,3 +5863,64 @@ ORDER BY k; ---- 1 1.8125 6.8007813 Float16 Float16 2 8.5 8.5 Float16 Float16 + +# The result is 0.19432323191699075 actually +query R +SELECT kurtosis_pop(col) FROM VALUES (1), (10), (100), (10), (1) as tab(col); +---- +0.194323231917 + +# The result is -1.153061224489787 actually +query R +SELECT kurtosis_pop(col) FROM VALUES (1), (2), (3), (2), (1) as tab(col); +---- +-1.15306122449 + +query R +SELECT kurtosis_pop(col) FROM VALUES (1.0), (10.0), (100.0), (10.0), (1.0) as tab(col); +---- +0.194323231917 + +query R +SELECT kurtosis_pop(col) FROM VALUES ('1'), ('10'), ('100'), ('10'), ('1') as tab(col); +---- +0.194323231917 + +query R +SELECT kurtosis_pop(col) FROM VALUES (1.0) as tab(col); +---- +NULL + +query R +SELECT kurtosis_pop(1) +---- +NULL + +query R +SELECT kurtosis_pop(1.0) +---- +NULL + +query R +SELECT kurtosis_pop(null) +---- +NULL + +statement ok +CREATE TABLE t1(c1 int); + +query R +SELECT kurtosis_pop(c1) FROM t1; +---- +NULL + +statement ok +INSERT INTO t1 VALUES (1), (10), (100), (10), (1); + +query R +SELECT kurtosis_pop(c1) FROM t1; +---- +0.194323231917 + +statement ok +DROP TABLE t1; diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index edb0e1d0c9f0..1c214084b3fa 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -252,6 +252,7 @@ last_value(expression [ORDER BY expression]) - [regr_sxx](#regr_sxx) - [regr_syy](#regr_syy) - [regr_sxy](#regr_sxy) +- [kurtosis_pop](#kurtosis_pop) ### `corr` @@ -527,6 +528,19 @@ regr_sxy(expression_y, expression_x) - **expression_x**: Independent variable. Can be a constant, column, or function, and any combination of arithmetic operators. +### `kurtosis_pop` + +Computes the excess kurtosis (Fisher’s definition) without bias correction. + +``` +kurtois_pop(expression) +``` + +#### Arguments + +- **expression**: Expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. + ## Approximate - [approx_distinct](#approx_distinct)