Skip to content

Commit de51434

Browse files
committed
add simplify method for aggregate function
1 parent 96487ea commit de51434

File tree

3 files changed

+329
-1
lines changed

3 files changed

+329
-1
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow_schema::{Field, Schema};
19+
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
20+
use datafusion_common::tree_node::Transformed;
21+
use datafusion_expr::simplify::SimplifyInfo;
22+
23+
use std::{any::Any, sync::Arc};
24+
25+
use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
26+
use datafusion::error::Result;
27+
use datafusion::{assert_batches_eq, prelude::*};
28+
use datafusion_common::cast::as_float64_array;
29+
use datafusion_expr::{
30+
expr::{AggregateFunction, AggregateFunctionDefinition},
31+
function::AccumulatorArgs,
32+
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
33+
};
34+
35+
/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
36+
/// defined aggregate function with a different expression which is defined in the `simplify` method.
37+
38+
#[derive(Debug, Clone)]
39+
struct BetterAvgUdaf {
40+
signature: Signature,
41+
}
42+
43+
impl BetterAvgUdaf {
44+
/// Create a new instance of the GeoMeanUdaf struct
45+
fn new() -> Self {
46+
Self {
47+
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
48+
}
49+
}
50+
}
51+
52+
impl AggregateUDFImpl for BetterAvgUdaf {
53+
fn as_any(&self) -> &dyn Any {
54+
self
55+
}
56+
57+
fn name(&self) -> &str {
58+
"better_avg"
59+
}
60+
61+
fn signature(&self) -> &Signature {
62+
&self.signature
63+
}
64+
65+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
66+
Ok(DataType::Float64)
67+
}
68+
69+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
70+
unimplemented!("should not be invoked")
71+
}
72+
73+
fn state_fields(
74+
&self,
75+
_name: &str,
76+
_value_type: DataType,
77+
_ordering_fields: Vec<arrow_schema::Field>,
78+
) -> Result<Vec<arrow_schema::Field>> {
79+
unimplemented!("should not be invoked")
80+
}
81+
82+
fn groups_accumulator_supported(&self) -> bool {
83+
true
84+
}
85+
86+
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
87+
unimplemented!("should not get here");
88+
}
89+
// we override method, to return new expression which would substitute
90+
// user defined function call
91+
fn simplify(
92+
&self,
93+
aggregate_function: AggregateFunction,
94+
_info: &dyn SimplifyInfo,
95+
) -> Result<Transformed<Expr>> {
96+
// as an example for this functionality we replace UDF function
97+
// with build-in aggregate function to illustrate the use
98+
let expr = Expr::AggregateFunction(AggregateFunction {
99+
func_def: AggregateFunctionDefinition::BuiltIn(
100+
// yes it is the same Avg, `BetterAvgUdaf` was just a
101+
// marketing pitch :)
102+
datafusion_expr::aggregate_function::AggregateFunction::Avg,
103+
),
104+
args: aggregate_function.args,
105+
distinct: aggregate_function.distinct,
106+
filter: aggregate_function.filter,
107+
order_by: aggregate_function.order_by,
108+
null_treatment: aggregate_function.null_treatment,
109+
});
110+
111+
Ok(Transformed::yes(expr))
112+
}
113+
}
114+
115+
// create local session context with an in-memory table
116+
fn create_context() -> Result<SessionContext> {
117+
use datafusion::datasource::MemTable;
118+
// define a schema.
119+
let schema = Arc::new(Schema::new(vec![
120+
Field::new("a", DataType::Float32, false),
121+
Field::new("b", DataType::Float32, false),
122+
]));
123+
124+
// define data in two partitions
125+
let batch1 = RecordBatch::try_new(
126+
schema.clone(),
127+
vec![
128+
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
129+
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
130+
],
131+
)?;
132+
let batch2 = RecordBatch::try_new(
133+
schema.clone(),
134+
vec![
135+
Arc::new(Float32Array::from(vec![16.0])),
136+
Arc::new(Float32Array::from(vec![2.0])),
137+
],
138+
)?;
139+
140+
let ctx = SessionContext::new();
141+
142+
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
143+
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
144+
ctx.register_table("t", Arc::new(provider))?;
145+
Ok(ctx)
146+
}
147+
148+
#[tokio::main]
149+
async fn main() -> Result<()> {
150+
let ctx = create_context()?;
151+
152+
let better_avg = AggregateUDF::from(BetterAvgUdaf::new());
153+
ctx.register_udaf(better_avg.clone());
154+
155+
let result = ctx
156+
.sql("SELECT better_avg(a) FROM t group by b")
157+
.await?
158+
.collect()
159+
.await?;
160+
161+
let expected = [
162+
"+-----------------+",
163+
"| better_avg(t.a) |",
164+
"+-----------------+",
165+
"| 7.5 |",
166+
"+-----------------+",
167+
];
168+
169+
assert_batches_eq!(expected, &result);
170+
171+
let df = ctx.table("t").await?;
172+
let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?;
173+
174+
let results = df.collect().await?;
175+
let result = as_float64_array(results[0].column(0))?;
176+
177+
assert!((result.value(0) - 7.5).abs() < f64::EPSILON);
178+
println!("The average of [2,4,8,16] is {}", result.value(0));
179+
180+
Ok(())
181+
}

