Skip to content

Commit 62d381e

Browse files
committed
add simplify method for aggregate function
1 parent 9fd697c commit 62d381e

File tree

3 files changed

+366
-1
lines changed

3 files changed

+366
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow_schema::{Field, Schema};
19+
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
20+
use datafusion_expr::simplify::SimplifyInfo;
21+
22+
use std::{any::Any, sync::Arc};
23+
24+
use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
25+
use datafusion::error::Result;
26+
use datafusion::{assert_batches_eq, prelude::*};
27+
use datafusion_common::cast::as_float64_array;
28+
use datafusion_expr::{
29+
expr::{AggregateFunction, AggregateFunctionDefinition},
30+
function::AccumulatorArgs,
31+
simplify::ExprSimplifyResult,
32+
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
33+
};
34+
35+
/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
36+
/// defined aggregate function with a different expression which is defined in the `simplify` method.
37+
38+
#[derive(Debug, Clone)]
39+
struct BetterAvgUdaf {
40+
signature: Signature,
41+
}
42+
43+
impl BetterAvgUdaf {
44+
/// Create a new instance of the GeoMeanUdaf struct
45+
fn new() -> Self {
46+
Self {
47+
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
48+
}
49+
}
50+
}
51+
52+
impl AggregateUDFImpl for BetterAvgUdaf {
53+
fn as_any(&self) -> &dyn Any {
54+
self
55+
}
56+
57+
fn name(&self) -> &str {
58+
"better_avg"
59+
}
60+
61+
fn signature(&self) -> &Signature {
62+
&self.signature
63+
}
64+
65+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
66+
Ok(DataType::Float64)
67+
}
68+
69+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
70+
unimplemented!("should not be invoked")
71+
}
72+
73+
fn state_fields(
74+
&self,
75+
_name: &str,
76+
_value_type: DataType,
77+
_ordering_fields: Vec<arrow_schema::Field>,
78+
) -> Result<Vec<arrow_schema::Field>> {
79+
unimplemented!("should not be invoked")
80+
}
81+
82+
fn groups_accumulator_supported(&self) -> bool {
83+
true
84+
}
85+
86+
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
87+
unimplemented!("should not get here");
88+
}
89+
// we override method, to return new expression which would substitute
90+
// user defined function call
91+
fn simplify(
92+
&self,
93+
args: Vec<Expr>,
94+
distinct: &bool,
95+
filter: &Option<Box<Expr>>,
96+
order_by: &Option<Vec<Expr>>,
97+
null_treatment: &Option<datafusion_sql::sqlparser::ast::NullTreatment>,
98+
_info: &dyn SimplifyInfo,
99+
) -> Result<ExprSimplifyResult> {
100+
// as an example for this functionality we replace UDF function
101+
// with build-in aggregate function to illustrate the use
102+
let expr = Expr::AggregateFunction(AggregateFunction {
103+
func_def: AggregateFunctionDefinition::BuiltIn(
104+
// yes it is the same Avg, `BetterAvgUdaf` was just a
105+
// marketing pitch :)
106+
datafusion_expr::aggregate_function::AggregateFunction::Avg,
107+
),
108+
args,
109+
distinct: *distinct,
110+
filter: filter.clone(),
111+
order_by: order_by.clone(),
112+
null_treatment: *null_treatment,
113+
});
114+
115+
Ok(ExprSimplifyResult::Simplified(expr))
116+
}
117+
}
118+
119+
// create local session context with an in-memory table
120+
fn create_context() -> Result<SessionContext> {
121+
use datafusion::datasource::MemTable;
122+
// define a schema.
123+
let schema = Arc::new(Schema::new(vec![
124+
Field::new("a", DataType::Float32, false),
125+
Field::new("b", DataType::Float32, false),
126+
]));
127+
128+
// define data in two partitions
129+
let batch1 = RecordBatch::try_new(
130+
schema.clone(),
131+
vec![
132+
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
133+
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
134+
],
135+
)?;
136+
let batch2 = RecordBatch::try_new(
137+
schema.clone(),
138+
vec![
139+
Arc::new(Float32Array::from(vec![16.0])),
140+
Arc::new(Float32Array::from(vec![2.0])),
141+
],
142+
)?;
143+
144+
let ctx = SessionContext::new();
145+
146+
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
147+
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
148+
ctx.register_table("t", Arc::new(provider))?;
149+
Ok(ctx)
150+
}
151+
152+
#[tokio::main]
153+
async fn main() -> Result<()> {
154+
let ctx = create_context()?;
155+
156+
let better_avg = AggregateUDF::from(BetterAvgUdaf::new());
157+
ctx.register_udaf(better_avg.clone());
158+
159+
let result = ctx
160+
.sql("SELECT better_avg(a) FROM t group by b")
161+
.await?
162+
.collect()
163+
.await?;
164+
165+
let expected = [
166+
"+-----------------+",
167+
"| better_avg(t.a) |",
168+
"+-----------------+",
169+
"| 7.5 |",
170+
"+-----------------+",
171+
];
172+
173+
assert_batches_eq!(expected, &result);
174+
175+
let df = ctx.table("t").await?;
176+
let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?;
177+
178+
let results = df.collect().await?;
179+
let result = as_float64_array(results[0].column(0))?;
180+
181+
assert!((result.value(0) - 7.5).abs() < f64::EPSILON);
182+
println!("The average of [2,4,8,16] is {}", result.value(0));
183+
184+
Ok(())
185+
}

