Skip to content

Commit 886e8ac

Browse files
timsaucertimsaucer-mayalamb
authored
Consistent API to set parameters of aggregate and window functions (AggregateExt --> ExprFunctionExt) (#11550)
* Moving over AggregateExt to ExprFunctionExt and adding in function settings for window functions * Switch WindowFrame to only need the window function definition and arguments. Other parameters will be set via the ExprFuncBuilder * Changing null_treatment to take an option, but this is mostly for code cleanliness and not strictly required * Moving functions in ExprFuncBuilder over to be explicitly implementing ExprFunctionExt trait so we can guarantee a consistent user experience no matter which they call on the Expr and which on the builder * Apply cargo fmt * Add deprecated trait AggregateExt so that users get a warning but still builds * Window helper functions should return Expr * Update documentation to show window function example * Add license info * Update comments that are no longer applicable * Remove first_value and last_value since these are already implemented in the aggregate functions * Update to use WindowFunction::new to set additional parameters for order_by using ExprFunctionExt * Apply cargo fmt * Fix up clippy * fix doc example * fmt * doc tweaks * more doc tweaks * fix up links * fix integration test * fix anothr doc example --------- Co-authored-by: Tim Saucer <[email protected]> Co-authored-by: Andrew Lamb <[email protected]>
1 parent 76039fa commit 886e8ac

File tree

26 files changed

+657
-444
lines changed

26 files changed

+657
-444
lines changed

datafusion-examples/examples/advanced_udwf.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,12 +216,12 @@ async fn main() -> Result<()> {
216216
df.show().await?;
217217

218218
// Now, run the function using the DataFrame API:
219-
let window_expr = smooth_it.call(
220-
vec![col("speed")], // smooth_it(speed)
221-
vec![col("car")], // PARTITION BY car
222-
vec![col("time").sort(true, true)], // ORDER BY time ASC
223-
WindowFrame::new(None),
224-
);
219+
let window_expr = smooth_it
220+
.call(vec![col("speed")]) // smooth_it(speed)
221+
.partition_by(vec![col("car")]) // PARTITION BY car
222+
.order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC
223+
.window_frame(WindowFrame::new(None))
224+
.build()?;
225225
let df = ctx.table("cars").await?.window(vec![window_expr])?;
226226

227227
// print the results

datafusion-examples/examples/expr_api.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps;
3333
use datafusion_expr::expr::BinaryExpr;
3434
use datafusion_expr::interval_arithmetic::Interval;
3535
use datafusion_expr::simplify::SimplifyContext;
36-
use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator};
36+
use datafusion_expr::{ColumnarValue, ExprFunctionExt, ExprSchemable, Operator};
3737

3838
/// This example demonstrates the DataFusion [`Expr`] API.
3939
///
@@ -95,7 +95,7 @@ fn expr_fn_demo() -> Result<()> {
9595
let agg = first_value.call(vec![col("price")]);
9696
assert_eq!(agg.to_string(), "first_value(price)");
9797

98-
// You can use the AggregateExt trait to create more complex aggregates
98+
// You can use the ExprFunctionExt trait to create more complex aggregates
9999
// such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts )
100100
let agg = first_value
101101
.call(vec![col("price")])

datafusion-examples/examples/simple_udwf.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,12 @@ async fn main() -> Result<()> {
118118
df.show().await?;
119119

120120
// Now, run the function using the DataFrame API:
121-
let window_expr = smooth_it.call(
122-
vec![col("speed")], // smooth_it(speed)
123-
vec![col("car")], // PARTITION BY car
124-
vec![col("time").sort(true, true)], // ORDER BY time ASC
125-
WindowFrame::new(None),
126-
);
121+
let window_expr = smooth_it
122+
.call(vec![col("speed")]) // smooth_it(speed)
123+
.partition_by(vec![col("car")]) // PARTITION BY car
124+
.order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC
125+
.window_frame(WindowFrame::new(None))
126+
.build()?;
127127
let df = ctx.table("cars").await?.window(vec![window_expr])?;
128128

129129
// print the results

datafusion/core/src/dataframe/mod.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,8 +1696,8 @@ mod tests {
16961696
use datafusion_common::{Constraint, Constraints, ScalarValue};
16971697
use datafusion_common_runtime::SpawnedTask;
16981698
use datafusion_expr::{
1699-
cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation,
1700-
Volatility, WindowFrame, WindowFunctionDefinition,
1699+
cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt,
1700+
ScalarFunctionImplementation, Volatility, WindowFunctionDefinition,
17011701
};
17021702
use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct};
17031703
use datafusion_physical_expr::expressions::Column;
@@ -1867,11 +1867,10 @@ mod tests {
18671867
BuiltInWindowFunction::FirstValue,
18681868
),
18691869
vec![col("aggregate_test_100.c1")],
1870-
vec![col("aggregate_test_100.c2")],
1871-
vec![],
1872-
WindowFrame::new(None),
1873-
None,
1874-
));
1870+
))
1871+
.partition_by(vec![col("aggregate_test_100.c2")])
1872+
.build()
1873+
.unwrap();
18751874
let t2 = t.select(vec![col("c1"), first_row])?;
18761875
let plan = t2.plan.clone();
18771876