datafusion/expr/src/udaf.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717

1818
//! [`AggregateUDF`]: User Defined Aggregate Functions
1919
20+
use crate::expr::AggregateFunction;
2021
use crate::function::AccumulatorArgs;
2122
use crate::groups_accumulator::GroupsAccumulator;
23+
use crate::simplify::SimplifyInfo;
2224
use crate::utils::format_state_name;
2325
use crate::{Accumulator, Expr};
2426
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
2527
use arrow::datatypes::{DataType, Field};
28+
use datafusion_common::tree_node::Transformed;
2629
use datafusion_common::{not_impl_err, Result};
2730
use std::any::Any;
2831
use std::fmt::{self, Debug, Formatter};
@@ -195,6 +198,16 @@ impl AggregateUDF {
195198
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
196199
self.inner.create_groups_accumulator()
197200
}
201+
/// Do the function rewrite
202+
///
203+
/// See [`AggregateUDFImpl::simplify`] for more details.
204+
pub fn simplify(
205+
&self,
206+
aggregate_function: AggregateFunction,
207+
info: &dyn SimplifyInfo,
208+
) -> Result<Transformed<Expr>> {
209+
self.inner.simplify(aggregate_function, info)
210+
}
198211
}
199212

200213
impl<F> From<F> for AggregateUDF
@@ -354,6 +367,35 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
354367
fn aliases(&self) -> &[String] {
355368
&[]
356369
}
370+
371+
/// Optionally apply per-UDF simplification / rewrite rules.
372+
///
373+
/// This can be used to apply function specific simplification rules during
374+
/// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
375+
/// implementation does nothing.
376+
///
377+
/// Note that DataFusion handles simplifying arguments and "constant
378+
/// folding" (replacing a function call with constant arguments such as
379+
/// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
380+
/// optimizations manually for specific UDFs.
381+
///
382+
/// # Arguments
383+
/// * 'aggregate_function': Aggregate function to be simplified
384+
/// * 'info': Simplification information
385+
///
386+
/// # Returns
387+
/// [`Transformed`] indicating the result of the simplification NOTE
388+
/// if the function cannot be simplified, [Expr::AggregateFunction] with unmodified [AggregateFunction]
389+
/// should be returned
390+
fn simplify(
391+
&self,
392+
aggregate_function: AggregateFunction,
393+
_info: &dyn SimplifyInfo,
394+
) -> Result<Transformed<Expr>> {
395+
Ok(Transformed::yes(Expr::AggregateFunction(
396+
aggregate_function,
397+
)))
398+
}
357399
}
358400

