Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement kurtosis_pop UDAF #12273

Merged
merged 12 commits into from
Sep 4, 2024
192 changes: 192 additions & 0 deletions datafusion/functions-aggregate/src/kurtosis_pop.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Array, ArrayRef, Float64Array, UInt64Array};
use arrow::compute::cast;
use arrow_schema::{DataType, Field};
use datafusion_common::cast::as_float64_array;
use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue};
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
use datafusion_functions_aggregate_common::accumulator::{
AccumulatorArgs, StateFieldsArgs,
};
use std::any::Any;
use std::fmt::Debug;

make_udaf_expr_and_func!(
KurtosisPopFunction,
kurtosis_pop,
x,
"Calculates the excess kurtosis (Fisher’s definition) without bias correction.",
kurtosis_pop_udaf
);

pub struct KurtosisPopFunction {
signature: Signature,
}

impl Debug for KurtosisPopFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KurtosisPopFunction")
.field("signature", &self.signature)
.finish()
}
}

impl Default for KurtosisPopFunction {
fn default() -> Self {
Self::new()
}
}

impl KurtosisPopFunction {
pub fn new() -> Self {
Self {
signature: Signature::coercible(
vec![DataType::Float64],
Volatility::Immutable,
),
}
}
}

impl AggregateUDFImpl for KurtosisPopFunction {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"kurtosis_pop"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![
Field::new("count", DataType::UInt64, true),
Field::new("sum", DataType::Float64, true),
Field::new("sum_sqr", DataType::Float64, true),
Field::new("sum_cub", DataType::Float64, true),
Field::new("sum_four", DataType::Float64, true),
])
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(KurtosisPopAccumulator::new()))
}
}

/// Accumulator for calculating the excess kurtosis (Fisher’s definition) without bias correction.
/// This implementation follows the [DuckDB implementation]:
/// <https://github.com/duckdb/duckdb/blob/main/src/core_functions/aggregate/distributive/kurtosis.cpp>
#[derive(Debug, Default)]
pub struct KurtosisPopAccumulator {
count: u64,
sum: f64,
sum_sqr: f64,
sum_cub: f64,
sum_four: f64,
}

impl KurtosisPopAccumulator {
pub fn new() -> Self {
Self {
count: 0,
sum: 0.0,
sum_sqr: 0.0,
sum_cub: 0.0,
sum_four: 0.0,
}
}
}

