Skip to content

Commit a851ecf

Browse files
Support IGNORE NULLS for LAG window function (#9221)
* WIP lag/lead ignore nulls * Support IGNORE NULLS for LAG function * fmt * comments * remove comments * Add new tests, minor changes, trigger evalaute_all * Make algorithm pruning friendly --------- Co-authored-by: Mustafa Akur <[email protected]>
1 parent 02c948d commit a851ecf

File tree

22 files changed

+272
-14
lines changed

22 files changed

+272
-14
lines changed

datafusion/core/src/dataframe/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1685,6 +1685,7 @@ mod tests {
16851685
vec![col("aggregate_test_100.c2")],
16861686
vec![],
16871687
WindowFrame::new(None),
1688+
None,
16881689
));
16891690
let t2 = t.select(vec![col("c1"), first_row])?;
16901691
let plan = t2.plan.clone();

datafusion/core/src/physical_optimizer/test_utils.rs

+1
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ pub fn bounded_window_exec(
245245
&sort_exprs,
246246
Arc::new(WindowFrame::new(Some(false))),
247247
schema.as_ref(),
248+
false,
248249
)
249250
.unwrap()],
250251
input.clone(),

datafusion/core/src/physical_planner.rs

+6
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ use futures::future::BoxFuture;
100100
use futures::{FutureExt, StreamExt, TryStreamExt};
101101
use itertools::{multiunzip, Itertools};
102102
use log::{debug, trace};
103+
use sqlparser::ast::NullTreatment;
103104

104105
fn create_function_physical_name(
105106
fun: &str,
@@ -1581,6 +1582,7 @@ pub fn create_window_expr_with_name(
15811582
partition_by,
15821583
order_by,
15831584
window_frame,
1585+
null_treatment,
15841586
}) => {
15851587
let args = args
15861588
.iter()
@@ -1605,6 +1607,9 @@ pub fn create_window_expr_with_name(
16051607
}
16061608

16071609
let window_frame = Arc::new(window_frame.clone());
1610+
let ignore_nulls = null_treatment
1611+
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
1612+
== NullTreatment::IgnoreNulls;
16081613
windows::create_window_expr(
16091614
fun,
16101615
name,
@@ -1613,6 +1618,7 @@ pub fn create_window_expr_with_name(
16131618
&order_by,
16141619
window_frame,
16151620
physical_input_schema,
1621+
ignore_nulls,
16161622
)
16171623
}
16181624
other => plan_err!("Invalid window expression '{other:?}'"),

datafusion/core/tests/dataframe/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ async fn test_count_wildcard_on_window() -> Result<()> {
182182
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
183183
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
184184
),
185+
None,
185186
))])?
186187
.explain(false, false)?
187188
.collect()

datafusion/core/tests/fuzz_cases/window_fuzz.rs

+3
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
281281
&orderby_exprs,
282282
Arc::new(window_frame),
283283
schema.as_ref(),
284+
false,
284285
)?;
285286
let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
286287
vec![window_expr],
@@ -642,6 +643,7 @@ async fn run_window_test(
642643
&orderby_exprs,
643644
Arc::new(window_frame.clone()),
644645
schema.as_ref(),
646+
false,
645647
)
646648
.unwrap()],
647649
exec1,
@@ -664,6 +666,7 @@ async fn run_window_test(
664666
&orderby_exprs,
665667
Arc::new(window_frame.clone()),
666668
schema.as_ref(),
669+
false,
667670
)
668671
.unwrap()],
669672
exec2,

datafusion/expr/src/expr.rs

+18
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use arrow::datatypes::DataType;
3030
use datafusion_common::tree_node::{Transformed, TreeNode};
3131
use datafusion_common::{internal_err, DFSchema, OwnedTableReference};
3232
use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue};
33+
use sqlparser::ast::NullTreatment;
3334
use std::collections::HashSet;
3435
use std::fmt;
3536
use std::fmt::{Display, Formatter, Write};
@@ -646,6 +647,7 @@ pub struct WindowFunction {
646647
pub order_by: Vec<Expr>,
647648
/// Window frame
648649
pub window_frame: window_frame::WindowFrame,
650+
pub null_treatment: Option<NullTreatment>,
649651
}
650652

651653
impl WindowFunction {
@@ -656,13 +658,15 @@ impl WindowFunction {
656658
partition_by: Vec<Expr>,
657659
order_by: Vec<Expr>,
658660
window_frame: window_frame::WindowFrame,
661+
null_treatment: Option<NullTreatment>,
659662
) -> Self {
660663
Self {
661664
fun,
662665
args,
663666
partition_by,
664667
order_by,
665668
window_frame,
669+
null_treatment,
666670
}
667671
}
668672
}
@@ -1440,8 +1444,14 @@ impl fmt::Display for Expr {
14401444
partition_by,
14411445
order_by,
14421446
window_frame,
1447+
null_treatment,
14431448
}) => {
14441449
fmt_function(f, &fun.to_string(), false, args, true)?;
1450+
1451+
if let Some(nt) = null_treatment {
1452+
write!(f, "{}", nt)?;
1453+
}
1454+
14451455
if !partition_by.is_empty() {
14461456
write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?;
14471457
}
@@ -1768,15 +1778,23 @@ fn create_name(e: &Expr) -> Result<String> {
17681778
window_frame,
17691779
partition_by,
17701780
order_by,
1781+
null_treatment,
17711782
}) => {
17721783
let mut parts: Vec<String> =
17731784
vec![create_function_name(&fun.to_string(), false, args)?];
1785+
1786+
if let Some(nt) = null_treatment {
1787+
parts.push(format!("{}", nt));
1788+
}
1789+
17741790
if !partition_by.is_empty() {
17751791
parts.push(format!("PARTITION BY [{}]", expr_vec_fmt!(partition_by)));
17761792
}
1793+
17771794
if !order_by.is_empty() {
17781795
parts.push(format!("ORDER BY [{}]", expr_vec_fmt!(order_by)));
17791796
}
1797+
17801798
parts.push(format!("{window_frame}"));
17811799
Ok(parts.join(" "))
17821800
}

