From caac6fb36f351c5f81c23c4f677c4908309f640c Mon Sep 17 00:00:00 2001 From: Dharan Aditya Date: Fri, 27 Sep 2024 17:13:32 +0530 Subject: [PATCH] Port `kurtosis_pop` UDAF (#7) --- README.md | 1 + src/kurtosis_pop.rs | 188 ++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 9 ++- tests/main.rs | 60 ++++++++++++++ 4 files changed, 257 insertions(+), 1 deletion(-) create mode 100644 src/kurtosis_pop.rs diff --git a/README.md b/README.md index 8e63330..db01025 100644 --- a/README.md +++ b/README.md @@ -84,3 +84,4 @@ SELECT min_by(x, y) FROM VALUES (1, 10), (2, 5), (3, 15), (4, 8) as tab(x, y); - [x] `mode(expression) -> scalar` - Returns the most frequent (mode) value from a column of data. - [x] `max_by(expression1, expression2) -> scalar` - Returns the value of `expression1` associated with the maximum value of `expression2`. - [x] `min_by(expression1, expression2) -> scalar` - Returns the value of `expression1` associated with the minimum value of `expression2`. +- [x] `kurtois_pop(expression) -> scalar` - Computes the excess kurtosis (Fisher’s definition) without bias correction. diff --git a/src/kurtosis_pop.rs b/src/kurtosis_pop.rs new file mode 100644 index 0000000..6d8f32a --- /dev/null +++ b/src/kurtosis_pop.rs @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Copired from `datafusion/functions-aggregate/src/kurtosis_pop.rs` +// Originally authored by goldmedal + +use arrow::array::{Array, ArrayRef, Float64Array, UInt64Array}; +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::common::cast::as_float64_array; +use datafusion::common::{downcast_value, DataFusionError, Result, ScalarValue}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::fmt::Debug; + +make_udaf_expr_and_func!( + KurtosisPopFunction, + kurtosis_pop, + x, + "Calculates the excess kurtosis (Fisher’s definition) without bias correction.", + kurtosis_pop_udaf +); + +pub struct KurtosisPopFunction { + signature: Signature, +} + +impl Debug for KurtosisPopFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("KurtosisPopFunction") + .field("signature", &self.signature) + .finish() + } +} + +impl Default for KurtosisPopFunction { + fn default() -> Self { + Self::new() + } +} + +impl KurtosisPopFunction { + pub fn new() -> Self { + Self { + signature: Signature::coercible(vec![DataType::Float64], Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for KurtosisPopFunction { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "kurtosis_pop" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new("count", DataType::UInt64, true), + Field::new("sum", DataType::Float64, true), + Field::new("sum_sqr", DataType::Float64, true), + Field::new("sum_cub", DataType::Float64, true), + Field::new("sum_four", DataType::Float64, true), + ]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(KurtosisPopAccumulator::new())) + } +} + +/// Accumulator for calculating the excess kurtosis (Fisher’s definition) without bias correction. +/// This implementation follows the [DuckDB implementation]: +/// +#[derive(Debug, Default)] +pub struct KurtosisPopAccumulator { + count: u64, + sum: f64, + sum_sqr: f64, + sum_cub: f64, + sum_four: f64, +} + +impl KurtosisPopAccumulator { + pub fn new() -> Self { + Self { + count: 0, + sum: 0.0, + sum_sqr: 0.0, + sum_cub: 0.0, + sum_four: 0.0, + } + } +} + +impl Accumulator for KurtosisPopAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = as_float64_array(&values[0])?; + for value in array.iter().flatten() { + self.count += 1; + self.sum += value; + self.sum_sqr += value.powi(2); + self.sum_cub += value.powi(3); + self.sum_four += value.powi(4); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], UInt64Array); + let sums = downcast_value!(states[1], Float64Array); + let sum_sqrs = downcast_value!(states[2], Float64Array); + let sum_cubs = downcast_value!(states[3], Float64Array); + let sum_fours = downcast_value!(states[4], Float64Array); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0 { + continue; + } + self.count += c; + self.sum += sums.value(i); + self.sum_sqr += sum_sqrs.value(i); + self.sum_cub += sum_cubs.value(i); + self.sum_four += sum_fours.value(i); + } + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + if self.count < 1 { + return Ok(ScalarValue::Float64(None)); + } + + let count_64 = 1_f64 / self.count as f64; + let m4 = count_64 + * (self.sum_four - 4.0 * self.sum_cub * self.sum * count_64 + + 6.0 * self.sum_sqr * self.sum.powi(2) * count_64.powi(2) + - 3.0 * self.sum.powi(4) * count_64.powi(3)); + + let m2 = (self.sum_sqr - self.sum.powi(2) * count_64) * count_64; + if m2 <= 0.0 { + return Ok(ScalarValue::Float64(None)); + } + + let target = m4 / (m2.powi(2)) - 3.0; + Ok(ScalarValue::Float64(Some(target))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.sum), + ScalarValue::from(self.sum_sqr), + ScalarValue::from(self.sum_cub), + ScalarValue::from(self.sum_four), + ]) + } +} diff --git a/src/lib.rs b/src/lib.rs index efbf090..d9dd246 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,16 +26,23 @@ use datafusion::logical_expr::AggregateUDF; #[macro_use] pub mod macros; pub mod common; +pub mod kurtosis_pop; pub mod max_min_by; pub mod mode; pub mod expr_extra_fn { + pub use super::kurtosis_pop::kurtosis_pop; pub use super::max_min_by::max_by; pub use super::max_min_by::min_by; pub use super::mode::mode; } pub fn all_extra_aggregate_functions() -> Vec> { - vec![mode_udaf(), max_min_by::max_by_udaf(), max_min_by::min_by_udaf()] + vec![ + mode_udaf(), + max_min_by::max_by_udaf(), + max_min_by::min_by_udaf(), + kurtosis_pop::kurtosis_pop_udaf(), + ] } /// Registers all enabled packages with a [`FunctionRegistry`] diff --git a/tests/main.rs b/tests/main.rs index a63b782..1a1bbc7 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -250,3 +250,63 @@ async fn test_max_by_and_min_by() { - +---------------------+ "###); } + +#[tokio::test] +async fn test_kurtosis_pop() { + let mut execution = TestExecution::new().await.unwrap().with_setup(TEST_TABLE).await; + + // Test with int64 + let actual = execution + .run_and_format("SELECT kurtosis_pop(int64_col) FROM test_table") + .await; + + insta::assert_yaml_snapshot!(actual, @r###" + - +------------------------------------+ + - "| kurtosis_pop(test_table.int64_col) |" + - +------------------------------------+ + - "| -0.9599999999999755 |" + - +------------------------------------+ + "###); + + // Test with float64 + let actual = execution + .run_and_format("SELECT kurtosis_pop(float64_col) FROM test_table") + .await; + + insta::assert_yaml_snapshot!(actual, @r###" + - +--------------------------------------+ + - "| kurtosis_pop(test_table.float64_col) |" + - +--------------------------------------+ + - "| -0.9599999999999755 |" + - +--------------------------------------+ +"###); + + let actual = execution + .run_and_format("SELECT kurtosis_pop(col) FROM VALUES (1.0) as tab(col)") + .await; + insta::assert_yaml_snapshot!(actual, @r###" + - +-----------------------+ + - "| kurtosis_pop(tab.col) |" + - +-----------------------+ + - "| |" + - +-----------------------+ +"###); + + let actual = execution.run_and_format("SELECT kurtosis_pop(1.0)").await; + insta::assert_yaml_snapshot!(actual, @r###" + - +--------------------------+ + - "| kurtosis_pop(Float64(1)) |" + - +--------------------------+ + - "| |" + - +--------------------------+ +"###); + + let actual = execution.run_and_format("SELECT kurtosis_pop(null)").await; + insta::assert_yaml_snapshot!(actual, @r###" +- +--------------------+ +- "| kurtosis_pop(NULL) |" +- +--------------------+ +- "| |" +- +--------------------+ +"###); +}