Skip to content

Commit 44127ec

Browse files
authored
Fix: handle NULL input in lead/lag window function (#12811)
1 parent db85d07 commit 44127ec

File tree

2 files changed

+94
-7
lines changed

2 files changed

+94
-7
lines changed

datafusion/physical-plan/src/windows/mod.rs

+45-6
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,41 @@ fn get_casted_value(
217217
}
218218
}
219219

220+
/// Rewrites the NULL expression (1st argument) with an expression
221+
/// which is the same data type as the default value (3rd argument).
222+
/// Also rewrites the return type with the same data type as the
223+
/// default value.
224+
///
225+
/// If a default value is not provided, or it is NULL the original
226+
/// expression (1st argument) and return type is returned without
227+
/// any modifications.
228+
fn rewrite_null_expr_and_data_type(
229+
args: &[Arc<dyn PhysicalExpr>],
230+
expr_type: &DataType,
231+
) -> Result<(Arc<dyn PhysicalExpr>, DataType)> {
232+
assert!(!args.is_empty());
233+
let expr = Arc::clone(&args[0]);
234+
235+
// The input expression and the return is type is unchanged
236+
// when the input expression is not NULL.
237+
if !expr_type.is_null() {
238+
return Ok((expr, expr_type.clone()));
239+
}
240+
241+
get_scalar_value_from_args(args, 2)?
242+
.and_then(|value| {
243+
ScalarValue::try_from(value.data_type().clone())
244+
.map(|sv| {
245+
Ok((
246+
Arc::new(Literal::new(sv)) as Arc<dyn PhysicalExpr>,
247+
value.data_type().clone(),
248+
))
249+
})
250+
.ok()
251+
})
252+
.unwrap_or(Ok((expr, expr_type.clone())))
253+
}
254+
220255
fn create_built_in_window_expr(
221256
fun: &BuiltInWindowFunction,
222257
args: &[Arc<dyn PhysicalExpr>],
@@ -252,31 +287,35 @@ fn create_built_in_window_expr(
252287
}
253288
}
254289
BuiltInWindowFunction::Lag => {
255-
let arg = Arc::clone(&args[0]);
290+
// rewrite NULL expression and the return datatype
291+
let (arg, out_data_type) =
292+
rewrite_null_expr_and_data_type(args, out_data_type)?;
256293
let shift_offset = get_scalar_value_from_args(args, 1)?
257294
.map(get_signed_integer)
258295
.map_or(Ok(None), |v| v.map(Some))?;
259296
let default_value =
260-
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
297+
get_casted_value(get_scalar_value_from_args(args, 2)?, &out_data_type)?;
261298
Arc::new(lag(
262299
name,
263-
out_data_type.clone(),
300+
default_value.data_type().clone(),
264301
arg,
265302
shift_offset,
266303
default_value,
267304
ignore_nulls,
268305
))
269306
}
270307
BuiltInWindowFunction::Lead => {
271-
let arg = Arc::clone(&args[0]);
308+
// rewrite NULL expression and the return datatype
309+
let (arg, out_data_type) =
310+
rewrite_null_expr_and_data_type(args, out_data_type)?;
272311
let shift_offset = get_scalar_value_from_args(args, 1)?
273312
.map(get_signed_integer)
274313
.map_or(Ok(None), |v| v.map(Some))?;
275314
let default_value =
276-
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
315+
get_casted_value(get_scalar_value_from_args(args, 2)?, &out_data_type)?;
277316
Arc::new(lead(
278317
name,
279-
out_data_type.clone(),
318+
default_value.data_type().clone(),
280319
arg,
281320
shift_offset,
282321
default_value,

datafusion/sqllogictest/test_files/window.slt

+49-1
Original file line numberDiff line numberDiff line change
@@ -4941,4 +4941,52 @@ NULL
49414941
statement ok
49424942
DROP TABLE t;
49434943

4944-
## end test handle NULL and 0 of NTH_VALUE
4944+
## end test handle NULL and 0 of NTH_VALUE
4945+
4946+
## test handle NULL of lead
4947+
4948+
statement ok
4949+
create table t1(v1 int);
4950+
4951+
statement ok
4952+
insert into t1 values (1);
4953+
4954+
query B
4955+
SELECT LEAD(NULL, 0, false) OVER () FROM t1;
4956+
----
4957+
NULL
4958+
4959+
query B
4960+
SELECT LAG(NULL, 0, false) OVER () FROM t1;
4961+
----
4962+
NULL
4963+
4964+
query B
4965+
SELECT LEAD(NULL, 1, false) OVER () FROM t1;
4966+
----
4967+
false
4968+
4969+
query B
4970+
SELECT LAG(NULL, 1, false) OVER () FROM t1;
4971+
----
4972+
false
4973+
4974+
statement ok
4975+
insert into t1 values (2);
4976+
4977+
query B
4978+
SELECT LEAD(NULL, 1, false) OVER () FROM t1;
4979+
----
4980+
NULL
4981+
false
4982+
4983+
query B
4984+
SELECT LAG(NULL, 1, false) OVER () FROM t1;
4985+
----
4986+
false
4987+
NULL
4988+
4989+
statement ok
4990+
DROP TABLE t1;
4991+
4992+
## end test handle NULL of lead

0 commit comments

Comments
 (0)