From d88e4c6da76348c218eca9a8e62ac0e8c735e81b Mon Sep 17 00:00:00 2001 From: jatin Date: Wed, 25 Sep 2024 08:17:24 +0530 Subject: [PATCH 1/2] implement kurtosis udaf --- .../functions-aggregate/src/kurtosis.rs | 195 ++++++++++++++++++ datafusion/functions-aggregate/src/lib.rs | 2 + .../tests/cases/roundtrip_logical_plan.rs | 2 + .../sqllogictest/test_files/aggregate.slt | 42 ++++ .../user-guide/sql/aggregate_functions.md | 15 ++ 5 files changed, 256 insertions(+) create mode 100644 datafusion/functions-aggregate/src/kurtosis.rs diff --git a/datafusion/functions-aggregate/src/kurtosis.rs b/datafusion/functions-aggregate/src/kurtosis.rs new file mode 100644 index 000000000000..9a9f142b38ad --- /dev/null +++ b/datafusion/functions-aggregate/src/kurtosis.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt::Debug; +use arrow::array::{ArrayRef, Float64Array, UInt64Array}; +use arrow_schema::{DataType, Field}; +use datafusion_common::cast::as_float64_array; +use datafusion_common::{downcast_value, DataFusionError,ScalarValue}; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_functions_aggregate_common::accumulator::{AccumulatorArgs, StateFieldsArgs}; + +make_udaf_expr_and_func!( + KurtosisFunction, + kurtosis, + x, + "Calculates the excess kurtosis (Fisher’s definition) with bias correction according to the sample size.", + kurtosis_udaf +); + +pub struct KurtosisFunction { + signature: Signature, +} + +impl Debug for KurtosisFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("KurtosisFunction") + .field("signature", &self.signature) + .finish() + } +} + +impl Default for KurtosisFunction { + fn default() -> Self { + Self::new() + } +} + +impl KurtosisFunction { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![DataType::Float64], + Volatility::Immutable, + ), + } + } +} + +impl AggregateUDFImpl for KurtosisFunction { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "kurtosis" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> datafusion_common::Result> { + Ok(Box::new(KurtosisAccumulator::new())) + } + + // TODO + // check from here + fn state_fields(&self, _args: StateFieldsArgs) -> datafusion_common::Result> { + Ok(vec![ + Field::new("count", DataType::UInt64, true), + Field::new("sum", DataType::Float64, true), + Field::new("sum_sqr", DataType::Float64, true), + Field::new("sum_cub", DataType::Float64, true), + Field::new("sum_four", DataType::Float64, true), + ]) + } +} + +/// Accumulator for calculating the excess kurtosis (Fisher’s definition) with bias correction according to the sample size. +/// This implementation follows the [DuckDB implementation]: +/// +#[derive(Debug, Default)] +pub struct KurtosisAccumulator { + count: u64, + sum: f64, + sum_sqr: f64, + sum_cub: f64, + sum_four: f64, +} + +impl KurtosisAccumulator { + pub fn new() -> Self { + Self { + count: 0, + sum: 0.0, + sum_sqr: 0.0, + sum_cub: 0.0, + sum_four: 0.0, + } + } +} + +impl Accumulator for KurtosisAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + let array = as_float64_array(&values[0])?; + for value in array.iter().flatten() { + self.count += 1; + self.sum += value; + self.sum_sqr += value.powi(2); + self.sum_cub += value.powi(3); + self.sum_four += value.powi(4); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + let counts = downcast_value!(states[0], UInt64Array); + let sums = downcast_value!(states[1], Float64Array); + let sum_sqrs = downcast_value!(states[2], Float64Array); + let sum_cubs = downcast_value!(states[3], Float64Array); + let sum_fours = downcast_value!(states[4], Float64Array); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0 { + continue; + } + self.count += c; + self.sum += sums.value(i); + self.sum_sqr += sum_sqrs.value(i); + self.sum_cub += sum_cubs.value(i); + self.sum_four += sum_fours.value(i); + } + + Ok(()) + } + + fn evaluate(&mut self) -> datafusion_common::Result { + if self.count < 3 { + return Ok(ScalarValue::Float64(None)); + } + + let count_64 = 1_f64 / self.count as f64; + let m4 = count_64 + * (self.sum_four - 4.0 * self.sum_cub * self.sum * count_64 + + 6.0 * self.sum_sqr * self.sum.powi(2) * count_64.powi(2) + - 3.0 * self.sum.powi(4) * count_64.powi(3)); + + let m2 = (self.sum_sqr - self.sum.powi(2) * count_64) * count_64; + if m2 <= 0.0 { + return Ok(ScalarValue::Float64(None)); + } + + let count = self.count as f64; + let numerator = (count - 1.0) * ((count + 1.0) * m4 / m2.powi(2) - 3.0 * (count - 1.0)); + let denominator = (count - 2.0 ) * (count - 3.0); + + let target = numerator/denominator; + + Ok(ScalarValue::Float64(Some(target))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> datafusion_common::Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.sum), + ScalarValue::from(self.sum_sqr), + ScalarValue::from(self.sum_cub), + ScalarValue::from(self.sum_four), + ]) + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 60e2602eb6ed..03c4a4372e67 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -79,6 +79,7 @@ pub mod bit_and_or_xor; pub mod bool_and_or; pub mod grouping; pub mod kurtosis_pop; +pub mod kurtosis; pub mod nth_value; pub mod string_agg; @@ -172,6 +173,7 @@ pub fn all_default_aggregate_functions() -> Vec> { grouping::grouping_udaf(), nth_value::nth_value_udaf(), kurtosis_pop::kurtosis_pop_udaf(), + kurtosis:: kurtosis_udaf(), ] } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 1f1426164d39..6bbc0d6d7fb1 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -74,6 +74,7 @@ use datafusion_functions_aggregate::expr_fn::{ nth_value, }; use datafusion_functions_aggregate::kurtosis_pop::kurtosis_pop; +use datafusion_functions_aggregate::kurtosis::kurtosis; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_proto::bytes::{ @@ -913,6 +914,7 @@ async fn roundtrip_expr_api() -> Result<()> { ), row_number(), kurtosis_pop(lit(1)), + kurtosis(lit(1)), nth_value(col("b"), 1, vec![]), nth_value( col("b"), diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 576abe5c6f5a..e82f685ef7e1 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5923,3 +5923,45 @@ SELECT kurtosis_pop(c1) FROM t1; statement ok DROP TABLE t1; + +# Test for kurtosis + +query R +SELECT kurtosis(col) FROM VALUES (1), (10), (100), (10), (1) as tab(col); +---- +4.777292927668 + +query R +SELECT kurtosis(col) FROM VALUES (1), (2), (3), (2), (1) as tab(col); +---- +-0.615384615385 + +query R +SELECT kurtosis(col) FROM VALUES (1.0), (10.0), (100.0), (10.0), (1.0) as tab(col); +---- +4.777292927668 + +query R +SELECT kurtosis(col) FROM VALUES ('1'), ('10'), ('100'), ('10'), ('1') as tab(col); +---- +4.777292927668 + +query R +SELECT kurtosis(col) FROM VALUES (1.0) as tab(col); +---- +NULL + +query R +SELECT kurtosis(1); +---- +NULL + +query R +SELECT kurtosis(1.0); +---- +NULL + +query R +SELECT kurtosis(null); +---- +NULL \ No newline at end of file diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 1c214084b3fa..bf20f7bf5a26 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -253,6 +253,7 @@ last_value(expression [ORDER BY expression]) - [regr_syy](#regr_syy) - [regr_sxy](#regr_sxy) - [kurtosis_pop](#kurtosis_pop) +- [kurtosis](#kurtosis) ### `corr` @@ -538,6 +539,20 @@ kurtois_pop(expression) #### Arguments +- **expression**: Expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. + + +### `kurtosis` + +Computes the excess kurtosis (Fisher's definition) with bias correction according to the sample size. + +``` +kurtois(expression) +``` + +#### Arguments + - **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. From d925061d68d055f890290fc234d680dd5ca2ddc4 Mon Sep 17 00:00:00 2001 From: jatin Date: Wed, 25 Sep 2024 08:33:10 +0530 Subject: [PATCH 2/2] updated the evaluate logic and added slt test --- datafusion/functions-aggregate/src/kurtosis.rs | 2 +- datafusion/sqllogictest/test_files/aggregate.slt | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/kurtosis.rs b/datafusion/functions-aggregate/src/kurtosis.rs index 9a9f142b38ad..f18f9c3921ad 100644 --- a/datafusion/functions-aggregate/src/kurtosis.rs +++ b/datafusion/functions-aggregate/src/kurtosis.rs @@ -155,7 +155,7 @@ impl Accumulator for KurtosisAccumulator { } fn evaluate(&mut self) -> datafusion_common::Result { - if self.count < 3 { + if self.count <= 3 { return Ok(ScalarValue::Float64(None)); } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index e82f685ef7e1..7b095f369e95 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5934,7 +5934,7 @@ SELECT kurtosis(col) FROM VALUES (1), (10), (100), (10), (1) as tab(col); query R SELECT kurtosis(col) FROM VALUES (1), (2), (3), (2), (1) as tab(col); ---- --0.615384615385 +-0.612244897959 query R SELECT kurtosis(col) FROM VALUES (1.0), (10.0), (100.0), (10.0), (1.0) as tab(col); @@ -5951,6 +5951,16 @@ SELECT kurtosis(col) FROM VALUES (1.0) as tab(col); ---- NULL +query R +SELECT kurtosis(col) FROM VALUES (1.0), (2.0) as tab(col); +---- +NULL + +query R +SELECT kurtosis(col) FROM VALUES (1.0), (2.0), (3.0) as tab(col); +---- +NULL + query R SELECT kurtosis(1); ----