Skip to content

Commit 032b9c9

Browse files
authored
fix: impl ordering for serialization/deserialization for AggregateUdf (#11926)
* fix: support ordering and pencentile function ser/der * add more test case
1 parent 5251dc9 commit 032b9c9

File tree

8 files changed

+71
-22
lines changed

8 files changed

+71
-22
lines changed

datafusion/core/src/physical_optimizer/test_utils.rs

-1
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,6 @@ pub fn bounded_window_exec(
251251
"count".to_owned(),
252252
&[col(col_name, &schema).unwrap()],
253253
&[],
254-
&[],
255254
&sort_exprs,
256255
Arc::new(WindowFrame::new(Some(false))),
257256
schema.as_ref(),

datafusion/core/src/physical_planner.rs

-1
Original file line numberDiff line numberDiff line change
@@ -1510,7 +1510,6 @@ pub fn create_window_expr_with_name(
15101510
fun,
15111511
name,
15121512
&physical_args,
1513-
args,
15141513
&partition_by,
15151514
&order_by,
15161515
window_frame,

datafusion/core/tests/fuzz_cases/window_fuzz.rs

-4
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
253253

254254
let partitionby_exprs = vec![];
255255
let orderby_exprs = vec![];
256-
let logical_exprs = vec![];
257256
// Window frame starts with "UNBOUNDED PRECEDING":
258257
let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None));
259258

@@ -285,7 +284,6 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
285284
&window_fn,
286285
fn_name.to_string(),
287286
&args,
288-
&logical_exprs,
289287
&partitionby_exprs,
290288
&orderby_exprs,
291289
Arc::new(window_frame),
@@ -674,7 +672,6 @@ async fn run_window_test(
674672
&window_fn,
675673
fn_name.clone(),
676674
&args,
677-
&[],
678675
&partitionby_exprs,
679676
&orderby_exprs,
680677
Arc::new(window_frame.clone()),
@@ -693,7 +690,6 @@ async fn run_window_test(
693690
&window_fn,
694691
fn_name,
695692
&args,
696-
&[],
697693
&partitionby_exprs,
698694
&orderby_exprs,
699695
Arc::new(window_frame.clone()),

datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs

+1-5
Original file line numberDiff line numberDiff line change
@@ -1196,7 +1196,7 @@ mod tests {
11961196
RecordBatchStream, SendableRecordBatchStream, TaskContext,
11971197
};
11981198
use datafusion_expr::{
1199-
Expr, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
1199+
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
12001200
};
12011201
use datafusion_functions_aggregate::count::count_udaf;
12021202
use datafusion_physical_expr::expressions::{col, Column, NthValue};
@@ -1303,10 +1303,7 @@ mod tests {
13031303
let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf());
13041304
let col_expr =
13051305
Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc<dyn PhysicalExpr>;
1306-
let log_expr =
1307-
Expr::Column(datafusion_common::Column::from(schema.fields[0].name()));
13081306
let args = vec![col_expr];
1309-
let log_args = vec![log_expr];
13101307
let partitionby_exprs = vec![col(hash, &schema)?];
13111308
let orderby_exprs = vec![PhysicalSortExpr {
13121309
expr: col(order_by, &schema)?,
@@ -1327,7 +1324,6 @@ mod tests {
13271324
&window_fn,
13281325
fn_name,
13291326
&args,
1330-
&log_args,
13311327
&partitionby_exprs,
13321328
&orderby_exprs,
13331329
Arc::new(window_frame.clone()),

datafusion/physical-plan/src/windows/mod.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ use arrow::datatypes::Schema;
3232
use arrow_schema::{DataType, Field, SchemaRef};
3333
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
3434
use datafusion_expr::{
35-
BuiltInWindowFunction, Expr, PartitionEvaluator, WindowFrame,
36-
WindowFunctionDefinition, WindowUDF,
35+
BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition,
36+
WindowUDF,
3737
};
3838
use datafusion_physical_expr::equivalence::collapse_lex_req;
3939
use datafusion_physical_expr::{
@@ -94,7 +94,6 @@ pub fn create_window_expr(
9494
fun: &WindowFunctionDefinition,
9595
name: String,
9696
args: &[Arc<dyn PhysicalExpr>],
97-
_logical_args: &[Expr],
9897
partition_by: &[Arc<dyn PhysicalExpr>],
9998
order_by: &[PhysicalSortExpr],
10099
window_frame: Arc<WindowFrame>,
@@ -746,7 +745,6 @@ mod tests {
746745
&[col("a", &schema)?],
747746
&[],
748747
&[],
749-
&[],
750748
Arc::new(WindowFrame::new(None)),
751749
schema.as_ref(),
752750
false,

datafusion/proto/src/physical_plan/from_proto.rs

-3
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,10 @@ pub fn parse_physical_window_expr(
169169
// TODO: Remove extended_schema if functions are all UDAF
170170
let extended_schema =
171171
schema_add_window_field(&window_node_expr, input_schema, &fun, &name)?;
172-
// approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet.
173-
let logical_exprs = &[];
174172
create_window_expr(
175173
&fun,
176174
name,
177175
&window_node_expr,
178-
logical_exprs,
179176
&partition_by,
180177
&order_by,
181178
Arc::new(window_frame),

datafusion/proto/src/physical_plan/mod.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
477477
ExprType::AggregateExpr(agg_node) => {
478478
let input_phy_expr: Vec<Arc<dyn PhysicalExpr>> = agg_node.expr.iter()
479479
.map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec)).collect::<Result<Vec<_>>>()?;
480-
let _ordering_req: Vec<PhysicalSortExpr> = agg_node.ordering_req.iter()
480+
let ordering_req: Vec<PhysicalSortExpr> = agg_node.ordering_req.iter()
481481
.map(|e| parse_physical_sort_expr(e, registry, &physical_schema, extension_codec)).collect::<Result<Vec<_>>>()?;
482482
agg_node.aggregate_function.as_ref().map(|func| {
483483
match func {
@@ -487,14 +487,12 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
487487
None => registry.udaf(udaf_name)?
488488
};
489489

490-
// TODO: approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet.
491-
// TODO: `order by` is not supported for UDAF yet
492-
// https://github.com/apache/datafusion/issues/11804
493490
AggregateExprBuilder::new(agg_udf, input_phy_expr)
494491
.schema(Arc::clone(&physical_schema))
495492
.alias(name)
496493
.with_ignore_nulls(agg_node.ignore_nulls)
497494
.with_distinct(agg_node.distinct)
495+
.order_by(ordering_req)
498496
.build()
499497
}
500498
}

datafusion/proto/tests/cases/roundtrip_physical_plan.rs

+66
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ use std::vec;
2525
use arrow::array::RecordBatch;
2626
use arrow::csv::WriterBuilder;
2727
use datafusion::physical_expr_functions_aggregate::aggregate::AggregateExprBuilder;
28+
use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf;
29+
use datafusion_functions_aggregate::array_agg::array_agg_udaf;
2830
use datafusion_functions_aggregate::min_max::max_udaf;
2931
use prost::Message;
3032

@@ -412,6 +414,70 @@ fn rountrip_aggregate_with_limit() -> Result<()> {
412414
roundtrip_test(Arc::new(agg))
413415
}
414416

417+
#[test]
418+
fn rountrip_aggregate_with_approx_pencentile_cont() -> Result<()> {
419+
let field_a = Field::new("a", DataType::Int64, false);
420+
let field_b = Field::new("b", DataType::Int64, false);
421+
let schema = Arc::new(Schema::new(vec![field_a, field_b]));
422+
423+
let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
424+
vec![(col("a", &schema)?, "unused".to_string())];
425+
426+
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![AggregateExprBuilder::new(
427+
approx_percentile_cont_udaf(),
428+
vec![col("b", &schema)?, lit(0.5)],
429+
)
430+
.schema(Arc::clone(&schema))
431+
.alias("APPROX_PERCENTILE_CONT(b, 0.5)")
432+
.build()?];
433+
434+
let agg = AggregateExec::try_new(
435+
AggregateMode::Final,
436+
PhysicalGroupBy::new_single(groups.clone()),
437+
aggregates.clone(),
438+
vec![None],
439+
Arc::new(EmptyExec::new(schema.clone())),
440+
schema,
441+
)?;
442+
roundtrip_test(Arc::new(agg))
443+
}
444+
445+
#[test]
446+
fn rountrip_aggregate_with_sort() -> Result<()> {
447+
let field_a = Field::new("a", DataType::Int64, false);
448+
let field_b = Field::new("b", DataType::Int64, false);
449+
let schema = Arc::new(Schema::new(vec![field_a, field_b]));
450+
451+
let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
452+
vec![(col("a", &schema)?, "unused".to_string())];
453+
let sort_exprs = vec![PhysicalSortExpr {
454+
expr: col("b", &schema)?,
455+
options: SortOptions {
456+
descending: false,
457+
nulls_first: true,
458+
},
459+
}];
460+
461+
let aggregates: Vec<Arc<dyn AggregateExpr>> =
462+
vec![
463+
AggregateExprBuilder::new(array_agg_udaf(), vec![col("b", &schema)?])
464+
.schema(Arc::clone(&schema))
465+
.alias("ARRAY_AGG(b)")
466+
.order_by(sort_exprs)
467+
.build()?,
468+
];
469+
470+
let agg = AggregateExec::try_new(
471+
AggregateMode::Final,
472+
PhysicalGroupBy::new_single(groups.clone()),
473+
aggregates.clone(),
474+
vec![None],
475+
Arc::new(EmptyExec::new(schema.clone())),
476+
schema,
477+
)?;
478+
roundtrip_test(Arc::new(agg))
479+
}
480+
415481
#[test]
416482
fn roundtrip_aggregate_udaf() -> Result<()> {
417483
let field_a = Field::new("a", DataType::Int64, false);

0 commit comments

Comments
 (0)