359401
/// AggregateUDF that adds an alias to the underlying function. It is better to

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use datafusion_common::{
3232
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
3333
};
3434
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
35-
use datafusion_expr::expr::{InList, InSubquery};
35+
use datafusion_expr::expr::{AggregateFunctionDefinition, InList, InSubquery};
3636
use datafusion_expr::simplify::ExprSimplifyResult;
3737
use datafusion_expr::{
3838
and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility,
@@ -1382,6 +1382,18 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
13821382
}
13831383
}
13841384

1385+
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
1386+
func_def: AggregateFunctionDefinition::UDF(ref udaf),
1387+
..
1388+
}) => {
1389+
let udaf = udaf.clone();
1390+
if let Expr::AggregateFunction(aggregate_function) = expr {
1391+
udaf.simplify(aggregate_function, info)?
1392+
} else {
1393+
unreachable!("this branch should be unreachable")
1394+
}
1395+
}
1396+
13851397
//
13861398
// Rules for Between
13871399
//
@@ -3698,4 +3710,97 @@ mod tests {
36983710
assert_eq!(expr, expected);
36993711
assert_eq!(num_iter, 2);
37003712
}
3713+
#[test]
3714+
fn test_simplify_udaf() {
3715+
let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());
3716+
let aggregate_function_expr =
3717+
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
3718+
udaf.into(),
3719+
vec![],
3720+
false,
3721+
None,
3722+
None,
3723+
None,
3724+
));
3725+
3726+
let expected = col("result_column");
3727+
assert_eq!(simplify(aggregate_function_expr), expected);
3728+
3729+
let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify());
3730+
let aggregate_function_expr =
3731+
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
3732+
udaf.into(),
3733+
vec![],
3734+
false,
3735+
None,
3736+
None,
3737+
None,
3738+
));
3739+
3740+
let expected = aggregate_function_expr.clone();
3741+
assert_eq!(simplify(aggregate_function_expr), expected);
3742+
}
3743+
3744+
/// A Mock UDAF which defines `simplify` to be used in tests
3745+
/// related to UDAF simplification
3746+
#[derive(Debug, Clone)]
3747+
struct SimplifyMockUdaf {
3748+
simplify: bool,
3749+
}
3750+
3751+
impl SimplifyMockUdaf {
3752+
/// make simplify method return new expression
3753+
fn new_with_simplify() -> Self {
3754+
Self { simplify: true }
3755+
}
3756+
/// make simplify method return no change
3757+
fn new_without_simplify() -> Self {
3758+
Self { simplify: false }
3759+
}
3760+
}
3761+
3762+
impl AggregateUDFImpl for SimplifyMockUdaf {
3763+
fn as_any(&self) -> &dyn std::any::Any {
3764+
self
3765+
}
3766+
3767+
fn name(&self) -> &str {
3768+
"mock_simplify"
3769+
}
3770+
3771+
fn signature(&self) -> &Signature {
3772+
unimplemented!()
3773+
}
3774+
3775+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3776+
unimplemented!("not needed for tests")
3777+
}
3778+
3779+
fn accumulator(
3780+
&self,
3781+
_acc_args: function::AccumulatorArgs,
3782+
) -> Result<Box<dyn Accumulator>> {
3783+
unimplemented!("not needed for tests")
3784+
}
3785+
3786+
fn groups_accumulator_supported(&self) -> bool {
3787+
unimplemented!("not needed for testing")
3788+
}
3789+
3790+
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
3791+
unimplemented!("not needed for testing")
3792+
}
3793+
3794+
fn simplify(
3795+
&self,
3796+
aggregate_function: datafusion_expr::expr::AggregateFunction,
3797+
_info: &dyn SimplifyInfo,
3798+
) -> Result<Transformed<Expr>> {
3799+
if self.simplify {
3800+
Ok(Transformed::yes(col("result_column")))
3801+
} else {
3802+
Ok(Transformed::no(Expr::AggregateFunction(aggregate_function)))
3803+
}
3804+
}
3805+
}
37013806
}

0 commit comments

Comments
 (0)