Skip to content

Commit 230c68c

Browse files
authored
Add simplify method to aggregate function (#10354)
* add simplify method for aggregate function * simplify returns closure
1 parent 58cc4e1 commit 230c68c

File tree

4 files changed

+328
-3
lines changed

4 files changed

+328
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow_schema::{Field, Schema};
19+
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
20+
use datafusion_expr::function::AggregateFunctionSimplification;
21+
use datafusion_expr::simplify::SimplifyInfo;
22+
23+
use std::{any::Any, sync::Arc};
24+
25+
use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
26+
use datafusion::error::Result;
27+
use datafusion::{assert_batches_eq, prelude::*};
28+
use datafusion_common::cast::as_float64_array;
29+
use datafusion_expr::{
30+
expr::{AggregateFunction, AggregateFunctionDefinition},
31+
function::AccumulatorArgs,
32+
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
33+
};
34+
35+
/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
36+
/// defined aggregate function with a different expression which is defined in the `simplify` method.
37+
38+
#[derive(Debug, Clone)]
39+
struct BetterAvgUdaf {
40+
signature: Signature,
41+
}
42+
43+
impl BetterAvgUdaf {
44+
/// Create a new instance of the GeoMeanUdaf struct
45+
fn new() -> Self {
46+
Self {
47+
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
48+
}
49+
}
50+
}
51+
52+
impl AggregateUDFImpl for BetterAvgUdaf {
53+
fn as_any(&self) -> &dyn Any {
54+
self
55+
}
56+
57+
fn name(&self) -> &str {
58+
"better_avg"
59+
}
60+
61+
fn signature(&self) -> &Signature {
62+
&self.signature
63+
}
64+
65+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
66+
Ok(DataType::Float64)
67+
}
68+
69+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
70+
unimplemented!("should not be invoked")
71+
}
72+
73+
fn state_fields(
74+
&self,
75+
_name: &str,
76+
_value_type: DataType,
77+
_ordering_fields: Vec<arrow_schema::Field>,
78+
) -> Result<Vec<arrow_schema::Field>> {
79+
unimplemented!("should not be invoked")
80+
}
81+
82+
fn groups_accumulator_supported(&self) -> bool {
83+
true
84+
}
85+
86+
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
87+
unimplemented!("should not get here");
88+
}
89+
// we override method, to return new expression which would substitute
90+
// user defined function call
91+
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
92+
// as an example for this functionality we replace UDF function
93+
// with build-in aggregate function to illustrate the use
94+
let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction,
95+
_: &dyn SimplifyInfo| {
96+
Ok(Expr::AggregateFunction(AggregateFunction {
97+
func_def: AggregateFunctionDefinition::BuiltIn(
98+
// yes it is the same Avg, `BetterAvgUdaf` was just a
99+
// marketing pitch :)
100+
datafusion_expr::aggregate_function::AggregateFunction::Avg,
101+
),
102+
args: aggregate_function.args,
103+
distinct: aggregate_function.distinct,
104+
filter: aggregate_function.filter,
105+
order_by: aggregate_function.order_by,
106+
null_treatment: aggregate_function.null_treatment,
107+
}))
108+
};
109+
110+
Some(Box::new(simplify))
111+
}
112+
}
113+
114+
// create local session context with an in-memory table
115+
fn create_context() -> Result<SessionContext> {
116+
use datafusion::datasource::MemTable;
117+
// define a schema.
118+
let schema = Arc::new(Schema::new(vec![
119+
Field::new("a", DataType::Float32, false),
120+
Field::new("b", DataType::Float32, false),
121+
]));
122+
123+
// define data in two partitions
124+
let batch1 = RecordBatch::try_new(
125+
schema.clone(),
126+
vec![
127+
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
128+
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
129+
],
130+
)?;
131+
let batch2 = RecordBatch::try_new(
132+
schema.clone(),
133+
vec![
134+
Arc::new(Float32Array::from(vec![16.0])),
135+
Arc::new(Float32Array::from(vec![2.0])),
136+
],
137+
)?;
138+
139+
let ctx = SessionContext::new();
140+
141+
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
142+
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
143+
ctx.register_table("t", Arc::new(provider))?;
144+
Ok(ctx)
145+
}
146+
147+
#[tokio::main]
148+
async fn main() -> Result<()> {
149+
let ctx = create_context()?;
150+
151+
let better_avg = AggregateUDF::from(BetterAvgUdaf::new());
152+
ctx.register_udaf(better_avg.clone());
153+
154+
let result = ctx
155+
.sql("SELECT better_avg(a) FROM t group by b")
156+
.await?
157+
.collect()
158+
.await?;
159+
160+
let expected = [
161+
"+-----------------+",
162+
"| better_avg(t.a) |",
163+
"+-----------------+",
164+
"| 7.5 |",
165+
"+-----------------+",
166+
];
167+
168+
assert_batches_eq!(expected, &result);
169+
170+
let df = ctx.table("t").await?;
171+
let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?;
172+
173+
let results = df.collect().await?;
174+
let result = as_float64_array(results[0].column(0))?;
175+
176+
assert!((result.value(0) - 7.5).abs() < f64::EPSILON);
177+
println!("The average of [2,4,8,16] is {}", result.value(0));
178+
179+
Ok(())
180+
}

