Skip to content

Commit c40690c

Browse files
committed
simplify returns closure
1 parent d13a57d commit c40690c

File tree

4 files changed

+58
-61
lines changed

4 files changed

+58
-61
lines changed

datafusion-examples/examples/simplify_udaf_expression.rs

+19-20
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use arrow_schema::{Field, Schema};
1919
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
20-
use datafusion_common::tree_node::Transformed;
20+
use datafusion_expr::function::AggregateFunctionSimplification;
2121
use datafusion_expr::simplify::SimplifyInfo;
2222

2323
use std::{any::Any, sync::Arc};
@@ -88,27 +88,26 @@ impl AggregateUDFImpl for BetterAvgUdaf {
8888
}
8989
// we override method, to return new expression which would substitute
9090
// user defined function call
91-
fn simplify(
92-
&self,
93-
aggregate_function: AggregateFunction,
94-
_info: &dyn SimplifyInfo,
95-
) -> Result<Transformed<Expr>> {
91+
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
9692
// as an example for this functionality we replace UDF function
9793
// with build-in aggregate function to illustrate the use
98-
let expr = Expr::AggregateFunction(AggregateFunction {
99-
func_def: AggregateFunctionDefinition::BuiltIn(
100-
// yes it is the same Avg, `BetterAvgUdaf` was just a
101-
// marketing pitch :)
102-
datafusion_expr::aggregate_function::AggregateFunction::Avg,
103-
),
104-
args: aggregate_function.args,
105-
distinct: aggregate_function.distinct,
106-
filter: aggregate_function.filter,
107-
order_by: aggregate_function.order_by,
108-
null_treatment: aggregate_function.null_treatment,
109-
});
110-
111-
Ok(Transformed::yes(expr))
94+
let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction,
95+
_: &dyn SimplifyInfo| {
96+
Ok(Expr::AggregateFunction(AggregateFunction {
97+
func_def: AggregateFunctionDefinition::BuiltIn(
98+
// yes it is the same Avg, `BetterAvgUdaf` was just a
99+
// marketing pitch :)
100+
datafusion_expr::aggregate_function::AggregateFunction::Avg,
101+
),
102+
args: aggregate_function.args,
103+
distinct: aggregate_function.distinct,
104+
filter: aggregate_function.filter,
105+
order_by: aggregate_function.order_by,
106+
null_treatment: aggregate_function.null_treatment,
107+
}))
108+
};
109+
110+
Some(Box::new(simplify))
112111
}
113112
}
114113

datafusion/expr/src/function.rs

+13
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,16 @@ pub type PartitionEvaluatorFactory =
9797
/// its state, given its return datatype.
9898
pub type StateTypeFunction =
9999
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;
100+
101+
/// [crate::udaf::AggregateUDFImpl::simplify] simplifier closure
102+
/// A closure with two arguments:
103+
/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
104+
/// * 'info': [crate::simplify::SimplifyInfo]
105+
///
106+
/// closure returns simplified [Expr] or an error.
107+
pub type AggregateFunctionSimplification = Box<
108+
dyn Fn(
109+
crate::expr::AggregateFunction,
110+
&dyn crate::simplify::SimplifyInfo,
111+
) -> Result<Expr>,
112+
>;

datafusion/expr/src/udaf.rs

+15-26
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,12 @@
1717

1818
//! [`AggregateUDF`]: User Defined Aggregate Functions
1919
20-
use crate::expr::AggregateFunction;
21-
use crate::function::AccumulatorArgs;
20+
use crate::function::{AccumulatorArgs, AggregateFunctionSimplification};
2221
use crate::groups_accumulator::GroupsAccumulator;
23-
use crate::simplify::SimplifyInfo;
2422
use crate::utils::format_state_name;
2523
use crate::{Accumulator, Expr};
2624
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
2725
use arrow::datatypes::{DataType, Field};
28-
use datafusion_common::tree_node::Transformed;
2926
use datafusion_common::{not_impl_err, Result};
3027
use std::any::Any;
3128
use std::fmt::{self, Debug, Formatter};
@@ -205,12 +202,8 @@ impl AggregateUDF {
205202
/// Do the function rewrite
206203
///
207204
/// See [`AggregateUDFImpl::simplify`] for more details.
208-
pub fn simplify(
209-
&self,
210-
aggregate_function: AggregateFunction,
211-
info: &dyn SimplifyInfo,
212-
) -> Result<Transformed<Expr>> {
213-
self.inner.simplify(aggregate_function, info)
205+
pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
206+
self.inner.simplify()
214207
}
215208
}
216209

