Skip to content

Commit 3aba67e

Browse files
huaxingaoHuaxin Gaoalamb
authored
Implement IGNORE NULLS for FIRST_VALUE (#9411)
* Implement IGNORE NULLS for FIRST_VALUE * fix style * fix clippy error * fix clippy error * address comments * fix error * add test to aggregate.slt * address comments * Trigger Build * Add one additional column in order by to ensure a deterministic order in the output --------- Co-authored-by: Huaxin Gao <[email protected]> Co-authored-by: Andrew Lamb <[email protected]>
1 parent 2873fd0 commit 3aba67e

File tree

25 files changed

+286
-39
lines changed

25 files changed

+286
-39
lines changed

datafusion/core/src/physical_planner.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
246246
args,
247247
filter,
248248
order_by,
249+
null_treatment: _,
249250
}) => match func_def {
250251
AggregateFunctionDefinition::BuiltIn(..) => {
251252
create_function_physical_name(func_def.name(), *distinct, args)
@@ -1662,6 +1663,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
16621663
args,
16631664
filter,
16641665
order_by,
1666+
null_treatment,
16651667
}) => {
16661668
let args = args
16671669
.iter()
@@ -1689,6 +1691,9 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
16891691
),
16901692
None => None,
16911693
};
1694+
let ignore_nulls = null_treatment
1695+
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
1696+
== NullTreatment::IgnoreNulls;
16921697
let (agg_expr, filter, order_by) = match func_def {
16931698
AggregateFunctionDefinition::BuiltIn(fun) => {
16941699
let ordering_reqs = order_by.clone().unwrap_or(vec![]);
@@ -1699,6 +1704,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
16991704
&ordering_reqs,
17001705
physical_input_schema,
17011706
name,
1707+
ignore_nulls,
17021708
)?;
17031709
(agg_expr, filter, order_by)
17041710
}

datafusion/core/tests/sql/aggregates.rs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,83 @@ async fn test_accumulator_row_accumulator() -> Result<()> {
321321

322322
Ok(())
323323
}
324+
325+
#[tokio::test]
326+
async fn test_first_value() -> Result<()> {
327+
let session_ctx = SessionContext::new();
328+
session_ctx
329+
.sql("CREATE TABLE abc AS VALUES (null,2,3), (4,5,6)")
330+
.await?
331+
.collect()
332+
.await?;
333+
334+
let results1 = session_ctx
335+
.sql("SELECT FIRST_VALUE(column1) ignore nulls FROM abc")
336+
.await?
337+
.collect()
338+
.await?;
339+
let expected1 = [
340+
"+--------------------------+",
341+
"| FIRST_VALUE(abc.column1) |",
342+
"+--------------------------+",
343+
"| 4 |",
344+
"+--------------------------+",
345+
];
346+
assert_batches_eq!(expected1, &results1);
347+
348+
let results2 = session_ctx
349+
.sql("SELECT FIRST_VALUE(column1) respect nulls FROM abc")
350+
.await?
351+
.collect()
352+
.await?;
353+
let expected2 = [
354+
"+--------------------------+",
355+
"| FIRST_VALUE(abc.column1) |",
356+
"+--------------------------+",
357+
"| |",
358+
"+--------------------------+",
359+
];
360+
assert_batches_eq!(expected2, &results2);
361+
362+
Ok(())
363+
}
364+
365+
#[tokio::test]
366+
async fn test_first_value_with_sort() -> Result<()> {
367+
let session_ctx = SessionContext::new();
368+
session_ctx
369+
.sql("CREATE TABLE abc AS VALUES (null,2,3), (null,1,6), (4, 5, 5), (1, 4, 7), (2, 3, 8)")
370+
.await?
371+
.collect()
372+
.await?;
373+
374+
let results1 = session_ctx
375+
.sql("SELECT FIRST_VALUE(column1 ORDER BY column2) ignore nulls FROM abc")
376+
.await?
377+
.collect()
378+
.await?;
379+
let expected1 = [
380+
"+--------------------------+",
381+
"| FIRST_VALUE(abc.column1) |",
382+
"+--------------------------+",
383+
"| 2 |",
384+
"+--------------------------+",
385+
];
386+
assert_batches_eq!(expected1, &results1);
387+
388+
let results2 = session_ctx
389+
.sql("SELECT FIRST_VALUE(column1 ORDER BY column2) respect nulls FROM abc")
390+
.await?
391+
.collect()
392+
.await?;
393+
let expected2 = [
394+
"+--------------------------+",
395+
"| FIRST_VALUE(abc.column1) |",
396+
"+--------------------------+",
397+
"| |",
398+
"+--------------------------+",
399+
];
400+
assert_batches_eq!(expected2, &results2);
401+
402+
Ok(())
403+
}

