Skip to content

Commit 1d0171a

Browse files
authored
Make builtin window function output datatype to be derived from schema (#9686)
* Make builtin window function output datatype to be derived from schema
1 parent 89efc4a commit 1d0171a

File tree

3 files changed

+72
-36
lines changed

3 files changed

+72
-36
lines changed

datafusion/core/src/physical_planner.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -742,13 +742,13 @@ impl DefaultPhysicalPlanner {
742742
);
743743
}
744744

745-
let logical_input_schema = input.schema();
745+
let logical_schema = logical_plan.schema();
746746
let window_expr = window_expr
747747
.iter()
748748
.map(|e| {
749749
create_window_expr(
750750
e,
751-
logical_input_schema,
751+
logical_schema,
752752
session_state.execution_props(),
753753
)
754754
})
@@ -1578,11 +1578,11 @@ pub fn is_window_frame_bound_valid(window_frame: &WindowFrame) -> bool {
15781578
pub fn create_window_expr_with_name(
15791579
e: &Expr,
15801580
name: impl Into<String>,
1581-
logical_input_schema: &DFSchema,
1581+
logical_schema: &DFSchema,
15821582
execution_props: &ExecutionProps,
15831583
) -> Result<Arc<dyn WindowExpr>> {
15841584
let name = name.into();
1585-
let physical_input_schema: &Schema = &logical_input_schema.into();
1585+
let physical_schema: &Schema = &logical_schema.into();
15861586
match e {
15871587
Expr::WindowFunction(WindowFunction {
15881588
fun,
@@ -1594,17 +1594,15 @@ pub fn create_window_expr_with_name(
15941594
}) => {
15951595
let args = args
15961596
.iter()
1597-
.map(|e| create_physical_expr(e, logical_input_schema, execution_props))
1597+
.map(|e| create_physical_expr(e, logical_schema, execution_props))
15981598
.collect::<Result<Vec<_>>>()?;
15991599
let partition_by = partition_by
16001600
.iter()
1601-
.map(|e| create_physical_expr(e, logical_input_schema, execution_props))
1601+
.map(|e| create_physical_expr(e, logical_schema, execution_props))
16021602
.collect::<Result<Vec<_>>>()?;
16031603
let order_by = order_by
16041604
.iter()
1605-
.map(|e| {
1606-
create_physical_sort_expr(e, logical_input_schema, execution_props)
1607-
})
1605+
.map(|e| create_physical_sort_expr(e, logical_schema, execution_props))
16081606
.collect::<Result<Vec<_>>>()?;
16091607

16101608
if !is_window_frame_bound_valid(window_frame) {
@@ -1625,7 +1623,7 @@ pub fn create_window_expr_with_name(
16251623
&partition_by,
16261624
&order_by,
16271625
window_frame,
1628-
physical_input_schema,
1626+
physical_schema,
16291627
ignore_nulls,
16301628
)
16311629
}
@@ -1636,15 +1634,15 @@ pub fn create_window_expr_with_name(
16361634
/// Create a window expression from a logical expression or an alias
16371635
pub fn create_window_expr(
16381636
e: &Expr,
1639-
logical_input_schema: &DFSchema,
1637+
logical_schema: &DFSchema,
16401638
execution_props: &ExecutionProps,
16411639
) -> Result<Arc<dyn WindowExpr>> {
16421640
// unpack aliased logical expressions, e.g. "sum(col) over () as total"
16431641
let (name, e) = match e {
16441642
Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()),
16451643
_ => (e.display_name()?, e),
16461644
};
1647-
create_window_expr_with_name(e, name, logical_input_schema, execution_props)
1645+
create_window_expr_with_name(e, name, logical_schema, execution_props)
16481646
}
16491647

16501648
type AggregateExprWithOptionalArgs = (

datafusion/core/tests/fuzz_cases/window_fuzz.rs

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use arrow::compute::{concat_batches, SortOptions};
2222
use arrow::datatypes::SchemaRef;
2323
use arrow::record_batch::RecordBatch;
2424
use arrow::util::pretty::pretty_format_batches;
25+
use arrow_schema::{Field, Schema};
2526
use datafusion::physical_plan::memory::MemoryExec;
2627
use datafusion::physical_plan::sorts::sort::SortExec;
2728
use datafusion::physical_plan::windows::{
@@ -39,6 +40,7 @@ use datafusion_expr::{
3940
};
4041
use datafusion_physical_expr::expressions::{cast, col, lit};
4142
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
43+
use itertools::Itertools;
4244
use test_utils::add_empty_batches;
4345

4446
use hashbrown::HashMap;
@@ -273,14 +275,17 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
273275
window_frame.is_causal()
274276
};
275277

278+
let extended_schema =
279+
schema_add_window_fields(&args, &schema, &window_fn, fn_name)?;
280+
276281
let window_expr = create_window_expr(
277282
&window_fn,
278283
fn_name.to_string(),
279284
&args,
280285
&partitionby_exprs,
281286
&orderby_exprs,
282287
Arc::new(window_frame),
283-
schema.as_ref(),
288+
&extended_schema,
284289
false,
285290
)?;
286291
let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
@@ -678,6 +683,8 @@ async fn run_window_test(
678683
exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _;
679684
}
680685

686+
let extended_schema = schema_add_window_fields(&args, &schema, &window_fn, &fn_name)?;
687+
681688
let usual_window_exec = Arc::new(WindowAggExec::try_new(
682689
vec![create_window_expr(
683690
&window_fn,
@@ -686,7 +693,7 @@ async fn run_window_test(
686693
&partitionby_exprs,
687694
&orderby_exprs,
688695
Arc::new(window_frame.clone()),
689-
schema.as_ref(),
696+
&extended_schema,
690697
false,
691698
)?],
692699
exec1,
@@ -704,7 +711,7 @@ async fn run_window_test(
704711
&partitionby_exprs,
705712
&orderby_exprs,
706713
Arc::new(window_frame.clone()),
707-
schema.as_ref(),
714+
&extended_schema,
708715
false,
709716
)?],
710717
exec2,
@@ -747,6 +754,32 @@ async fn run_window_test(
747754
Ok(())
748755
}
749756

757+
// The planner has fully updated schema before calling the `create_window_expr`
758+
// Replicate the same for this test
759+
fn schema_add_window_fields(
760+
args: &[Arc<dyn PhysicalExpr>],
761+
schema: &Arc<Schema>,
762+
window_fn: &WindowFunctionDefinition,
763+
fn_name: &str,
764+
) -> Result<Arc<Schema>> {
765+
let data_types = args
766+
.iter()
767+
.map(|e| e.clone().as_ref().data_type(schema))
768+
.collect::<Result<Vec<_>>>()?;
769+
let window_expr_return_type = window_fn.return_type(&data_types)?;
770+
let mut window_fields = schema
771+
.fields()
772+
.iter()
773+
.map(|f| f.as_ref().clone())
774+
.collect_vec();
775+
window_fields.extend_from_slice(&[Field::new(
776+
fn_name,
777+
window_expr_return_type,
778+
true,
779+
)]);
780+
Ok(Arc::new(Schema::new(window_fields)))
781+
}
782+
750783
/// Return randomly sized record batches with:
751784
/// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns
752785
/// one random int32 column x

datafusion/physical-plan/src/windows/mod.rs

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -174,20 +174,15 @@ fn create_built_in_window_expr(
174174
name: String,
175175
ignore_nulls: bool,
176176
) -> Result<Arc<dyn BuiltInWindowFunctionExpr>> {
177-
// need to get the types into an owned vec for some reason
178-
let input_types: Vec<_> = args
179-
.iter()
180-
.map(|arg| arg.data_type(input_schema))
181-
.collect::<Result<_>>()?;
177+
// derive the output datatype from incoming schema
178+
let out_data_type: &DataType = input_schema.field_with_name(&name)?.data_type();
182179

183-
// figure out the output type
184-
let data_type = &fun.return_type(&input_types)?;
185180
Ok(match fun {
186-
BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, data_type)),
187-
BuiltInWindowFunction::Rank => Arc::new(rank(name, data_type)),
188-
BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, data_type)),
189-
BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, data_type)),
190-
BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, data_type)),
181+
BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, out_data_type)),
182+
BuiltInWindowFunction::Rank => Arc::new(rank(name, out_data_type)),
183+
BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, out_data_type)),
184+
BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, out_data_type)),
185+
BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, out_data_type)),
191186
BuiltInWindowFunction::Ntile => {
192187
let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| {
193188
DataFusionError::Execution(
@@ -201,13 +196,13 @@ fn create_built_in_window_expr(
201196

202197
if n.is_unsigned() {
203198
let n: u64 = n.try_into()?;
204-
Arc::new(Ntile::new(name, n, data_type))
199+
Arc::new(Ntile::new(name, n, out_data_type))
205200
} else {
206201
let n: i64 = n.try_into()?;
207202
if n <= 0 {
208203
return exec_err!("NTILE requires a positive integer");
209204
}
210-
Arc::new(Ntile::new(name, n as u64, data_type))
205+
Arc::new(Ntile::new(name, n as u64, out_data_type))
211206
}
212207
}
213208
BuiltInWindowFunction::Lag => {
@@ -216,10 +211,10 @@ fn create_built_in_window_expr(
216211
.map(|v| v.try_into())
217212
.and_then(|v| v.ok());
218213
let default_value =
219-
get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?;
214+
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
220215
Arc::new(lag(
221216
name,
222-
data_type.clone(),
217+
out_data_type.clone(),
223218
arg,
224219
shift_offset,
225220
default_value,
@@ -232,10 +227,10 @@ fn create_built_in_window_expr(
232227
.map(|v| v.try_into())
233228
.and_then(|v| v.ok());
234229
let default_value =
235-
get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?;
230+
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
236231
Arc::new(lead(
237232
name,
238-
data_type.clone(),
233+
out_data_type.clone(),
239234
arg,
240235
shift_offset,
241236
default_value,
@@ -252,18 +247,28 @@ fn create_built_in_window_expr(
252247
Arc::new(NthValue::nth(
253248
name,
254249
arg,
255-
data_type.clone(),
250+
out_data_type.clone(),
256251
n,
257252
ignore_nulls,
258253
)?)
259254
}
260255
BuiltInWindowFunction::FirstValue => {
261256
let arg = args[0].clone();
262-
Arc::new(NthValue::first(name, arg, data_type.clone(), ignore_nulls))
257+
Arc::new(NthValue::first(
258+
name,
259+
arg,
260+
out_data_type.clone(),
261+
ignore_nulls,
262+
))
263263
}
264264
BuiltInWindowFunction::LastValue => {
265265
let arg = args[0].clone();
266-
Arc::new(NthValue::last(name, arg, data_type.clone(), ignore_nulls))
266+
Arc::new(NthValue::last(
267+
name,
268+
arg,
269+
out_data_type.clone(),
270+
ignore_nulls,
271+
))
267272
}
268273
})
269274
}

0 commit comments

Comments
 (0)