datafusion/expr/src/function.rs

+13
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,16 @@ pub type PartitionEvaluatorFactory =
9797
/// its state, given its return datatype.
9898
pub type StateTypeFunction =
9999
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;
100+
101+
/// [crate::udaf::AggregateUDFImpl::simplify] simplifier closure
102+
/// A closure with two arguments:
103+
/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
104+
/// * 'info': [crate::simplify::SimplifyInfo]
105+
///
106+
/// closure returns simplified [Expr] or an error.
107+
pub type AggregateFunctionSimplification = Box<
108+
dyn Fn(
109+
crate::expr::AggregateFunction,
110+
&dyn crate::simplify::SimplifyInfo,
111+
) -> Result<Expr>,
112+
>;

datafusion/expr/src/udaf.rs

+32-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
//! [`AggregateUDF`]: User Defined Aggregate Functions
1919
20-
use crate::function::AccumulatorArgs;
20+
use crate::function::{AccumulatorArgs, AggregateFunctionSimplification};
2121
use crate::groups_accumulator::GroupsAccumulator;
2222
use crate::utils::format_state_name;
2323
use crate::{Accumulator, Expr};
@@ -199,6 +199,12 @@ impl AggregateUDF {
199199
pub fn coerce_types(&self, _args: &[DataType]) -> Result<Vec<DataType>> {
200200
not_impl_err!("coerce_types not implemented for {:?} yet", self.name())
201201
}
202+
/// Do the function rewrite
203+
///
204+
/// See [`AggregateUDFImpl::simplify`] for more details.
205+
pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
206+
self.inner.simplify()
207+
}
202208
}
203209

204210
impl<F> From<F> for AggregateUDF
@@ -358,6 +364,31 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
358364
fn aliases(&self) -> &[String] {
359365
&[]
360366
}
367+
368+
/// Optionally apply per-UDaF simplification / rewrite rules.
369+
///
370+
/// This can be used to apply function specific simplification rules during
371+
/// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
372+
/// implementation does nothing.
373+
///
374+
/// Note that DataFusion handles simplifying arguments and "constant
375+
/// folding" (replacing a function call with constant arguments such as
376+
/// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
377+
/// optimizations manually for specific UDFs.
378+
///
379+
/// # Returns
380+
///
381+
/// [None] if simplify is not defined or,
382+
///
383+
/// Or, a closure with two arguments:
384+
/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
385+
/// * 'info': [crate::simplify::SimplifyInfo]
386+
///
387+
/// closure returns simplified [Expr] or an error.
388+
///
389+
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
390+
None
391+
}
361392
}
362393

363394
/// AggregateUDF that adds an alias to the underlying function. It is better to

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