datafusion/expr/src/expr.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ pub struct AggregateFunction {
543543
pub filter: Option<Box<Expr>>,
544544
/// Optional ordering
545545
pub order_by: Option<Vec<Expr>>,
546+
pub null_treatment: Option<NullTreatment>,
546547
}
547548

548549
impl AggregateFunction {
@@ -552,13 +553,15 @@ impl AggregateFunction {
552553
distinct: bool,
553554
filter: Option<Box<Expr>>,
554555
order_by: Option<Vec<Expr>>,
556+
null_treatment: Option<NullTreatment>,
555557
) -> Self {
556558
Self {
557559
func_def: AggregateFunctionDefinition::BuiltIn(fun),
558560
args,
559561
distinct,
560562
filter,
561563
order_by,
564+
null_treatment,
562565
}
563566
}
564567

@@ -576,6 +579,7 @@ impl AggregateFunction {
576579
distinct,
577580
filter,
578581
order_by,
582+
null_treatment: None,
579583
}
580584
}
581585
}
@@ -646,6 +650,7 @@ pub struct WindowFunction {
646650
pub order_by: Vec<Expr>,
647651
/// Window frame
648652
pub window_frame: window_frame::WindowFrame,
653+
/// Specifies how NULL value is treated: ignore or respect
649654
pub null_treatment: Option<NullTreatment>,
650655
}
651656

@@ -1471,9 +1476,13 @@ impl fmt::Display for Expr {
14711476
ref args,
14721477
filter,
14731478
order_by,
1479+
null_treatment,
14741480
..
14751481
}) => {
14761482
fmt_function(f, func_def.name(), *distinct, args, true)?;
1483+
if let Some(nt) = null_treatment {
1484+
write!(f, " {}", nt)?;
1485+
}
14771486
if let Some(fe) = filter {
14781487
write!(f, " FILTER (WHERE {fe})")?;
14791488
}
@@ -1804,6 +1813,7 @@ fn create_name(e: &Expr) -> Result<String> {
18041813
args,
18051814
filter,
18061815
order_by,
1816+
null_treatment,
18071817
}) => {
18081818
let name = match func_def {
18091819
AggregateFunctionDefinition::BuiltIn(..)
@@ -1823,6 +1833,9 @@ fn create_name(e: &Expr) -> Result<String> {
18231833
if let Some(order_by) = order_by {
18241834
info += &format!(" ORDER BY [{}]", expr_vec_fmt!(order_by));
18251835
};
1836+
if let Some(nt) = null_treatment {
1837+
info += &format!(" {}", nt);
1838+
}
18261839
match func_def {
18271840
AggregateFunctionDefinition::BuiltIn(..)
18281841
| AggregateFunctionDefinition::Name(..) => {

datafusion/expr/src/expr_fn.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ pub fn min(expr: Expr) -> Expr {
150150
false,
151151
None,
152152
None,
153+
None,
153154
))
154155
}
155156

@@ -161,6 +162,7 @@ pub fn max(expr: Expr) -> Expr {
161162
false,
162163
None,
163164
None,
165+
None,
164166
))
165167
}
166168

@@ -172,6 +174,7 @@ pub fn sum(expr: Expr) -> Expr {
172174
false,
173175
None,
174176
None,
177+
None,
175178
))
176179
}
177180

@@ -183,6 +186,7 @@ pub fn array_agg(expr: Expr) -> Expr {
183186
false,
184187
None,
185188
None,
189+
None,
186190
))
187191
}
188192

@@ -194,6 +198,7 @@ pub fn avg(expr: Expr) -> Expr {
194198
false,
195199
None,
196200
None,
201+
None,
197202
))
198203
}
199204

@@ -205,6 +210,7 @@ pub fn count(expr: Expr) -> Expr {
205210
false,
206211
None,
207212
None,
213+
None,
208214
))
209215
}
210216

@@ -261,6 +267,7 @@ pub fn count_distinct(expr: Expr) -> Expr {
261267
true,
262268
None,
263269
None,
270+
None,
264271
))
265272
}
266273