datafusion/core/tests/dataframe/mod.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ use datafusion_expr::expr::{GroupingSet, Sort};
5555
use datafusion_expr::var_provider::{VarProvider, VarType};
5656
use datafusion_expr::{
5757
cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder,
58-
scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
59-
WindowFrameUnits, WindowFunctionDefinition,
58+
scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame,
59+
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
6060
};
6161
use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum};
6262

@@ -183,15 +183,15 @@ async fn test_count_wildcard_on_window() -> Result<()> {
183183
.select(vec![Expr::WindowFunction(expr::WindowFunction::new(
184184
WindowFunctionDefinition::AggregateUDF(count_udaf()),
185185
vec![wildcard()],
186-
vec![],
187-
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
188-
WindowFrame::new_bounds(
189-
WindowFrameUnits::Range,
190-
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
191-
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
192-
),
193-
None,
194-
))])?
186+
))
187+
.order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))])
188+
.window_frame(WindowFrame::new_bounds(
189+
WindowFrameUnits::Range,
190+
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
191+
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
192+
))
193+
.build()
194+
.unwrap()])?
195195
.explain(false, false)?
196196
.collect()
197197
.await?;

datafusion/core/tests/expr_api/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray};
2121
use arrow_schema::{DataType, Field};
2222
use datafusion::prelude::*;
2323
use datafusion_common::{assert_contains, DFSchema, ScalarValue};
24-
use datafusion_expr::AggregateExt;
24+
use datafusion_expr::ExprFunctionExt;
2525
use datafusion_functions::core::expr_ext::FieldAccessor;
2626
use datafusion_functions_aggregate::first_last::first_value_udaf;
2727
use datafusion_functions_aggregate::sum::sum_udaf;

datafusion/expr/src/expr.rs

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ use crate::expr_fn::binary_expr;
2828
use crate::logical_plan::Subquery;
2929
use crate::utils::expr_to_columns;
3030
use crate::{
31-
aggregate_function, built_in_window_function, udaf, ExprSchemable, Operator,
32-
Signature,
31+
aggregate_function, built_in_window_function, udaf, BuiltInWindowFunction,
32+
ExprSchemable, Operator, Signature, WindowFrame, WindowUDF,
3333
};
3434
use crate::{window_frame, Volatility};
3535

