Skip to content

Commit e8fdc09

Browse files
authored
Convert VariancePopulation to UDAF (#10836)
1 parent 9503456 commit e8fdc09

File tree

15 files changed

+105
-265
lines changed

15 files changed

+105
-265
lines changed

datafusion/expr/src/aggregate_function.rs

+1-10
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ pub enum AggregateFunction {
4747
ArrayAgg,
4848
/// N'th value in a group according to some ordering
4949
NthValue,
50-
/// Variance (Population)
51-
VariancePop,
5250
/// Correlation
5351
Correlation,
5452
/// Slope from linear regression
@@ -102,7 +100,6 @@ impl AggregateFunction {
102100
ApproxDistinct => "APPROX_DISTINCT",
103101
ArrayAgg => "ARRAY_AGG",
104102
NthValue => "NTH_VALUE",
105-
VariancePop => "VAR_POP",
106103
Correlation => "CORR",
107104
RegrSlope => "REGR_SLOPE",
108105
RegrIntercept => "REGR_INTERCEPT",
@@ -153,7 +150,6 @@ impl FromStr for AggregateFunction {
153150
"string_agg" => AggregateFunction::StringAgg,
154151
// statistical
155152
"corr" => AggregateFunction::Correlation,
156-
"var_pop" => AggregateFunction::VariancePop,
157153
"regr_slope" => AggregateFunction::RegrSlope,
158154
"regr_intercept" => AggregateFunction::RegrIntercept,
159155
"regr_count" => AggregateFunction::RegrCount,
@@ -216,9 +212,6 @@ impl AggregateFunction {
216212
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
217213
Ok(DataType::Boolean)
218214
}
219-
AggregateFunction::VariancePop => {
220-
variance_return_type(&coerced_data_types[0])
221-
}
222215
AggregateFunction::Correlation => {
223216
correlation_return_type(&coerced_data_types[0])
224217
}
@@ -291,9 +284,7 @@ impl AggregateFunction {
291284
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
292285
Signature::uniform(1, vec![DataType::Boolean], Volatility::Immutable)
293286
}
294-
AggregateFunction::Avg
295-
| AggregateFunction::VariancePop
296-
| AggregateFunction::ApproxMedian => {
287+
AggregateFunction::Avg | AggregateFunction::ApproxMedian => {
297288
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
298289
}
299290
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),

datafusion/expr/src/type_coercion/aggregates.rs

-10
Original file line numberDiff line numberDiff line change
@@ -151,16 +151,6 @@ pub fn coerce_types(
151151
}
152152
Ok(input_types.to_vec())
153153
}
154-
AggregateFunction::VariancePop => {
155-
if !is_variance_support_arg_type(&input_types[0]) {
156-
return plan_err!(
157-
"The function {:?} does not support inputs of type {:?}.",
158-
agg_fun,
159-
input_types[0]
160-
);
161-
}
162-
Ok(vec![Float64, Float64])
163-
}
164154
AggregateFunction::Correlation => {
165155
if !is_correlation_support_arg_type(&input_types[0]) {
166156
return plan_err!(

datafusion/functions-aggregate/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ pub mod expr_fn {
7878
pub use super::stddev::stddev;
7979
pub use super::stddev::stddev_pop;
8080
pub use super::sum::sum;
81+
pub use super::variance::var_pop;
8182
pub use super::variance::var_sample;
8283
}
8384

@@ -91,6 +92,7 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
9192
covariance::covar_pop_udaf(),
9293
median::median_udaf(),
9394
variance::var_samp_udaf(),
95+
variance::var_pop_udaf(),
9496
stddev::stddev_udaf(),
9597
stddev::stddev_pop_udaf(),
9698
]

datafusion/functions-aggregate/src/variance.rs

+84-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
//! [`VarianceSample`]: covariance sample aggregations.
18+
//! [`VarianceSample`]: variance sample aggregations.
19+
//! [`VariancePopulation`]: variance population aggregations.
1920
2021
use std::fmt::Debug;
2122

@@ -43,6 +44,14 @@ make_udaf_expr_and_func!(
4344
var_samp_udaf
4445
);
4546

47+
make_udaf_expr_and_func!(
48+
VariancePopulation,
49+
var_pop,
50+
expression,
51+
"Computes the population variance.",
52+
var_pop_udaf
53+
);
54+
4655
pub struct VarianceSample {
4756
signature: Signature,
4857
aliases: Vec<String>,
@@ -115,6 +124,80 @@ impl AggregateUDFImpl for VarianceSample {
115124
}
116125
}
117126

127+
pub struct VariancePopulation {
128+
signature: Signature,
129+
aliases: Vec<String>,
130+
}
131+
132+
impl Debug for VariancePopulation {
133+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
134+
f.debug_struct("VariancePopulation")
135+
.field("name", &self.name())
136+
.field("signature", &self.signature)
137+
.finish()
138+
}
139+
}
140+
141+
impl Default for VariancePopulation {
142+
fn default() -> Self {
143+
Self::new()
144+
}
145+
}
146+
147+
impl VariancePopulation {
148+
pub fn new() -> Self {
149+
Self {
150+
aliases: vec![String::from("var_population")],
151+
signature: Signature::numeric(1, Volatility::Immutable),
152+
}
153+
}
154+
}
155+
156+
impl AggregateUDFImpl for VariancePopulation {
157+
fn as_any(&self) -> &dyn std::any::Any {
158+
self
159+
}
160+
161+
fn name(&self) -> &str {
162+
"var_pop"
163+
}
164+
165+
fn signature(&self) -> &Signature {
166+
&self.signature
167+
}
168+
169+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
170+
if !arg_types[0].is_numeric() {
171+
return plan_err!("Variance requires numeric input types");
172+
}
173+
174+
Ok(DataType::Float64)
175+
}
176+
177+
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
178+
let name = args.name;
179+
Ok(vec![
180+
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
181+
Field::new(format_state_name(name, "mean"), DataType::Float64, true),
182+
Field::new(format_state_name(name, "m2"), DataType::Float64, true),
183+
])
184+
}
185+
186+
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
187+
if acc_args.is_distinct {
188+
return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available");
189+
}
190+
191+
Ok(Box::new(VarianceAccumulator::try_new(
192+
StatsType::Population,
193+
)?))
194+
}
195+
196+
fn aliases(&self) -> &[String] {
197+
&self.aliases
198+
}
199+
}
200+
118201
/// An accumulator to compute variance
119202
/// The algrithm used is an online implementation and numerically stable. It is based on this paper:
120203
/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".

datafusion/physical-expr/src/aggregate/build_in.rs

-45
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,6 @@ pub fn create_aggregate_expr(
157157
(AggregateFunction::Avg, true) => {
158158
return not_impl_err!("AVG(DISTINCT) aggregations are not available");
159159
}
160-
(AggregateFunction::VariancePop, false) => Arc::new(
161-
expressions::VariancePop::new(input_phy_exprs[0].clone(), name, data_type),
162-
),
163-
(AggregateFunction::VariancePop, true) => {
164-
return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available");
165-
}
166160
(AggregateFunction::Correlation, false) => {
167161
Arc::new(expressions::Correlation::new(
168162
input_phy_exprs[0].clone(),
@@ -340,7 +334,6 @@ pub fn create_aggregate_expr(
340334
#[cfg(test)]
341335
mod tests {
342336
use arrow::datatypes::{DataType, Field};
343-
use expressions::VariancePop;
344337

345338
use super::*;
346339
use crate::expressions::{
@@ -693,44 +686,6 @@ mod tests {
693686
Ok(())
694687
}
695688

696-
#[test]
697-
fn test_var_pop_expr() -> Result<()> {
698-
let funcs = vec![AggregateFunction::VariancePop];
699-
let data_types = vec![
700-
DataType::UInt32,
701-
DataType::UInt64,
702-
DataType::Int32,
703-
DataType::Int64,
704-
DataType::Float32,
705-
DataType::Float64,
706-
];
707-
for fun in funcs {
708-
for data_type in &data_types {
709-
let input_schema =
710-
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
711-
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
712-
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
713-
)];
714-
let result_agg_phy_exprs = create_physical_agg_expr_for_test(
715-
&fun,
716-
false,
717-
&input_phy_exprs[0..1],
718-
&input_schema,
719-
"c1",
720-
)?;
721-
if fun == AggregateFunction::VariancePop {
722-
assert!(result_agg_phy_exprs.as_any().is::<VariancePop>());
723-
assert_eq!("c1", result_agg_phy_exprs.name());
724-
assert_eq!(
725-
Field::new("c1", DataType::Float64, true),
726-
result_agg_phy_exprs.field().unwrap()
727-
)
728-
}
729-
}
730-
}
731-
Ok(())
732-
}
733-
734689
#[test]
735690
fn test_median_expr() -> Result<()> {
736691
let funcs = vec![AggregateFunction::ApproxMedian];

0 commit comments

Comments
 (0)