Skip to content

Commit a0fccbf

Browse files
authored
Move Covariance (Sample) covar / covar_samp to be a User Defined Aggregate Function (#10372)
* introduce CovarianceSample Signed-off-by: jayzhan211 <[email protected]> * rewrite macro Signed-off-by: jayzhan211 <[email protected]> * rm old statstype Signed-off-by: jayzhan211 <[email protected]> * register Signed-off-by: jayzhan211 <[email protected]> * state field Signed-off-by: jayzhan211 <[email protected]> * rm builtin Signed-off-by: jayzhan211 <[email protected]> * addres comments Signed-off-by: jayzhan211 <[email protected]> --------- Signed-off-by: jayzhan211 <[email protected]>
1 parent c1f1370 commit a0fccbf

File tree

21 files changed

+418
-391
lines changed

21 files changed

+418
-391
lines changed

datafusion/core/src/physical_planner.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1901,6 +1901,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
19011901
let ignore_nulls = null_treatment
19021902
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
19031903
== NullTreatment::IgnoreNulls;
1904+
19041905
let (agg_expr, filter, order_by) = match func_def {
19051906
AggregateFunctionDefinition::BuiltIn(fun) => {
19061907
let physical_sort_exprs = match order_by {

datafusion/expr/src/aggregate_function.rs

+1-10
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ pub enum AggregateFunction {
6363
Stddev,
6464
/// Standard Deviation (Population)
6565
StddevPop,
66-
/// Covariance (Sample)
67-
Covariance,
6866
/// Covariance (Population)
6967
CovariancePop,
7068
/// Correlation
@@ -128,7 +126,6 @@ impl AggregateFunction {
128126
VariancePop => "VAR_POP",
129127
Stddev => "STDDEV",
130128
StddevPop => "STDDEV_POP",
131-
Covariance => "COVAR",
132129
CovariancePop => "COVAR_POP",
133130
Correlation => "CORR",
134131
RegrSlope => "REGR_SLOPE",
@@ -184,9 +181,7 @@ impl FromStr for AggregateFunction {
184181
"string_agg" => AggregateFunction::StringAgg,
185182
// statistical
186183
"corr" => AggregateFunction::Correlation,
187-
"covar" => AggregateFunction::Covariance,
188184
"covar_pop" => AggregateFunction::CovariancePop,
189-
"covar_samp" => AggregateFunction::Covariance,
190185
"stddev" => AggregateFunction::Stddev,
191186
"stddev_pop" => AggregateFunction::StddevPop,
192187
"stddev_samp" => AggregateFunction::Stddev,
@@ -260,9 +255,6 @@ impl AggregateFunction {
260255
AggregateFunction::VariancePop => {
261256
variance_return_type(&coerced_data_types[0])
262257
}
263-
AggregateFunction::Covariance => {
264-
covariance_return_type(&coerced_data_types[0])
265-
}
266258
AggregateFunction::CovariancePop => {
267259
covariance_return_type(&coerced_data_types[0])
268260
}
@@ -357,8 +349,7 @@ impl AggregateFunction {
357349
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
358350
}
359351
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
360-
AggregateFunction::Covariance
361-
| AggregateFunction::CovariancePop
352+
AggregateFunction::CovariancePop
362353
| AggregateFunction::Correlation
363354
| AggregateFunction::RegrSlope
364355
| AggregateFunction::RegrIntercept

datafusion/expr/src/type_coercion/aggregates.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ pub fn coerce_types(
183183
}
184184
Ok(vec![Float64, Float64])
185185
}
186-
AggregateFunction::Covariance | AggregateFunction::CovariancePop => {
186+
AggregateFunction::CovariancePop => {
187187
if !is_covariance_support_arg_type(&input_types[0]) {
188188
return plan_err!(
189189
"The function {:?} does not support inputs of type {:?}.",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`CovarianceSample`]: covariance sample aggregations.
19+
20+
use std::fmt::Debug;
21+
22+
use arrow::{
23+
array::{ArrayRef, Float64Array, UInt64Array},
24+
compute::kernels::cast,
25+
datatypes::{DataType, Field},
26+
};
27+
28+
use datafusion_common::{
29+
downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result,
30+
ScalarValue,
31+
};
32+
use datafusion_expr::{
33+
function::AccumulatorArgs, type_coercion::aggregates::NUMERICS,
34+
utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility,
35+
};
36+
use datafusion_physical_expr_common::aggregate::stats::StatsType;
37+
38+
make_udaf_expr_and_func!(
39+
CovarianceSample,
40+
covar_samp,
41+
y x,
42+
"Computes the sample covariance.",
43+
covar_samp_udaf
44+
);
45+
46+
pub struct CovarianceSample {
47+
signature: Signature,
48+
aliases: Vec<String>,
49+
}
50+
51+
impl Debug for CovarianceSample {
52+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
53+
f.debug_struct("CovarianceSample")
54+
.field("name", &self.name())
55+
.field("signature", &self.signature)
56+
.finish()
57+
}
58+
}
59+
60+
impl Default for CovarianceSample {
61+
fn default() -> Self {
62+
Self::new()
63+
}
64+
}
65+
66+
impl CovarianceSample {
67+
pub fn new() -> Self {
68+
Self {
69+
aliases: vec![String::from("covar")],
70+
signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
71+
}
72+
}
73+
}
74+
75+
impl AggregateUDFImpl for CovarianceSample {
76+
fn as_any(&self) -> &dyn std::any::Any {
77+
self
78+
}
79+
80+
fn name(&self) -> &str {
81+
"covar_samp"
82+
}
83+
84+
fn signature(&self) -> &Signature {
85+
&self.signature
86+
}
87+
88+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
89+
if !arg_types[0].is_numeric() {
90+
return plan_err!("Covariance requires numeric input types");
91+
}
92+
93+
Ok(DataType::Float64)
94+
}
95+
96+
fn state_fields(
97+
&self,
98+
name: &str,
99+
_value_type: DataType,
100+
_ordering_fields: Vec<Field>,
101+
) -> Result<Vec<Field>> {
102+
Ok(vec![
103+
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
104+
Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
105+
Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
106+
Field::new(
107+
format_state_name(name, "algo_const"),
108+
DataType::Float64,
109+
true,
110+
),
111+
])
112+
}
113+
114+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
115+
Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?))
116+
}
117+
118+
fn aliases(&self) -> &[String] {
119+
&self.aliases
120+
}
121+
}
122+
123+
/// An accumulator to compute covariance
124+
/// The algorithm used is an online implementation and numerically stable. It is derived from the following paper
125+
/// for calculating variance:
126+
/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".
127+
/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
128+
///
129+
/// The algorithm has been analyzed here:
130+
/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances".
131+
/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154.
132+
///
133+
/// Though it is not covered in the original paper but is based on the same idea, as a result the algorithm is online,
134+
/// parallelizable and numerically stable.
135+
136+
#[derive(Debug)]
137+
pub struct CovarianceAccumulator {
138+
algo_const: f64,
139+
mean1: f64,
140+
mean2: f64,
141+
count: u64,
142+
stats_type: StatsType,
143+
}
144+
145+
impl CovarianceAccumulator {
146+
/// Creates a new `CovarianceAccumulator`
147+
pub fn try_new(s_type: StatsType) -> Result<Self> {
148+
Ok(Self {
149+
algo_const: 0_f64,
150+
mean1: 0_f64,
151+
mean2: 0_f64,
152+
count: 0_u64,
153+
stats_type: s_type,
154+
})
155+
}
156+
157+
pub fn get_count(&self) -> u64 {
158+
self.count
159+
}
160+
161+
pub fn get_mean1(&self) -> f64 {
162+
self.mean1
163+
}
164+
165+
pub fn get_mean2(&self) -> f64 {
166+
self.mean2
167+
}
168+
169+
pub fn get_algo_const(&self) -> f64 {
170+
self.algo_const
171+
}
172+
}
173+
174+
impl Accumulator for CovarianceAccumulator {
175+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
176+
Ok(vec![
177+
ScalarValue::from(self.count),
178+
ScalarValue::from(self.mean1),
179+
ScalarValue::from(self.mean2),
180+
ScalarValue::from(self.algo_const),
181+
])
182+
}
183+
184+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
185+
let values1 = &cast(&values[0], &DataType::Float64)?;
186+
let values2 = &cast(&values[1], &DataType::Float64)?;
187+
188+
let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
189+
let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
190+
191+
for i in 0..values1.len() {
192+
let value1 = if values1.is_valid(i) {
193+
arr1.next()
194+
} else {
195+
None
196+
};
197+
let value2 = if values2.is_valid(i) {
198+
arr2.next()
199+
} else {
200+
None
201+
};
202+
203+
if value1.is_none() || value2.is_none() {
204+
continue;
205+
}
206+
207+
let value1 = unwrap_or_internal_err!(value1);
208+
let value2 = unwrap_or_internal_err!(value2);
209+
let new_count = self.count + 1;
210+
let delta1 = value1 - self.mean1;
211+
let new_mean1 = delta1 / new_count as f64 + self.mean1;
212+
let delta2 = value2 - self.mean2;
213+
let new_mean2 = delta2 / new_count as f64 + self.mean2;
214+
let new_c = delta1 * (value2 - new_mean2) + self.algo_const;
215+
216+
self.count += 1;
217+
self.mean1 = new_mean1;
218+
self.mean2 = new_mean2;
219+
self.algo_const = new_c;
220+
}
221+
222+
Ok(())
223+
}
224+
225+
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
226+
let values1 = &cast(&values[0], &DataType::Float64)?;
227+
let values2 = &cast(&values[1], &DataType::Float64)?;
228+
let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
229+
let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
230+
231+
for i in 0..values1.len() {
232+
let value1 = if values1.is_valid(i) {
233+
arr1.next()
234+
} else {
235+
None
236+
};
237+
let value2 = if values2.is_valid(i) {
238+
arr2.next()
239+
} else {
240+
None
241+
};
242+
243+
if value1.is_none() || value2.is_none() {
244+
continue;
245+
}
246+
247+
let value1 = unwrap_or_internal_err!(value1);
248+
let value2 = unwrap_or_internal_err!(value2);
249+
250+
let new_count = self.count - 1;
251+
let delta1 = self.mean1 - value1;
252+
let new_mean1 = delta1 / new_count as f64 + self.mean1;
253+
let delta2 = self.mean2 - value2;
254+
let new_mean2 = delta2 / new_count as f64 + self.mean2;
255+
let new_c = self.algo_const - delta1 * (new_mean2 - value2);
256+
257+
self.count -= 1;
258+
self.mean1 = new_mean1;
259+
self.mean2 = new_mean2;
260+
self.algo_const = new_c;
261+
}
262+
263+
Ok(())
264+
}
265+
266+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
267+
let counts = downcast_value!(states[0], UInt64Array);
268+
let means1 = downcast_value!(states[1], Float64Array);
269+
let means2 = downcast_value!(states[2], Float64Array);
270+
let cs = downcast_value!(states[3], Float64Array);
271+
272+
for i in 0..counts.len() {
273+
let c = counts.value(i);
274+
if c == 0_u64 {
275+
continue;
276+
}
277+
let new_count = self.count + c;
278+
let new_mean1 = self.mean1 * self.count as f64 / new_count as f64
279+
+ means1.value(i) * c as f64 / new_count as f64;
280+
let new_mean2 = self.mean2 * self.count as f64 / new_count as f64
281+
+ means2.value(i) * c as f64 / new_count as f64;
282+
let delta1 = self.mean1 - means1.value(i);
283+
let delta2 = self.mean2 - means2.value(i);
284+
let new_c = self.algo_const
285+
+ cs.value(i)
286+
+ delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64;
287+
288+
self.count = new_count;
289+
self.mean1 = new_mean1;
290+
self.mean2 = new_mean2;
291+
self.algo_const = new_c;
292+
}
293+
Ok(())
294+
}
295+
296+
fn evaluate(&mut self) -> Result<ScalarValue> {
297+
let count = match self.stats_type {
298+
StatsType::Population => self.count,
299+
StatsType::Sample => {
300+
if self.count > 0 {
301+
self.count - 1
302+
} else {
303+
self.count
304+
}
305+
}
306+
};
307+
308+
if count == 0 {
309+
Ok(ScalarValue::Float64(None))
310+
} else {
311+
Ok(ScalarValue::Float64(Some(self.algo_const / count as f64)))
312+
}
313+
}
314+
315+
fn size(&self) -> usize {
316+
std::mem::size_of_val(self)
317+
}
318+
}

datafusion/functions-aggregate/src/first_last.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ use datafusion_physical_expr_common::expressions;
3939
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
4040
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
4141
use datafusion_physical_expr_common::utils::reverse_order_bys;
42-
use sqlparser::ast::NullTreatment;
42+
4343
use std::any::Any;
4444
use std::fmt::Debug;
4545
use std::sync::Arc;
4646

47-
make_udaf_function!(
47+
make_udaf_expr_and_func!(
4848
FirstValue,
4949
first_value,
5050
"Returns the first value in a group of values.",

0 commit comments

Comments
 (0)