+103-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use datafusion_common::{
3232
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
3333
};
3434
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
35-
use datafusion_expr::expr::{InList, InSubquery};
35+
use datafusion_expr::expr::{AggregateFunctionDefinition, InList, InSubquery};
3636
use datafusion_expr::simplify::ExprSimplifyResult;
3737
use datafusion_expr::{
3838
and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility,
@@ -1382,6 +1382,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
13821382
}
13831383
}
13841384

1385+
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
1386+
func_def: AggregateFunctionDefinition::UDF(ref udaf),
1387+
..
1388+
}) => match (udaf.simplify(), expr) {
1389+
(Some(simplify_function), Expr::AggregateFunction(af)) => {
1390+
Transformed::yes(simplify_function(af, info)?)
1391+
}
1392+
(_, expr) => Transformed::no(expr),
1393+
},
1394+
13851395
//
13861396
// Rules for Between
13871397
//
@@ -1748,7 +1758,9 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {
17481758
#[cfg(test)]
17491759
mod tests {
17501760
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
1751-
use datafusion_expr::{interval_arithmetic::Interval, *};
1761+
use datafusion_expr::{
1762+
function::AggregateFunctionSimplification, interval_arithmetic::Interval, *,
1763+
};
17521764
use std::{
17531765
collections::HashMap,
17541766
ops::{BitAnd, BitOr, BitXor},
@@ -3698,4 +3710,93 @@ mod tests {
36983710
assert_eq!(expr, expected);
36993711
assert_eq!(num_iter, 2);
37003712
}
3713+
#[test]
3714+
fn test_simplify_udaf() {
3715+
let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());
3716+
let aggregate_function_expr =
3717+
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
3718+
udaf.into(),
3719+
vec![],
3720+
false,
3721+
None,
3722+
None,
3723+
None,
3724+
));
3725+
3726+
let expected = col("result_column");
3727+
assert_eq!(simplify(aggregate_function_expr), expected);
3728+
3729+
let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify());
3730+
let aggregate_function_expr =
3731+
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
3732+
udaf.into(),
3733+
vec![],
3734+
false,
3735+
None,
3736+
None,
3737+
None,
3738+
));
3739+
3740+
let expected = aggregate_function_expr.clone();
3741+
assert_eq!(simplify(aggregate_function_expr), expected);
3742+
}
3743+
3744+
/// A Mock UDAF which defines `simplify` to be used in tests
3745+
/// related to UDAF simplification
3746+
#[derive(Debug, Clone)]
3747+
struct SimplifyMockUdaf {
3748+
simplify: bool,
3749+
}
3750+
3751+
impl SimplifyMockUdaf {
3752+
/// make simplify method return new expression
3753+
fn new_with_simplify() -> Self {
3754+
Self { simplify: true }
3755+
}
3756+
/// make simplify method return no change
3757+
fn new_without_simplify() -> Self {
3758+
Self { simplify: false }
3759+
}
3760+
}
3761+
3762+
impl AggregateUDFImpl for SimplifyMockUdaf {
3763+
fn as_any(&self) -> &dyn std::any::Any {
3764+
self
3765+
}
3766+
3767+
fn name(&self) -> &str {
3768+
"mock_simplify"
3769+
}
3770+
3771+
fn signature(&self) -> &Signature {
3772+
unimplemented!()
3773+
}
3774+
3775+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3776+
unimplemented!("not needed for tests")
3777+
}
3778+
3779+
fn accumulator(
3780+
&self,
3781+
_acc_args: function::AccumulatorArgs,
3782+
) -> Result<Box<dyn Accumulator>> {
3783+
unimplemented!("not needed for tests")
3784+
}
3785+
3786+
fn groups_accumulator_supported(&self) -> bool {
3787+
unimplemented!("not needed for testing")
3788+
}
3789+
3790+
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
3791+
unimplemented!("not needed for testing")
3792+
}
3793+
3794+
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
3795+
if self.simplify {
3796+
Some(Box::new(|_, _| Ok(col("result_column"))))
3797+
} else {
3798+
None
3799+
}
3800+
}
3801+
}
37013802
}

0 commit comments

Comments
 (0)