impl Accumulator for KurtosisPopAccumulator {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to add a link to the algorithm (something like wikipedia or duckdb's implementation)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reminding this. I have added the doc for KurtosisPopAccumulator and updated the function doc.

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = &cast(&values[0], &DataType::Float64)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let values = &cast(&values[0], &DataType::Float64)?;
let array = values[0].as_primitive::<Float64Type>();
for value in array.iter().flatten() {
self.count += 1;
self.sum += value;
self.sum_sqr += value.powi(2);
self.sum_cub += value.powi(3);
self.sum_four += value.powi(4);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also use as_float64_array or as_primitive_opt if you prefer Result than panic.

Copy link
Contributor Author

@goldmedal goldmedal Sep 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good. I prefer to use as_float64_array. However, I think the &cast can't be removed. We should cast from another type array to the float64 array first, then downcast to Float64Array by as_float64_array.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need the cast here? 🤔
The coercion is handled in Signature::Coercible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing! Thanks for the suggestion.

let array = as_float64_array(&values)?;
for value in array.iter().flatten() {
self.count += 1;
self.sum += value;
self.sum_sqr += value.powi(2);
self.sum_cub += value.powi(3);
self.sum_four += value.powi(4);
}
Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let counts = downcast_value!(states[0], UInt64Array);
let sums = downcast_value!(states[1], Float64Array);
let sum_sqrs = downcast_value!(states[2], Float64Array);
let sum_cubs = downcast_value!(states[3], Float64Array);
let sum_fours = downcast_value!(states[4], Float64Array);

for i in 0..counts.len() {
let c = counts.value(i);
if c == 0 {
continue;
}
self.count += c;
self.sum += sums.value(i);
self.sum_sqr += sum_sqrs.value(i);
self.sum_cub += sum_cubs.value(i);
self.sum_four += sum_fours.value(i);
}

Ok(())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
if self.count < 1 {
return Ok(ScalarValue::Float64(None));
}

let count_64 = 1_f64 / self.count as f64;
let m4 = count_64
* (self.sum_four - 4.0 * self.sum_cub * self.sum * count_64
+ 6.0 * self.sum_sqr * self.sum.powi(2) * count_64.powi(2)
- 3.0 * self.sum.powi(4) * count_64.powi(3));
Comment on lines +162 to +166
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the DuckDB way to get the divisor here.

https://github.com/duckdb/duckdb/blob/a706958d15a6fc7fd47d65d22de7deac63613458/src/core_functions/aggregate/distributive/kurtosis.cpp#L69

The result will same as DuckDB but it's different from Clickhouse.

I did some test to compare the behavior between DuckDB and Clickhouse:

DuckDB

D  SELECT kurtosis_pop(col) FROM VALUES (1), (10), (100), (10), (1) as tab(col);
┌─────────────────────┐
│  kurtosis_pop(col)  │
│       double        │
├─────────────────────┤
│ 0.19432323191699075 │
└─────────────────────┘

Clickhouse

:) SELECT kurtPop(value) FROM (SELECT arrayJoin([1, 10, 100, 10, 1]) AS value);

SELECT kurtPop(value)
FROM
(
    SELECT arrayJoin([1, 10, 100, 10, 1]) AS value
)

Query id: abdea377-40b1-4437-a87a-4814f11cc866

   ┌─────kurtPop(value)─┐
1. │ 3.1943232319169903 │
   └────────────────────┘

1 row in set. Elapsed: 0.002 sec. 

Because DuckDB's kurtosis_pop calculates the population kurtosis using Fisher's definition, which results in the excess kurtosis, i.e., the value minus 3, ClickHouse directly provides the population kurtosis value without subtracting 3.

However, if we change the code like

Suggested change
let count_64 = 1_f64 / self.count as f64;
let m4 = count_64
* (self.sum_four - 4.0 * self.sum_cub * self.sum * count_64
+ 6.0 * self.sum_sqr * self.sum.powi(2) * count_64.powi(2)
- 3.0 * self.sum.powi(4) * count_64.powi(3));
let count_64 = self.count as f64;
let m4 =
(self.sum_four - 4.0 * self.sum_cub * self.sum / count_64
+ 6.0 * self.sum_sqr * self.sum.powi(2) / count_64.powi(2)
- 3.0 * self.sum.powi(4) / count_64.powi(3)) / count_64;

The result will same as Clikhouse, 3.1943232319169903 - 3 = 0.1943232319169903

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could follow DuckDB in this case


let m2 = (self.sum_sqr - self.sum.powi(2) * count_64) * count_64;
if m2 <= 0.0 {
return Ok(ScalarValue::Float64(None));
}

let target = m4 / (m2.powi(2)) - 3.0;
Ok(ScalarValue::Float64(Some(target)))
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.count),
ScalarValue::from(self.sum),
ScalarValue::from(self.sum_sqr),
ScalarValue::from(self.sum_cub),
ScalarValue::from(self.sum_four),
])
}
}
2 changes: 2 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pub mod average;
pub mod bit_and_or_xor;
pub mod bool_and_or;
pub mod grouping;
pub mod kurtosis_pop;
pub mod nth_value;
pub mod string_agg;

Expand Down Expand Up @@ -170,6 +171,7 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
average::avg_udaf(),
grouping::grouping_udaf(),
nth_value::nth_value_udaf(),
kurtosis_pop::kurtosis_pop_udaf(),
]
}