datafusion/expr/src/udaf.rs

+48
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
2020
use crate::function::AccumulatorArgs;
2121
use crate::groups_accumulator::GroupsAccumulator;
22+
use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
2223
use crate::utils::format_state_name;
2324
use crate::{Accumulator, Expr};
2425
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
2526
use arrow::datatypes::{DataType, Field};
2627
use datafusion_common::{not_impl_err, Result};
28+
use sqlparser::ast::NullTreatment;
2729
use std::any::Any;
2830
use std::fmt::{self, Debug, Formatter};
2931
use std::sync::Arc;
@@ -195,6 +197,21 @@ impl AggregateUDF {
195197
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
196198
self.inner.create_groups_accumulator()
197199
}
200+
/// Do the function rewrite
201+
///
202+
/// See [`AggregateUDFImpl::simplify`] for more details.
203+
pub fn simplify(
204+
&self,
205+
args: Vec<Expr>,
206+
distinct: &bool,
207+
filter: &Option<Box<Expr>>,
208+
order_by: &Option<Vec<Expr>>,
209+
null_treatment: &Option<NullTreatment>,
210+
info: &dyn SimplifyInfo,
211+
) -> Result<ExprSimplifyResult> {
212+
self.inner
213+
.simplify(args, distinct, filter, order_by, null_treatment, info)
214+
}
198215
}
199216

200217
impl<F> From<F> for AggregateUDF
@@ -354,6 +371,37 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
354371
fn aliases(&self) -> &[String] {
355372
&[]
356373
}
374+
375+
/// Optionally apply per-UDF simplification / rewrite rules.
376+
///
377+
/// This can be used to apply function specific simplification rules during
378+
/// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
379+
/// implementation does nothing.
380+
///
381+
/// Note that DataFusion handles simplifying arguments and "constant
382+
/// folding" (replacing a function call with constant arguments such as
383+
/// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
384+
/// optimizations manually for specific UDFs.
385+
///
386+
/// # Arguments
387+
/// * 'args': The arguments of the function
388+
/// * 'schema': The schema of the function
389+
///
390+
/// # Returns
391+
/// [`ExprSimplifyResult`] indicating the result of the simplification NOTE
392+
/// if the function cannot be simplified, the arguments *MUST* be returned
393+
/// unmodified
394+
fn simplify(
395+
&self,
396+
args: Vec<Expr>,
397+
_distinct: &bool,
398+
_filter: &Option<Box<Expr>>,
399+
_order_by: &Option<Vec<Expr>>,
400+
_null_treatment: &Option<NullTreatment>,
401+
_info: &dyn SimplifyInfo,
402+
) -> Result<ExprSimplifyResult> {
403+
Ok(ExprSimplifyResult::Original(args))
404+
}
357405
}
358406

359407
/// AggregateUDF that adds an alias to the underlying function. It is better to

0 commit comments

Comments
 (0)