Skip to content

Commit 3d00760

Browse files
authored
feature: Add a WindowUDFImpl::simplify() API (#9906)
* feature: Add a WindowUDFImpl::simplfy() API Signed-off-by: guojidan <[email protected]> * fix doc Signed-off-by: guojidan <[email protected]> * fix fmt Signed-off-by: guojidan <[email protected]> --------- Signed-off-by: guojidan <[email protected]>
1 parent 77352b2 commit 3d00760

File tree

4 files changed

+288
-4
lines changed

4 files changed

+288
-4
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::any::Any;
19+
20+
use arrow_schema::DataType;
21+
use datafusion::execution::context::SessionContext;
22+
use datafusion::{error::Result, execution::options::CsvReadOptions};
23+
use datafusion_expr::function::WindowFunctionSimplification;
24+
use datafusion_expr::{
25+
expr::WindowFunction, simplify::SimplifyInfo, AggregateFunction, Expr,
26+
PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
27+
};
28+
29+
/// This UDWF will show how to use the WindowUDFImpl::simplify() API
30+
#[derive(Debug, Clone)]
31+
struct SimplifySmoothItUdf {
32+
signature: Signature,
33+
}
34+
35+
impl SimplifySmoothItUdf {
36+
fn new() -> Self {
37+
Self {
38+
signature: Signature::exact(
39+
// this function will always take one arguments of type f64
40+
vec![DataType::Float64],
41+
// this function is deterministic and will always return the same
42+
// result for the same input
43+
Volatility::Immutable,
44+
),
45+
}
46+
}
47+
}
48+
impl WindowUDFImpl for SimplifySmoothItUdf {
49+
fn as_any(&self) -> &dyn Any {
50+
self
51+
}
52+
53+
fn name(&self) -> &str {
54+
"simplify_smooth_it"
55+
}
56+
57+
fn signature(&self) -> &Signature {
58+
&self.signature
59+
}
60+
61+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
62+
Ok(DataType::Float64)
63+
}
64+
65+
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
66+
todo!()
67+
}
68+
69+
/// this function will simplify `SimplifySmoothItUdf` to `SmoothItUdf`.
70+
fn simplify(&self) -> Option<WindowFunctionSimplification> {
71+
// Ok(ExprSimplifyResult::Simplified(Expr::WindowFunction(
72+
// WindowFunction {
73+
// fun: datafusion_expr::WindowFunctionDefinition::AggregateFunction(
74+
// AggregateFunction::Avg,
75+
// ),
76+
// args,
77+
// partition_by: partition_by.to_vec(),
78+
// order_by: order_by.to_vec(),
79+
// window_frame: window_frame.clone(),
80+
// null_treatment: *null_treatment,
81+
// },
82+
// )))
83+
let simplify = |window_function: datafusion_expr::expr::WindowFunction,
84+
_: &dyn SimplifyInfo| {
85+
Ok(Expr::WindowFunction(WindowFunction {
86+
fun: datafusion_expr::WindowFunctionDefinition::AggregateFunction(
87+
AggregateFunction::Avg,
88+
),
89+
args: window_function.args,
90+
partition_by: window_function.partition_by,
91+
order_by: window_function.order_by,
92+
window_frame: window_function.window_frame,
93+
null_treatment: window_function.null_treatment,
94+
}))
95+
};
96+
97+
Some(Box::new(simplify))
98+
}
99+
}
100+
101+
// create local execution context with `cars.csv` registered as a table named `cars`
102+
async fn create_context() -> Result<SessionContext> {
103+
// declare a new context. In spark API, this corresponds to a new spark SQL session
104+
let ctx = SessionContext::new();
105+
106+
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
107+
println!("pwd: {}", std::env::current_dir().unwrap().display());
108+
let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string();
109+
let read_options = CsvReadOptions::default().has_header(true);
110+
111+
ctx.register_csv("cars", &csv_path, read_options).await?;
112+
Ok(ctx)
113+
}
114+
115+
#[tokio::main]
116+
async fn main() -> Result<()> {
117+
let ctx = create_context().await?;
118+
let simplify_smooth_it = WindowUDF::from(SimplifySmoothItUdf::new());
119+
ctx.register_udwf(simplify_smooth_it.clone());
120+
121+
// Use SQL to run the new window function
122+
let df = ctx.sql("SELECT * from cars").await?;
123+
// print the results
124+
df.show().await?;
125+
126+
let df = ctx
127+
.sql(
128+
"SELECT \
129+
car, \
130+
speed, \
131+
simplify_smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\
132+
time \
133+
from cars \
134+
ORDER BY \
135+
car",
136+
)
137+
.await?;
138+
// print the results
139+
df.show().await?;
140+
141+
Ok(())
142+
}