@@ -313,6 +320,7 @@ pub fn approx_distinct(expr: Expr) -> Expr {
313320
false,
314321
None,
315322
None,
323+
None,
316324
))
317325
}
318326

@@ -324,6 +332,7 @@ pub fn median(expr: Expr) -> Expr {
324332
false,
325333
None,
326334
None,
335+
None,
327336
))
328337
}
329338

@@ -335,6 +344,7 @@ pub fn approx_median(expr: Expr) -> Expr {
335344
false,
336345
None,
337346
None,
347+
None,
338348
))
339349
}
340350

@@ -346,6 +356,7 @@ pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
346356
false,
347357
None,
348358
None,
359+
None,
349360
))
350361
}
351362

@@ -361,6 +372,7 @@ pub fn approx_percentile_cont_with_weight(
361372
false,
362373
None,
363374
None,
375+
None,
364376
))
365377
}
366378

@@ -431,6 +443,7 @@ pub fn stddev(expr: Expr) -> Expr {
431443
false,
432444
None,
433445
None,
446+
None,
434447
))
435448
}
436449

datafusion/expr/src/tree_node/expr.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ impl TreeNode for Expr {
350350
distinct,
351351
filter,
352352
order_by,
353+
null_treatment,
353354
}) => transform_vec(args, &mut f)?
354355
.update_data(|new_args| (new_args, filter, order_by))
355356
.try_transform_node(|(new_args, filter, order_by)| {
@@ -368,6 +369,7 @@ impl TreeNode for Expr {
368369
distinct,
369370
new_filter,
370371
new_order_by,
372+
null_treatment,
371373
)))
372374
}
373375
AggregateFunctionDefinition::UDF(fun) => {

datafusion/optimizer/src/analyzer/count_wildcard_rule.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ impl TreeNodeRewriter for CountWildcardRewriter {
158158
distinct,
159159
filter,
160160
order_by,
161+
null_treatment,
161162
}) if args.len() == 1 => match args[0] {
162163
Expr::Wildcard { qualifier: None } => {
163164
Transformed::yes(Expr::AggregateFunction(AggregateFunction::new(
@@ -166,6 +167,7 @@ impl TreeNodeRewriter for CountWildcardRewriter {
166167
distinct,
167168
filter,
168169
order_by,
170+
null_treatment,
169171
)))
170172
}
171173
_ => Transformed::no(old_expr),

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
345345
distinct,
346346
filter,
347347
order_by,
348+
null_treatment,
348349
}) => match func_def {
349350
AggregateFunctionDefinition::BuiltIn(fun) => {
350351
let new_expr = coerce_agg_exprs_for_signature(
@@ -355,7 +356,12 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
355356
)?;
356357
Ok(Transformed::yes(Expr::AggregateFunction(
357358
expr::AggregateFunction::new(
358-
fun, new_expr, distinct, filter, order_by,
359+
fun,
360+
new_expr,
361+
distinct,
362+
filter,
363+
order_by,
364+
null_treatment,
359365
),
360366
)))
361367
}
@@ -946,6 +952,7 @@ mod test {
946952
false,
947953
None,
948954
None,
955+
None,
949956
));
950957
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
951958
let expected = "Projection: AVG(CAST(Int64(12) AS Float64))\n EmptyRelation";
@@ -959,6 +966,7 @@ mod test {
959966
false,
960967
None,
961968
None,
969+
None,
962970
));
963971
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
964972
let expected = "Projection: AVG(CAST(a AS Float64))\n EmptyRelation";
@@ -976,6 +984,7 @@ mod test {
976984
false,
977985
None,
978986
None,
987+
None,
979988
));
980989
let err = Projection::try_new(vec![agg_expr], empty)
981990
.err()
@@ -998,6 +1007,7 @@ mod test {
9981007
false,
9991008
None,
10001009
None,
1010+
None,
10011011
));
10021012

10031013
let err = Projection::try_new(vec![agg_expr], empty)

datafusion/optimizer/src/push_down_projection.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ mod tests {
545545
false,
546546
Some(Box::new(col("c").gt(lit(42)))),
547547
None,
548+
None,
548549
));
549550

550551
let plan = LogicalPlanBuilder::from(table_scan)

datafusion/optimizer/src/replace_distinct_aggregate.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
9797
false,
9898
None,
9999
sort_expr.clone(),
100+
None,
100101
))
101102
})
102103
.collect::<Vec<Expr>>();

0 commit comments

Comments
 (0)