datafusion/expr/src/tree_node/expr.rs

+2
Original file line numberDiff line numberDiff line change
@@ -283,12 +283,14 @@ impl TreeNode for Expr {
283283
partition_by,
284284
order_by,
285285
window_frame,
286+
null_treatment,
286287
}) => Expr::WindowFunction(WindowFunction::new(
287288
fun,
288289
transform_vec(args, &mut transform)?,
289290
transform_vec(partition_by, &mut transform)?,
290291
transform_vec(order_by, &mut transform)?,
291292
window_frame,
293+
null_treatment,
292294
)),
293295
Expr::AggregateFunction(AggregateFunction {
294296
args,

datafusion/expr/src/udwf.rs

+1
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ impl WindowUDF {
130130
partition_by,
131131
order_by,
132132
window_frame,
133+
null_treatment: None,
133134
})
134135
}
135136

datafusion/expr/src/utils.rs

+10
Original file line numberDiff line numberDiff line change
@@ -1255,27 +1255,31 @@ mod tests {
12551255
vec![],
12561256
vec![],
12571257
WindowFrame::new(None),
1258+
None,
12581259
));
12591260
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
12601261
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
12611262
vec![col("name")],
12621263
vec![],
12631264
vec![],
12641265
WindowFrame::new(None),
1266+
None,
12651267
));
12661268
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
12671269
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
12681270
vec![col("name")],
12691271
vec![],
12701272
vec![],
12711273
WindowFrame::new(None),
1274+
None,
12721275
));
12731276
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
12741277
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
12751278
vec![col("age")],
12761279
vec![],
12771280
vec![],
12781281
WindowFrame::new(None),
1282+
None,
12791283
));
12801284
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
12811285
let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
@@ -1298,27 +1302,31 @@ mod tests {
12981302
vec![],
12991303
vec![age_asc.clone(), name_desc.clone()],
13001304
WindowFrame::new(Some(false)),
1305+
None,
13011306
));
13021307
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
13031308
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
13041309
vec![col("name")],
13051310
vec![],
13061311
vec![],
13071312
WindowFrame::new(None),
1313+
None,
13081314
));
13091315
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
13101316
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
13111317
vec![col("name")],
13121318
vec![],
13131319
vec![age_asc.clone(), name_desc.clone()],
13141320
WindowFrame::new(Some(false)),
1321+
None,
13151322
));
13161323
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
13171324
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
13181325
vec![col("age")],
13191326
vec![],
13201327
vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
13211328
WindowFrame::new(Some(false)),
1329+
None,
13221330
));
13231331
// FIXME use as_ref
13241332
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
@@ -1353,6 +1361,7 @@ mod tests {
13531361
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
13541362
],
13551363
WindowFrame::new(Some(false)),
1364+
None,
13561365
)),
13571366
Expr::WindowFunction(expr::WindowFunction::new(
13581367
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
@@ -1364,6 +1373,7 @@ mod tests {
13641373
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
13651374
],
13661375
WindowFrame::new(Some(false)),
1376+
None,
13671377
)),
13681378
];
13691379
let expected = vec![

datafusion/optimizer/src/analyzer/count_wildcard_rule.rs

+3
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ impl TreeNodeRewriter for CountWildcardRewriter {
128128
partition_by,
129129
order_by,
130130
window_frame,
131+
null_treatment,
131132
}) if args.len() == 1 => match args[0] {
132133
Expr::Wildcard { qualifier: None } => {
133134
Expr::WindowFunction(expr::WindowFunction {
@@ -138,6 +139,7 @@ impl TreeNodeRewriter for CountWildcardRewriter {
138139
partition_by,
139140
order_by,
140141
window_frame,
142+
null_treatment,
141143
})
142144
}
143145

@@ -351,6 +353,7 @@ mod tests {
351353
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
352354
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
353355
),
356+
None,
354357
))])?
355358
.project(vec![count(wildcard())])?
356359
.build()?;

datafusion/optimizer/src/analyzer/type_coercion.rs

+2
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
392392
partition_by,
393393
order_by,
394394
window_frame,
395+
null_treatment,
395396
}) => {
396397
let window_frame =
397398
coerce_window_frame(window_frame, &self.schema, &order_by)?;
@@ -414,6 +415,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
414415
partition_by,
415416
order_by,
416417
window_frame,
418+
null_treatment,
417419
));
418420
Ok(expr)
419421
}

datafusion/optimizer/src/push_down_projection.rs

+2
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ mod tests {
587587
vec![col("test.b")],
588588
vec![],
589589
WindowFrame::new(None),
590+
None,
590591
));
591592

592593
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
@@ -595,6 +596,7 @@ mod tests {
595596
vec![],
596597
vec![],
597598
WindowFrame::new(None),
599+
None,
598600
));
599601
let col1 = col(max1.display_name()?);
600602
let col2 = col(max2.display_name()?);

0 commit comments

Comments
 (0)