Expand Down
2 changes: 2 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ use datafusion_functions_aggregate::expr_fn::{
approx_distinct, array_agg, avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr,
nth_value,
};
use datafusion_functions_aggregate::kurtosis_pop::kurtosis_pop;
use datafusion_functions_aggregate::string_agg::string_agg;
use datafusion_proto::bytes::{
logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec,
Expand Down Expand Up @@ -904,6 +905,7 @@ async fn roundtrip_expr_api() -> Result<()> {
vec![lit(10), lit(20), lit(30)],
),
row_number(),
kurtosis_pop(lit(1)),
nth_value(col("b"), 1, vec![]),
nth_value(
col("b"),
Expand Down
61 changes: 61 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -5863,3 +5863,64 @@ ORDER BY k;
----
1 1.8125 6.8007813 Float16 Float16
2 8.5 8.5 Float16 Float16

# The result is 0.19432323191699075 actually
query R
SELECT kurtosis_pop(col) FROM VALUES (1), (10), (100), (10), (1) as tab(col);
----
0.194323231917
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this function with the CLI

DataFusion CLI v41.0.0
> SELECT kurtosis_pop(col) FROM VALUES (1), (10), (100), (10), (1) as tab(col);
+-----------------------+
| kurtosis_pop(tab.col) |
+-----------------------+
| 0.19432323191699075   |
+-----------------------+

I'm not sure but I guess the sqllogicttest may do some rounds for the result.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sqllogictest will do the rounding according to https://github.com/apache/datafusion/tree/main/datafusion/sqllogictest

floating point values are rounded to the scale of "12",


# The result is -1.153061224489787 actually
query R
SELECT kurtosis_pop(col) FROM VALUES (1), (2), (3), (2), (1) as tab(col);
----
-1.15306122449
Comment on lines +5873 to +5877
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This result is different from DuckDB but I'm not sure why.

D SELECT kurtosis_pop(col) FROM VALUES (1), (2), (3), (2), (1) as tab(col);
┌────────────────────┐
│ kurtosis_pop(col)  │
│       double       │
├────────────────────┤
│ -1.153061224489769 │
└────────────────────┘


query R
SELECT kurtosis_pop(col) FROM VALUES (1.0), (10.0), (100.0), (10.0), (1.0) as tab(col);
----
0.194323231917

query R
SELECT kurtosis_pop(col) FROM VALUES ('1'), ('10'), ('100'), ('10'), ('1') as tab(col);
----
0.194323231917

query R
SELECT kurtosis_pop(col) FROM VALUES (1.0) as tab(col);
----
NULL

query R
SELECT kurtosis_pop(1)
----
NULL

query R
SELECT kurtosis_pop(1.0)
----
NULL

query R
SELECT kurtosis_pop(null)
----
NULL

statement ok
CREATE TABLE t1(c1 int);

query R
SELECT kurtosis_pop(c1) FROM t1;
----
NULL

statement ok
INSERT INTO t1 VALUES (1), (10), (100), (10), (1);

query R
SELECT kurtosis_pop(c1) FROM t1;
----
0.194323231917

statement ok
DROP TABLE t1;
14 changes: 14 additions & 0 deletions docs/source/user-guide/sql/aggregate_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ last_value(expression [ORDER BY expression])
- [regr_sxx](#regr_sxx)
- [regr_syy](#regr_syy)
- [regr_sxy](#regr_sxy)
- [kurtosis_pop](#kurtosis_pop)

### `corr`

Expand Down Expand Up @@ -527,6 +528,19 @@ regr_sxy(expression_y, expression_x)
- **expression_x**: Independent variable.
Can be a constant, column, or function, and any combination of arithmetic operators.

### `kurtosis_pop`

Computes the excess kurtosis (Fisher’s definition) without bias correction.

```
kurtois_pop(expression)
```

#### Arguments

- **expression**: Expression to operate on.
Can be a constant, column, or function, and any combination of arithmetic operators.

## Approximate

- [approx_distinct](#approx_distinct)
Expand Down