Skip to content

Commit a678e6d

Browse files
committed
simplify returns closure
1 parent de51434 commit a678e6d

File tree

4 files changed

+58
-61
lines changed

4 files changed

+58
-61
lines changed

datafusion-examples/examples/simplify_udaf_expression.rs

+19-20
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use arrow_schema::{Field, Schema};
1919
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
20-
use datafusion_common::tree_node::Transformed;
20+
use datafusion_expr::function::AggregateFunctionSimplification;
2121
use datafusion_expr::simplify::SimplifyInfo;
2222

2323
use std::{any::Any, sync::Arc};
@@ -88,27 +88,26 @@ impl AggregateUDFImpl for BetterAvgUdaf {
8888
}
8989
// we override method, to return new expression which would substitute
9090
// user defined function call
91-
fn simplify(
92-
&self,
93-
aggregate_function: AggregateFunction,
94-
_info: &dyn SimplifyInfo,
95-
) -> Result<Transformed<Expr>> {
91+
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
9692
// as an example for this functionality we replace UDF function
9793
// with build-in aggregate function to illustrate the use
98-
let expr = Expr::AggregateFunction(AggregateFunction {
99-
func_def: AggregateFunctionDefinition::BuiltIn(
100-
// yes it is the same Avg, `BetterAvgUdaf` was just a
101-
// marketing pitch :)
102-
datafusion_expr::aggregate_function::AggregateFunction::Avg,
103-
),
104-
args: aggregate_function.args,
105-
distinct: aggregate_function.distinct,
106-
filter: aggregate_function.filter,
107-
order_by: aggregate_function.order_by,
108-
null_treatment: aggregate_function.null_treatment,
109-
});
110-
111-
Ok(Transformed::yes(expr))
94+
let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction,
95+
_: &dyn SimplifyInfo| {
96+
Ok(Expr::AggregateFunction(AggregateFunction {
97+
func_def: AggregateFunctionDefinition::BuiltIn(
98+
// yes it is the same Avg, `BetterAvgUdaf` was just a
99+
// marketing pitch :)
100+
datafusion_expr::aggregate_function::AggregateFunction::Avg,
101+
),
102+
args: aggregate_function.args,
103+
distinct: aggregate_function.distinct,
104+
filter: aggregate_function.filter,
105+
order_by: aggregate_function.order_by,
106+
null_treatment: aggregate_function.null_treatment,
107+
}))
108+
};
109+
110+
Some(Box::new(simplify))
112111
}
113112
}
114113

datafusion/expr/src/function.rs

+13
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,16 @@ pub type PartitionEvaluatorFactory =
9797
/// its state, given its return datatype.
9898
pub type StateTypeFunction =
9999
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;
100+
101+
/// [crate::udaf::AggregateUDFImpl::simplify] simplifier closure
102+
/// A closure with two arguments:
103+
/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
104+
/// * 'info': [crate::simplify::SimplifyInfo]
105+
///
106+
/// closure returns simplified [Expr] or an error.
107+
pub type AggregateFunctionSimplification = Box<
108+
dyn Fn(
109+
crate::expr::AggregateFunction,
110+
&dyn crate::simplify::SimplifyInfo,
111+
) -> Result<Expr>,
112+
>;

datafusion/expr/src/udaf.rs

+15-26
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,12 @@
1717

1818
//! [`AggregateUDF`]: User Defined Aggregate Functions
1919
20-
use crate::expr::AggregateFunction;
21-
use crate::function::AccumulatorArgs;
20+
use crate::function::{AccumulatorArgs, AggregateFunctionSimplification};
2221
use crate::groups_accumulator::GroupsAccumulator;
23-
use crate::simplify::SimplifyInfo;
2422
use crate::utils::format_state_name;
2523
use crate::{Accumulator, Expr};
2624
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
2725
use arrow::datatypes::{DataType, Field};
28-
use datafusion_common::tree_node::Transformed;
2926
use datafusion_common::{not_impl_err, Result};
3027
use std::any::Any;
3128
use std::fmt::{self, Debug, Formatter};
@@ -201,12 +198,8 @@ impl AggregateUDF {
201198
/// Do the function rewrite
202199
///
203200
/// See [`AggregateUDFImpl::simplify`] for more details.
204-
pub fn simplify(
205-
&self,
206-
aggregate_function: AggregateFunction,
207-
info: &dyn SimplifyInfo,
208-
) -> Result<Transformed<Expr>> {
209-
self.inner.simplify(aggregate_function, info)
201+
pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
202+
self.inner.simplify()
210203
}
211204
}
212205