@@ -60,6 +60,10 @@ use sqlparser::ast::NullTreatment;
6060
/// use the fluent APIs in [`crate::expr_fn`] such as [`col`] and [`lit`], or
6161
/// methods such as [`Expr::alias`], [`Expr::cast_to`], and [`Expr::Like`]).
6262
///
63+
/// See also [`ExprFunctionExt`] for creating aggregate and window functions.
64+
///
65+
/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt
66+
///
6367
/// # Schema Access
6468
///
6569
/// See [`ExprSchemable::get_type`] to access the [`DataType`] and nullability
@@ -283,15 +287,17 @@ pub enum Expr {
283287
/// This expression is guaranteed to have a fixed type.
284288
TryCast(TryCast),
285289
/// A sort expression, that can be used to sort values.
290+
///
291+
/// See [Expr::sort] for more details
286292
Sort(Sort),
287293
/// Represents the call of a scalar function with a set of arguments.
288294
ScalarFunction(ScalarFunction),
289295
/// Calls an aggregate function with arguments, and optional
290296
/// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`.
291297
///
292-
/// See also [`AggregateExt`] to set these fields.
298+
/// See also [`ExprFunctionExt`] to set these fields.
293299
///
294-
/// [`AggregateExt`]: crate::udaf::AggregateExt
300+
/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt
295301
AggregateFunction(AggregateFunction),
296302
/// Represents the call of a window function with arguments.
297303
WindowFunction(WindowFunction),
@@ -641,9 +647,9 @@ impl AggregateFunctionDefinition {
641647

642648
/// Aggregate function
643649
///
644-
/// See also [`AggregateExt`] to set these fields on `Expr`
650+
/// See also [`ExprFunctionExt`] to set these fields on `Expr`
645651
///
646-
/// [`AggregateExt`]: crate::udaf::AggregateExt
652+
/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt
647653
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
648654
pub struct AggregateFunction {
649655
/// Name of the function
@@ -769,7 +775,52 @@ impl fmt::Display for WindowFunctionDefinition {
769775
}
770776
}
771777

778+
impl From<aggregate_function::AggregateFunction> for WindowFunctionDefinition {
779+
fn from(value: aggregate_function::AggregateFunction) -> Self {
780+
Self::AggregateFunction(value)
781+
}
782+
}
783+
784+
impl From<BuiltInWindowFunction> for WindowFunctionDefinition {
785+
fn from(value: BuiltInWindowFunction) -> Self {
786+
Self::BuiltInWindowFunction(value)
787+
}
788+
}
789+
790+
impl From<Arc<crate::AggregateUDF>> for WindowFunctionDefinition {
791+
fn from(value: Arc<crate::AggregateUDF>) -> Self {
792+
Self::AggregateUDF(value)
793+
}
794+
}
795+
796+
impl From<Arc<WindowUDF>> for WindowFunctionDefinition {
797+
fn from(value: Arc<WindowUDF>) -> Self {
798+
Self::WindowUDF(value)
799+
}
800+
}
801+
772802
/// Window function
803+
///
804+
/// Holds the actual actual function to call [`WindowFunction`] as well as its
805+
/// arguments (`args`) and the contents of the `OVER` clause:
806+
///
807+
/// 1. `PARTITION BY`
808+
/// 2. `ORDER BY`
809+
/// 3. Window frame (e.g. `ROWS 1 PRECEDING AND 1 FOLLOWING`)
810+
///
811+
/// # Example
812+
/// ```
813+
/// # use datafusion_expr::{Expr, BuiltInWindowFunction, col, ExprFunctionExt};
814+
/// # use datafusion_expr::expr::WindowFunction;
815+
/// // Create FIRST_VALUE(a) OVER (PARTITION BY b ORDER BY c)
816+
/// let expr = Expr::WindowFunction(
817+
/// WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![col("a")])
818+
/// )
819+
/// .partition_by(vec![col("b")])
820+
/// .order_by(vec![col("b").sort(true, true)])
821+
/// .build()
822+
/// .unwrap();
823+
/// ```
773824
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
774825
pub struct WindowFunction {
775826
/// Name of the function
@@ -787,22 +838,16 @@ pub struct WindowFunction {
787838
}
788839

789840
impl WindowFunction {
790-
/// Create a new Window expression
791-
pub fn new(
792-
fun: WindowFunctionDefinition,
793-
args: Vec<Expr>,
794-
partition_by: Vec<Expr>,
795-
order_by: Vec<Expr>,
796-
window_frame: window_frame::WindowFrame,
797-
null_treatment: Option<NullTreatment>,
798-
) -> Self {
841+
/// Create a new Window expression with the specified argument an
842+
/// empty `OVER` clause
843+
pub fn new(fun: impl Into<WindowFunctionDefinition>, args: Vec<Expr>) -> Self {
799844
Self {
800-
fun,
845+
fun: fun.into(),
801846
args,
802-
partition_by,
803-
order_by,
804-
window_frame,
805-
null_treatment,
847+
partition_by: Vec::default(),
848+
order_by: Vec::default(),
849+
window_frame: WindowFrame::new(None),
850+
null_treatment: None,
806851
}
807852
}
808853
}

0 commit comments

Comments
 (0)