diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index ca0276d326a4..d828d50b3f1e 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -79,6 +79,7 @@ pub mod bit_and_or_xor; pub mod bool_and_or; pub mod grouping; pub mod nth_value; +pub mod skewness; pub mod string_agg; use crate::approx_percentile_cont::approx_percentile_cont_udaf; @@ -170,6 +171,7 @@ pub fn all_default_aggregate_functions() -> Vec> { average::avg_udaf(), grouping::grouping_udaf(), nth_value::nth_value_udaf(), + skewness::skewness_udaf(), ] } diff --git a/datafusion/functions-aggregate/src/skewness.rs b/datafusion/functions-aggregate/src/skewness.rs new file mode 100644 index 000000000000..25bc4c9e7c39 --- /dev/null +++ b/datafusion/functions-aggregate/src/skewness.rs @@ -0,0 +1,190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::{Float64Type, UInt64Type}; +use arrow_schema::{DataType, Field}; +use datafusion_common::ScalarValue; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_functions_aggregate_common::accumulator::{ + AccumulatorArgs, StateFieldsArgs, +}; +use std::any::Any; +use std::ops::{Div, Mul, Sub}; + +make_udaf_expr_and_func!( + SkewnessFunc, + skewness, + x, + "Computes the skewness value.", + skewness_udaf +); + +#[derive(Debug)] +pub struct SkewnessFunc { + name: String, + signature: Signature, +} + +impl Default for SkewnessFunc { + fn default() -> Self { + Self::new() + } +} + +impl SkewnessFunc { + pub fn new() -> Self { + Self { + name: "skewness".to_string(), + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for SkewnessFunc { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type( + &self, + _arg_types: &[DataType], + ) -> datafusion_common::Result { + Ok(DataType::Float64) + } + + fn accumulator( + &self, + _acc_args: AccumulatorArgs, + ) -> datafusion_common::Result> { + Ok(Box::new(SkewnessAccumulator::new())) + } + + fn state_fields( + &self, + _args: StateFieldsArgs, + ) -> datafusion_common::Result> { + Ok(vec![ + Field::new("count", DataType::UInt64, true), + Field::new("sum", DataType::Float64, true), + Field::new("sum_sqr", DataType::Float64, true), + Field::new("sum_cub", DataType::Float64, true), + ]) + } + + fn coerce_types( + &self, + _arg_types: &[DataType], + ) -> datafusion_common::Result> { + Ok(vec![DataType::Float64]) + } +} + +/// Accumulator for calculating the skewness +/// This implementation follows the DuckDB implementation: +/// +#[derive(Debug)] +pub struct SkewnessAccumulator { + count: u64, + sum: f64, + sum_sqr: f64, + sum_cub: f64, +} + +impl SkewnessAccumulator { + fn new() -> Self { + Self { + count: 0, + sum: 0f64, + sum_sqr: 0f64, + sum_cub: 0f64, + } + } +} + +impl Accumulator for SkewnessAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + let array = values[0].as_primitive::(); + for val in array.iter().flatten() { + self.count += 1; + self.sum += val; + self.sum_sqr += val.powi(2); + self.sum_cub += val.powi(3); + } + Ok(()) + } + fn evaluate(&mut self) -> datafusion_common::Result { + if self.count <= 2 { + return Ok(ScalarValue::Float64(None)); + } + let count = self.count as f64; + let t1 = 1f64 / count; + let p = (t1 * (self.sum_sqr - self.sum * self.sum * t1)) + .powi(3) + .max(0f64); + let div = p.sqrt(); + if div == 0f64 { + return Ok(ScalarValue::Float64(None)); + } + let t2 = count.mul(count.sub(1f64)).sqrt().div(count.sub(2f64)); + let res = t2 + * t1 + * (self.sum_cub - 3f64 * self.sum_sqr * self.sum * t1 + + 2f64 * self.sum.powi(3) * t1 * t1) + / div; + Ok(ScalarValue::Float64(Some(res))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> datafusion_common::Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.sum), + ScalarValue::from(self.sum_sqr), + ScalarValue::from(self.sum_cub), + ]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + let counts = states[0].as_primitive::(); + let sums = states[1].as_primitive::(); + let sum_sqrs = states[2].as_primitive::(); + let sum_cubs = states[3].as_primitive::(); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0 { + continue; + } + self.count += c; + self.sum += sums.value(i); + self.sum_sqr += sum_sqrs.value(i); + self.sum_cub += sum_cubs.value(i); + } + Ok(()) + } +} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 994ed8ad2352..db9d7b7aeb16 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -73,6 +73,7 @@ use datafusion_functions_aggregate::expr_fn::{ approx_distinct, array_agg, avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, nth_value, }; +use datafusion_functions_aggregate::skewness::skewness; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -916,6 +917,7 @@ async fn roundtrip_expr_api() -> Result<()> { -1, vec![col("a").sort(false, false), col("b").sort(true, false)], ), + skewness(lit(1.1)), ]; // ensure expressions created with the expr api can be round tripped diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 45cb4d4615d7..211e62958693 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5863,3 +5863,82 @@ ORDER BY k; ---- 1 1.8125 6.8007813 Float16 Float16 2 8.5 8.5 Float16 Float16 + +query R +SELECT skewness(col) FROM VALUES (-10), (-20), (100), (1000), (1000) AS tab(col); +---- +0.574511614753 + +query R +SELECT skewness(DISTINCT col) FROM VALUES (-10), (-20), (100), (1000), (1000) AS tab(col); +---- +1.928752451203 + +query R +SELECT skewness(1); +---- +NULL + +query R +select skewness(NULL); +---- +NULL + +query error +select skewness(*); + +# out of range +query R +SELECT skewness(DISTINCT col) FROM VALUES (-2e307), (0), (2e307) AS tab(col); +---- +NaN + +statement ok +create table aggr(k int, v decimal(10,2), v2 decimal(10, 2)); + +statement ok +insert into aggr values + (1, 10, null), + (2, 10, 11), + (2, 10, 15), + (2, 10, 18), + (2, 20, 22), + (2, 20, 25), + (2, 25, null), + (2, 30, 35), + (2, 30, 40), + (2, 30, 50), + (2, 30, 51); + +query RRR +select skewness(k), skewness(v), skewness(v2) from aggr; +---- +-3.316624790355 -0.163443669352 0.365400851103 + +query R +select skewness(v2) as sv2 from aggr group by v ORDER BY sv2; +---- +-0.423273160268 +-0.330140951366 +NULL +NULL + +# Window Function +query R +select skewness(v2) over (partition by v) + from aggr order by v; +---- +-0.423273160268 +-0.423273160268 +-0.423273160268 +-0.423273160268 +NULL +NULL +NULL +-0.330140951366 +-0.330140951366 +-0.330140951366 +-0.330140951366 + +statement ok +drop table aggr; diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index edb0e1d0c9f0..a977636f1361 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -252,6 +252,7 @@ last_value(expression [ORDER BY expression]) - [regr_sxx](#regr_sxx) - [regr_syy](#regr_syy) - [regr_sxy](#regr_sxy) +- [skewness](#skewness) ### `corr` @@ -527,6 +528,19 @@ regr_sxy(expression_y, expression_x) - **expression_x**: Independent variable. Can be a constant, column, or function, and any combination of arithmetic operators. +### `skewness` + +Computes the skewness value. + +``` +skewness(expression) +``` + +#### Arguments + +- **expression**: Expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. + ## Approximate - [approx_distinct](#approx_distinct)