datafusion/expr/src/function.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,16 @@ pub type AggregateFunctionSimplification = Box<
134134
&dyn crate::simplify::SimplifyInfo,
135135
) -> Result<Expr>,
136136
>;
137+
138+
/// [crate::udwf::WindowUDFImpl::simplify] simplifier closure
139+
/// A closure with two arguments:
140+
/// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked
141+
/// * 'info': [crate::simplify::SimplifyInfo]
142+
///
143+
/// closure returns simplified [Expr] or an error.
144+
pub type WindowFunctionSimplification = Box<
145+
dyn Fn(
146+
crate::expr::WindowFunction,
147+
&dyn crate::simplify::SimplifyInfo,
148+
) -> Result<Expr>,
149+
>;

datafusion/expr/src/udwf.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
//! [`WindowUDF`]: User Defined Window Functions
1919
2020
use crate::{
21-
Expr, PartitionEvaluator, PartitionEvaluatorFactory, ReturnTypeFunction, Signature,
22-
WindowFrame,
21+
function::WindowFunctionSimplification, Expr, PartitionEvaluator,
22+
PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame,
2323
};
2424
use arrow::datatypes::DataType;
2525
use datafusion_common::Result;
@@ -170,6 +170,13 @@ impl WindowUDF {
170170
self.inner.return_type(args)
171171
}
172172

173+
/// Do the function rewrite
174+
///
175+
/// See [`WindowUDFImpl::simplify`] for more details.
176+
pub fn simplify(&self) -> Option<WindowFunctionSimplification> {
177+
self.inner.simplify()
178+
}
179+
173180
/// Return a `PartitionEvaluator` for evaluating this window function
174181
pub fn partition_evaluator_factory(&self) -> Result<Box<dyn PartitionEvaluator>> {
175182
self.inner.partition_evaluator()
@@ -266,6 +273,29 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
266273
fn aliases(&self) -> &[String] {
267274
&[]
268275
}
276+
277+
/// Optionally apply per-UDWF simplification / rewrite rules.
278+
///
279+
/// This can be used to apply function specific simplification rules during
280+
/// optimization. The default implementation does nothing.
281+
///
282+
/// Note that DataFusion handles simplifying arguments and "constant
283+
/// folding" (replacing a function call with constant arguments such as
284+
/// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
285+
/// optimizations manually for specific UDFs.
286+
///
287+
/// Example:
288+
/// [`simplify_udwf_expression.rs`]: <https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simplify_udwf_expression.rs>
289+
///
290+
/// # Returns
291+
/// [None] if simplify is not defined or,
292+
///
293+
/// Or, a closure with two arguments:
294+
/// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked
295+
/// * 'info': [crate::simplify::SimplifyInfo]
296+
fn simplify(&self) -> Option<WindowFunctionSimplification> {
297+
None
298+
}
269299
}
270300

