Skip to content

Commit 74022d5

Browse files
yyin-devfindepi
authored andcommitted
Convert variance sample to udaf (apache#10713)
* Without migrating tests * Should fail VAR(DISTINCT) but doesn't * Pass all other tests. * Return error for var(distinct) * Migrate tests * Fix tests * Lint * Fix tests * Fix use
1 parent 8a7b66d commit 74022d5

File tree

17 files changed

+338
-180
lines changed

17 files changed

+338
-180
lines changed

datafusion/expr/src/aggregate_function.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ pub enum AggregateFunction {
4949
ArrayAgg,
5050
/// N'th value in a group according to some ordering
5151
NthValue,
52-
/// Variance (Sample)
53-
Variance,
5452
/// Variance (Population)
5553
VariancePop,
5654
/// Standard Deviation (Sample)
@@ -111,7 +109,6 @@ impl AggregateFunction {
111109
ApproxDistinct => "APPROX_DISTINCT",
112110
ArrayAgg => "ARRAY_AGG",
113111
NthValue => "NTH_VALUE",
114-
Variance => "VAR",
115112
VariancePop => "VAR_POP",
116113
Stddev => "STDDEV",
117114
StddevPop => "STDDEV_POP",
@@ -169,9 +166,7 @@ impl FromStr for AggregateFunction {
169166
"stddev" => AggregateFunction::Stddev,
170167
"stddev_pop" => AggregateFunction::StddevPop,
171168
"stddev_samp" => AggregateFunction::Stddev,
172-
"var" => AggregateFunction::Variance,
173169
"var_pop" => AggregateFunction::VariancePop,
174-
"var_samp" => AggregateFunction::Variance,
175170
"regr_slope" => AggregateFunction::RegrSlope,
176171
"regr_intercept" => AggregateFunction::RegrIntercept,
177172
"regr_count" => AggregateFunction::RegrCount,
@@ -235,7 +230,6 @@ impl AggregateFunction {
235230
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
236231
Ok(DataType::Boolean)
237232
}
238-
AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]),
239233
AggregateFunction::VariancePop => {
240234
variance_return_type(&coerced_data_types[0])
241235
}
@@ -315,7 +309,6 @@ impl AggregateFunction {
315309
}
316310
AggregateFunction::Avg
317311
| AggregateFunction::Sum
318-
| AggregateFunction::Variance
319312
| AggregateFunction::VariancePop
320313
| AggregateFunction::Stddev
321314
| AggregateFunction::StddevPop

datafusion/expr/src/type_coercion/aggregates.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ pub fn coerce_types(
173173
}
174174
Ok(input_types.to_vec())
175175
}
176-
AggregateFunction::Variance | AggregateFunction::VariancePop => {
176+
AggregateFunction::VariancePop => {
177177
if !is_variance_support_arg_type(&input_types[0]) {
178178
return plan_err!(
179179
"The function {:?} does not support inputs of type {:?}.",

datafusion/functions-aggregate/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ pub mod covariance;
5959
pub mod first_last;
6060
pub mod median;
6161
pub mod sum;
62+
pub mod variance;
6263

6364
use datafusion_common::Result;
6465
use datafusion_execution::FunctionRegistry;
@@ -74,6 +75,7 @@ pub mod expr_fn {
7475
pub use super::first_last::last_value;
7576
pub use super::median::median;
7677
pub use super::sum::sum;
78+
pub use super::variance::var_sample;
7779
}
7880

7981
/// Returns all default aggregate functions
@@ -85,6 +87,7 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
8587
sum::sum_udaf(),
8688
covariance::covar_pop_udaf(),
8789
median::median_udaf(),
90+
variance::var_samp_udaf(),
8891
]
8992
}
9093

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`VarianceSample`]: covariance sample aggregations.
19+
20+
use std::fmt::Debug;
21+
22+
use arrow::{
23+
array::{ArrayRef, Float64Array, UInt64Array},
24+
compute::kernels::cast,
25+
datatypes::{DataType, Field},
26+
};
27+
28+
use datafusion_common::{
29+
downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue,
30+
};
31+
use datafusion_expr::{
32+
function::{AccumulatorArgs, StateFieldsArgs},
33+
utils::format_state_name,
34+
Accumulator, AggregateUDFImpl, Signature, Volatility,
35+
};
36+
use datafusion_physical_expr_common::aggregate::stats::StatsType;
37+
38+
make_udaf_expr_and_func!(
39+
VarianceSample,
40+
var_sample,
41+
expression,
42+
"Computes the sample variance.",
43+
var_samp_udaf
44+
);
45+
46+
pub struct VarianceSample {
47+
signature: Signature,
48+
aliases: Vec<String>,
49+
}
50+
51+
impl Debug for VarianceSample {
52+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
53+
f.debug_struct("VarianceSample")
54+
.field("name", &self.name())
55+
.field("signature", &self.signature)
56+
.finish()
57+
}
58+
}
59+
60+
impl Default for VarianceSample {
61+
fn default() -> Self {
62+
Self::new()
63+
}
64+
}
65+
66+
impl VarianceSample {
67+
pub fn new() -> Self {
68+
Self {
69+
aliases: vec![String::from("var_sample"), String::from("var_samp")],
70+
signature: Signature::numeric(1, Volatility::Immutable),
71+
}
72+
}
73+
}
74+
75+
impl AggregateUDFImpl for VarianceSample {
76+
fn as_any(&self) -> &dyn std::any::Any {
77+
self
78+
}
79+
80+
fn name(&self) -> &str {
81+
"var"
82+
}
83+
84+
fn signature(&self) -> &Signature {
85+
&self.signature
86+
}
87+
88+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
89+
if !arg_types[0].is_numeric() {
90+
return plan_err!("Variance requires numeric input types");
91+
}
92+
93+
Ok(DataType::Float64)
94+
}
95+
96+
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
97+
let name = args.name;
98+
Ok(vec![
99+
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
100+
Field::new(format_state_name(name, "mean"), DataType::Float64, true),
101+
Field::new(format_state_name(name, "m2"), DataType::Float64, true),
102+
])
103+
}
104+
105+
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
106+
if acc_args.is_distinct {
107+
return not_impl_err!("VAR(DISTINCT) aggregations are not available");
108+
}
109+
110+
Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
111+
}
112+
113+
fn aliases(&self) -> &[String] {
114+
&self.aliases
115+
}
116+
}
117+
118+
/// An accumulator to compute variance
119+
/// The algrithm used is an online implementation and numerically stable. It is based on this paper:
120+
/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".
121+
/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577.
122+
///
123+
/// The algorithm has been analyzed here:
124+
/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances".
125+
/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154.
126+
127+
#[derive(Debug)]
128+
pub struct VarianceAccumulator {
129+
m2: f64,
130+
mean: f64,
131+
count: u64,
132+
stats_type: StatsType,
133+
}
134+
135+
impl VarianceAccumulator {
136+
/// Creates a new `VarianceAccumulator`
137+
pub fn try_new(s_type: StatsType) -> Result<Self> {
138+
Ok(Self {
139+
m2: 0_f64,
140+
mean: 0_f64,
141+
count: 0_u64,
142+
stats_type: s_type,
143+
})
144+
}
145+
146+
pub fn get_count(&self) -> u64 {
147+
self.count
148+
}
149+
150+
pub fn get_mean(&self) -> f64 {
151+
self.mean
152+
}
153+
154+
pub fn get_m2(&self) -> f64 {
155+
self.m2
156+
}
157+
}
158+
159+
impl Accumulator for VarianceAccumulator {
160+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
161+
Ok(vec![
162+
ScalarValue::from(self.count),
163+
ScalarValue::from(self.mean),
164+
ScalarValue::from(self.m2),
165+
])
166+
}
167+
168+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
169+
let values = &cast(&values[0], &DataType::Float64)?;
170+
let arr = downcast_value!(values, Float64Array).iter().flatten();
171+
172+
for value in arr {
173+
let new_count = self.count + 1;
174+
let delta1 = value - self.mean;
175+
let new_mean = delta1 / new_count as f64 + self.mean;
176+
let delta2 = value - new_mean;
177+
let new_m2 = self.m2 + delta1 * delta2;
178+
179+
self.count += 1;
180+
self.mean = new_mean;
181+
self.m2 = new_m2;
182+
}
183+
184+
Ok(())
185+
}
186+
187+
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
188+
let values = &cast(&values[0], &DataType::Float64)?;
189+
let arr = downcast_value!(values, Float64Array).iter().flatten();
190+
191+
for value in arr {
192+
let new_count = self.count - 1;
193+
let delta1 = self.mean - value;
194+
let new_mean = delta1 / new_count as f64 + self.mean;
195+
let delta2 = new_mean - value;
196+
let new_m2 = self.m2 - delta1 * delta2;
197+
198+
self.count -= 1;
199+
self.mean = new_mean;
200+
self.m2 = new_m2;
201+
}
202+
203+
Ok(())
204+
}
205+
206+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
207+
let counts = downcast_value!(states[0], UInt64Array);
208+
let means = downcast_value!(states[1], Float64Array);
209+
let m2s = downcast_value!(states[2], Float64Array);
210+
211+
for i in 0..counts.len() {
212+
let c = counts.value(i);
213+
if c == 0_u64 {
214+
continue;
215+
}
216+
let new_count = self.count + c;
217+
let new_mean = self.mean * self.count as f64 / new_count as f64
218+
+ means.value(i) * c as f64 / new_count as f64;
219+
let delta = self.mean - means.value(i);
220+
let new_m2 = self.m2
221+
+ m2s.value(i)
222+
+ delta * delta * self.count as f64 * c as f64 / new_count as f64;
223+
224+
self.count = new_count;
225+
self.mean = new_mean;
226+
self.m2 = new_m2;
227+
}
228+
Ok(())
229+
}
230+
231+
fn evaluate(&mut self) -> Result<ScalarValue> {
232+
let count = match self.stats_type {
233+
StatsType::Population => self.count,
234+
StatsType::Sample => {
235+
if self.count > 0 {
236+
self.count - 1
237+
} else {
238+
self.count
239+
}
240+
}
241+
};
242+
243+
Ok(ScalarValue::Float64(match self.count {
244+
0 => None,
245+
1 => {
246+
if let StatsType::Population = self.stats_type {
247+
Some(0.0)
248+
} else {
249+
None
250+
}
251+
}
252+
_ => Some(self.m2 / count as f64),
253+
}))
254+
}
255+
256+
fn size(&self) -> usize {
257+
std::mem::size_of_val(self)
258+
}
259+
260+
fn supports_retract_batch(&self) -> bool {
261+
true
262+
}
263+
}

0 commit comments

Comments
 (0)