@@ -368,7 +361,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
368361
&[]
369362
}
370363

371-
/// Optionally apply per-UDF simplification / rewrite rules.
364+
/// Optionally apply per-UDaF simplification / rewrite rules.
372365
///
373366
/// This can be used to apply function specific simplification rules during
374367
/// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
@@ -379,22 +372,18 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
379372
/// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
380373
/// optimizations manually for specific UDFs.
381374
///
382-
/// # Arguments
383-
/// * 'aggregate_function': Aggregate function to be simplified
384-
/// * 'info': Simplification information
385-
///
386375
/// # Returns
387-
/// [`Transformed`] indicating the result of the simplification NOTE
388-
/// if the function cannot be simplified, [Expr::AggregateFunction] with unmodified [AggregateFunction]
389-
/// should be returned
390-
fn simplify(
391-
&self,
392-
aggregate_function: AggregateFunction,
393-
_info: &dyn SimplifyInfo,
394-
) -> Result<Transformed<Expr>> {
395-
Ok(Transformed::yes(Expr::AggregateFunction(
396-
aggregate_function,
397-
)))
376+
///
377+
/// [None] if simplify is not defined or,
378+
///
379+
/// Or, a closure with two arguments:
380+
/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
381+
/// * 'info': [crate::simplify::SimplifyInfo]
382+
///
383+
/// closure returns simplified [Expr] or an error.
384+
///
385+
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
386+
None
398387
}
399388
}
400389

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

+11-15
Original file line numberDiff line numberDiff line change
@@ -1385,14 +1385,12 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
13851385
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
13861386
func_def: AggregateFunctionDefinition::UDF(ref udaf),
13871387
..
1388-
}) => {
1389-
let udaf = udaf.clone();
1390-
if let Expr::AggregateFunction(aggregate_function) = expr {
1391-
udaf.simplify(aggregate_function, info)?
1392-
} else {
1393-
unreachable!("this branch should be unreachable")
1388+
}) => match (udaf.simplify(), expr) {
1389+
(Some(simplify_function), Expr::AggregateFunction(af)) => {
1390+
Transformed::yes(simplify_function(af, info)?)
13941391
}
1395-
}
1392+
(_, expr) => Transformed::no(expr),
1393+
},
13961394

13971395
//
13981396
// Rules for Between
@@ -1760,7 +1758,9 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {
17601758
#[cfg(test)]
17611759
mod tests {
17621760
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
1763-
use datafusion_expr::{interval_arithmetic::Interval, *};
1761+
use datafusion_expr::{
1762+
function::AggregateFunctionSimplification, interval_arithmetic::Interval, *,
1763+
};
17641764
use std::{
17651765
collections::HashMap,
17661766
ops::{BitAnd, BitOr, BitXor},
@@ -3791,15 +3791,11 @@ mod tests {
37913791
unimplemented!("not needed for testing")
37923792
}
37933793

3794-
fn simplify(
3795-
&self,
3796-
aggregate_function: datafusion_expr::expr::AggregateFunction,
3797-
_info: &dyn SimplifyInfo,
3798-
) -> Result<Transformed<Expr>> {
3794+
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
37993795
if self.simplify {
3800-
Ok(Transformed::yes(col("result_column")))
3796+
Some(Box::new(|_, _| Ok(col("result_column"))))
38013797
} else {
3802-
Ok(Transformed::no(Expr::AggregateFunction(aggregate_function)))
3798+
None
38033799
}
38043800
}
38053801
}

0 commit comments

Comments
 (0)