271301
/// WindowUDF that adds an alias to the underlying function. It is better to

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,13 @@ use datafusion_common::{
3232
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
3333
};
3434
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
35-
use datafusion_expr::expr::{AggregateFunctionDefinition, InList, InSubquery};
35+
use datafusion_expr::expr::{
36+
AggregateFunctionDefinition, InList, InSubquery, WindowFunction,
37+
};
3638
use datafusion_expr::simplify::ExprSimplifyResult;
3739
use datafusion_expr::{
3840
and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility,
41+
WindowFunctionDefinition,
3942
};
4043
use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval};
4144
use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};
@@ -1391,6 +1394,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
13911394
(_, expr) => Transformed::no(expr),
13921395
},
13931396

1397+
Expr::WindowFunction(WindowFunction {
1398+
fun: WindowFunctionDefinition::WindowUDF(ref udwf),
1399+
..
1400+
}) => match (udwf.simplify(), expr) {
1401+
(Some(simplify_function), Expr::WindowFunction(wf)) => {
1402+
Transformed::yes(simplify_function(wf, info)?)
1403+
}
1404+
(_, expr) => Transformed::no(expr),
1405+
},
1406+
13941407
//
13951408
// Rules for Between
13961409
//
@@ -1758,7 +1771,10 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {
17581771
mod tests {
17591772
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
17601773
use datafusion_expr::{
1761-
function::{AccumulatorArgs, AggregateFunctionSimplification},
1774+
function::{
1775+
AccumulatorArgs, AggregateFunctionSimplification,
1776+
WindowFunctionSimplification,
1777+
},
17621778
interval_arithmetic::Interval,
17631779
*,
17641780
};
@@ -3800,4 +3816,87 @@ mod tests {
38003816
}
38013817
}
38023818
}
3819+
3820+
#[test]
3821+
fn test_simplify_udwf() {
3822+
let udwf = WindowFunctionDefinition::WindowUDF(
3823+
WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(),
3824+
);
3825+
let window_function_expr =
3826+
Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new(
3827+
udwf,
3828+
vec![],
3829+
vec![],
3830+
vec![],
3831+
WindowFrame::new(None),
3832+
None,
3833+
));
3834+
3835+
let expected = col("result_column");
3836+
assert_eq!(simplify(window_function_expr), expected);
3837+
3838+
let udwf = WindowFunctionDefinition::WindowUDF(
3839+
WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(),
3840+
);
3841+
let window_function_expr =
3842+
Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new(
3843+
udwf,
3844+
vec![],
3845+
vec![],
3846+
vec![],
3847+
WindowFrame::new(None),
3848+
None,
3849+
));
3850+
3851+
let expected = window_function_expr.clone();
3852+
assert_eq!(simplify(window_function_expr), expected);
3853+
}
3854+
3855+
/// A Mock UDWF which defines `simplify` to be used in tests
3856+
/// related to UDWF simplification
3857+
#[derive(Debug, Clone)]
3858+
struct SimplifyMockUdwf {
3859+
simplify: bool,
3860+
}
3861+
3862+
impl SimplifyMockUdwf {
3863+
/// make simplify method return new expression
3864+
fn new_with_simplify() -> Self {
3865+
Self { simplify: true }
3866+
}
3867+
/// make simplify method return no change
3868+
fn new_without_simplify() -> Self {
3869+
Self { simplify: false }
3870+
}
3871+
}
3872+
3873+
impl WindowUDFImpl for SimplifyMockUdwf {
3874+
fn as_any(&self) -> &dyn std::any::Any {
3875+
self
3876+
}
3877+
3878+
fn name(&self) -> &str {
3879+
"mock_simplify"
3880+
}
3881+
3882+
fn signature(&self) -> &Signature {
3883+
unimplemented!()
3884+
}
3885+
3886+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3887+
unimplemented!("not needed for tests")
3888+
}
3889+
3890+
fn simplify(&self) -> Option<WindowFunctionSimplification> {
3891+
if self.simplify {
3892+
Some(Box::new(|_, _| Ok(col("result_column"))))
3893+
} else {
3894+
None
3895+
}
3896+
}
3897+
3898+
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
3899+
unimplemented!("not needed for tests")
3900+
}
3901+
}
38033902
}

0 commit comments

Comments
 (0)