@@ -372,7 +365,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
372365
&[]
373366
}
374367

375-
/// Optionally apply per-UDF simplification / rewrite rules.
368+
/// Optionally apply per-UDaF simplification / rewrite rules.
376369
///
377370
/// This can be used to apply function specific simplification rules during
378371
/// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
@@ -383,22 +376,18 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
383376
/// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
384377
/// optimizations manually for specific UDFs.
385378
///
386-
/// # Arguments
387-
/// * 'aggregate_function': Aggregate function to be simplified
388-
/// * 'info': Simplification information
389-
///
390379
/// # Returns
391-
/// [`Transformed`] indicating the result of the simplification NOTE
392-
/// if the function cannot be simplified, [Expr::AggregateFunction] with unmodified [AggregateFunction]
393-
/// should be returned
394-
fn simplify(
395-
&self,
396-
aggregate_function: AggregateFunction,
397-
_info: &dyn SimplifyInfo,
398-
) -> Result<Transformed<Expr>> {
399-
Ok(Transformed::yes(Expr::AggregateFunction(
400-
aggregate_function,
401-
)))
380+
///
381+
/// [None] if simplify is not defined or,
382+
///
383+
/// Or, a closure with two arguments:
384+
/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
385+
/// * 'info': [crate::simplify::SimplifyInfo]
386+
///
387+
/// closure returns simplified [Expr] or an error.
388+
///
389+
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
390+
None
402391
}
403392
}
404393

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

+11-15
Original file line numberDiff line numberDiff line change
@@ -1385,14 +1385,12 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
13851385
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
13861386
func_def: AggregateFunctionDefinition::UDF(ref udaf),
13871387
..
1388-
}) => {
1389-
let udaf = udaf.clone();
1390-
if let Expr::AggregateFunction(aggregate_function) = expr {
1391-
udaf.simplify(aggregate_function, info)?
1392-
} else {
1393-
unreachable!("this branch should be unreachable")
1388+
}) => match (udaf.simplify(), expr) {
1389+
(Some(simplify_function), Expr::AggregateFunction(af)) => {
1390+
Transformed::yes(simplify_function(af, info)?)
13941391
}
1395-
}
1392+
(_, expr) => Transformed::no(expr),
1393+
},
13961394

13971395
//
13981396
// Rules for Between
@@ -1760,7 +1758,9 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {
17601758
#[cfg(test)]
17611759
mod tests {
17621760
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
1763-
use datafusion_expr::{interval_arithmetic::Interval, *};
1761+
use datafusion_expr::{
1762+
function::AggregateFunctionSimplification, interval_arithmetic::Interval, *,
1763+
};
17641764
use std::{
17651765
collections::HashMap,
17661766
ops::{BitAnd, BitOr, BitXor},
@@ -3791,15 +3791,11 @@ mod tests {
37913791
unimplemented!("not needed for testing")
37923792
}
37933793

3794-
fn simplify(
3795-
&self,
3796-
aggregate_function: datafusion_expr::expr::AggregateFunction,
3797-
_info: &dyn SimplifyInfo,
3798-
) -> Result<Transformed<Expr>> {
3794+
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
37993795
if self.simplify {
3800-
Ok(Transformed::yes(col("result_column")))
3796+
Some(Box::new(|_, _| Ok(col("result_column"))))
38013797
} else {
3802-
Ok(Transformed::no(Expr::AggregateFunction(aggregate_function)))
3798+
None
38033799
}
38043800
}
38053801
}

0 commit comments

Comments
 (0)