diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs new file mode 100644 index 000000000000..92deb20272e4 --- /dev/null +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -0,0 +1,180 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::{Field, Schema}; +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use datafusion_expr::function::AggregateFunctionSimplification; +use datafusion_expr::simplify::SimplifyInfo; + +use std::{any::Any, sync::Arc}; + +use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch}; +use datafusion::error::Result; +use datafusion::{assert_batches_eq, prelude::*}; +use datafusion_common::cast::as_float64_array; +use datafusion_expr::{ + expr::{AggregateFunction, AggregateFunctionDefinition}, + function::AccumulatorArgs, + Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, +}; + +/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user +/// defined aggregate function with a different expression which is defined in the `simplify` method. + +#[derive(Debug, Clone)] +struct BetterAvgUdaf { + signature: Signature, +} + +impl BetterAvgUdaf { + /// Create a new instance of the GeoMeanUdaf struct + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for BetterAvgUdaf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "better_avg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + unimplemented!("should not be invoked") + } + + fn state_fields( + &self, + _name: &str, + _value_type: DataType, + _ordering_fields: Vec, + ) -> Result> { + unimplemented!("should not be invoked") + } + + fn groups_accumulator_supported(&self) -> bool { + true + } + + fn create_groups_accumulator(&self) -> Result> { + unimplemented!("should not get here"); + } + // we override method, to return new expression which would substitute + // user defined function call + fn simplify(&self) -> Option { + // as an example for this functionality we replace UDF function + // with build-in aggregate function to illustrate the use + let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction, + _: &dyn SimplifyInfo| { + Ok(Expr::AggregateFunction(AggregateFunction { + func_def: AggregateFunctionDefinition::BuiltIn( + // yes it is the same Avg, `BetterAvgUdaf` was just a + // marketing pitch :) + datafusion_expr::aggregate_function::AggregateFunction::Avg, + ), + args: aggregate_function.args, + distinct: aggregate_function.distinct, + filter: aggregate_function.filter, + order_by: aggregate_function.order_by, + null_treatment: aggregate_function.null_treatment, + })) + }; + + Some(Box::new(simplify)) + } +} + +// create local session context with an in-memory table +fn create_context() -> Result { + use datafusion::datasource::MemTable; + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float32, false), + ])); + + // define data in two partitions + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])), + Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])), + ], + )?; + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![16.0])), + Arc::new(Float32Array::from(vec![2.0])), + ], + )?; + + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Arc::new(provider))?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context()?; + + let better_avg = AggregateUDF::from(BetterAvgUdaf::new()); + ctx.register_udaf(better_avg.clone()); + + let result = ctx + .sql("SELECT better_avg(a) FROM t group by b") + .await? + .collect() + .await?; + + let expected = [ + "+-----------------+", + "| better_avg(t.a) |", + "+-----------------+", + "| 7.5 |", + "+-----------------+", + ]; + + assert_batches_eq!(expected, &result); + + let df = ctx.table("t").await?; + let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?; + + let results = df.collect().await?; + let result = as_float64_array(results[0].column(0))?; + + assert!((result.value(0) - 7.5).abs() < f64::EPSILON); + println!("The average of [2,4,8,16] is {}", result.value(0)); + + Ok(()) +} diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 7a92a50ae15d..4e4d77924a9d 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -97,3 +97,16 @@ pub type PartitionEvaluatorFactory = /// its state, given its return datatype. pub type StateTypeFunction = Arc Result>> + Send + Sync>; + +/// [crate::udaf::AggregateUDFImpl::simplify] simplifier closure +/// A closure with two arguments: +/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked +/// * 'info': [crate::simplify::SimplifyInfo] +/// +/// closure returns simplified [Expr] or an error. +pub type AggregateFunctionSimplification = Box< + dyn Fn( + crate::expr::AggregateFunction, + &dyn crate::simplify::SimplifyInfo, + ) -> Result, +>; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index e5a47ddcd8b6..95121d78e7aa 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,7 +17,7 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions -use crate::function::AccumulatorArgs; +use crate::function::{AccumulatorArgs, AggregateFunctionSimplification}; use crate::groups_accumulator::GroupsAccumulator; use crate::utils::format_state_name; use crate::{Accumulator, Expr}; @@ -199,6 +199,12 @@ impl AggregateUDF { pub fn coerce_types(&self, _args: &[DataType]) -> Result> { not_impl_err!("coerce_types not implemented for {:?} yet", self.name()) } + /// Do the function rewrite + /// + /// See [`AggregateUDFImpl::simplify`] for more details. + pub fn simplify(&self) -> Option { + self.inner.simplify() + } } impl From for AggregateUDF @@ -358,6 +364,31 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn aliases(&self) -> &[String] { &[] } + + /// Optionally apply per-UDaF simplification / rewrite rules. + /// + /// This can be used to apply function specific simplification rules during + /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default + /// implementation does nothing. + /// + /// Note that DataFusion handles simplifying arguments and "constant + /// folding" (replacing a function call with constant arguments such as + /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such + /// optimizations manually for specific UDFs. + /// + /// # Returns + /// + /// [None] if simplify is not defined or, + /// + /// Or, a closure with two arguments: + /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked + /// * 'info': [crate::simplify::SimplifyInfo] + /// + /// closure returns simplified [Expr] or an error. + /// + fn simplify(&self) -> Option { + None + } } /// AggregateUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 5122de4f09a7..55052542a8bf 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -32,7 +32,7 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::{InList, InSubquery}; +use datafusion_expr::expr::{AggregateFunctionDefinition, InList, InSubquery}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, @@ -1382,6 +1382,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction { + func_def: AggregateFunctionDefinition::UDF(ref udaf), + .. + }) => match (udaf.simplify(), expr) { + (Some(simplify_function), Expr::AggregateFunction(af)) => { + Transformed::yes(simplify_function(af, info)?) + } + (_, expr) => Transformed::no(expr), + }, + // // Rules for Between // @@ -1748,7 +1758,9 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result { #[cfg(test)] mod tests { use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; - use datafusion_expr::{interval_arithmetic::Interval, *}; + use datafusion_expr::{ + function::AggregateFunctionSimplification, interval_arithmetic::Interval, *, + }; use std::{ collections::HashMap, ops::{BitAnd, BitOr, BitXor}, @@ -3698,4 +3710,93 @@ mod tests { assert_eq!(expr, expected); assert_eq!(num_iter, 2); } + #[test] + fn test_simplify_udaf() { + let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify()); + let aggregate_function_expr = + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + udaf.into(), + vec![], + false, + None, + None, + None, + )); + + let expected = col("result_column"); + assert_eq!(simplify(aggregate_function_expr), expected); + + let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify()); + let aggregate_function_expr = + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + udaf.into(), + vec![], + false, + None, + None, + None, + )); + + let expected = aggregate_function_expr.clone(); + assert_eq!(simplify(aggregate_function_expr), expected); + } + + /// A Mock UDAF which defines `simplify` to be used in tests + /// related to UDAF simplification + #[derive(Debug, Clone)] + struct SimplifyMockUdaf { + simplify: bool, + } + + impl SimplifyMockUdaf { + /// make simplify method return new expression + fn new_with_simplify() -> Self { + Self { simplify: true } + } + /// make simplify method return no change + fn new_without_simplify() -> Self { + Self { simplify: false } + } + } + + impl AggregateUDFImpl for SimplifyMockUdaf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "mock_simplify" + } + + fn signature(&self) -> &Signature { + unimplemented!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!("not needed for tests") + } + + fn accumulator( + &self, + _acc_args: function::AccumulatorArgs, + ) -> Result> { + unimplemented!("not needed for tests") + } + + fn groups_accumulator_supported(&self) -> bool { + unimplemented!("not needed for testing") + } + + fn create_groups_accumulator(&self) -> Result> { + unimplemented!("not needed for testing") + } + + fn simplify(&self) -> Option { + if self.simplify { + Some(Box::new(|_, _| Ok(col("result_column")))) + } else { + None + } + } + } }