Skip to content

Commit 6993561

Browse files
authored
fix: Be more lenient in interpreting input args for builtin window functions (#11199)
* fix: Be more lenient in interpreting input args for builtin window functions The built-in window functions Lag, Lead, NthValue, Ntile accept integer arguments. However while they should allow any integers, currently as they just use ScalarValue's try_from to convert into an i64, they actually only accept i64s. Any other argument, e.g. an i32, would be converted into a None and ignored. Before - lag and lead would silently ignore the argument, ntile and nth_value would fail: ``` > SELECT id, lead(id, -1) OVER (ORDER BY id) AS correct, lead(id, arrow_cast(-1,'Int32')) OVER (ORDER BY id) as wrong from (values (1), (2)) as tbl(id); +----+---------+-------+ | id | correct | wrong | +----+---------+-------+ | 1 | | 2 | | 2 | 1 | | +----+---------+-------+ > SELECT id, lag(id, -1) OVER (ORDER BY id) AS correct, lag(id, arrow_cast(-1,'Int32')) OVER (ORDER BY id) as wrong from (values (1), (2)) as tbl(id); +----+---------+-------+ | id | correct | wrong | +----+---------+-------+ | 1 | 2 | | | 2 | | 1 | +----+---------+-------+ > SELECT id, nth_value(id, 2) OVER (ORDER BY id) AS correct, nth_value(id, arrow_cast(2,'Int32')) OVER (ORDER BY id) as corrected from (values (1), (2)) as tbl(id); Execution error: Internal("Cannot convert Int32(2) to i64") > SELECT id, ntile(2) OVER (ORDER BY id) AS correct, ntile(arrow_cast(2,'Int32')) OVER (ORDER BY id) as corrected from (values (1), (2)) as tbl(id); Internal error: Cannot convert Int32(2) to i64. This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker ``` After - all four produce expected results: ``` SELECT id, lead(id, -1) OVER (ORDER BY id) AS correct, lead(id, arrow_cast(-1,'Int32')) OVER (ORDER BY id) as corrected from (values (1), (2)) as tbl(id) +----+---------+-----------+ | id | correct | corrected | +----+---------+-----------+ | 1 | | | | 2 | 1 | 1 | +----+---------+-----------+ SELECT id, lag(id, -1) OVER (ORDER BY id) AS correct, lag(id, arrow_cast(-1,'Int32')) OVER (ORDER BY id) as corrected from (values (1), (2)) as tbl(id) +----+---------+-----------+ | id | correct | corrected | +----+---------+-----------+ | 1 | 2 | 2 | | 2 | | | +----+---------+-----------+ SELECT id, nth_value(id, 2) OVER (ORDER BY id) AS correct, nth_value(id, arrow_cast(2,'Int32')) OVER (ORDER BY id) as corrected from (values (1), (2)) as tbl(id) +----+---------+-----------+ | id | correct | corrected | +----+---------+-----------+ | 1 | | | | 2 | 2 | 2 | +----+---------+-----------+ SELECT id, ntile(2) OVER (ORDER BY id) AS correct, ntile(arrow_cast(2,'Int32')) OVER (ORDER BY id) as corrected from (values (1), (2)) as tbl(id) +----+---------+-----------+ | id | correct | corrected | +----+---------+-----------+ | 1 | 1 | 1 | | 2 | 2 | 2 | +----+---------+-----------+ ``` * cleanup * make lead/lag throw if arg is invalid, check that the arg is int before casting, add tests * return unsigned handling to ntile and move tests to sqllogictests window.slt * remove unused import
1 parent c049a94 commit 6993561

File tree

2 files changed

+69
-12
lines changed

2 files changed

+69
-12
lines changed

datafusion/physical-plan/src/windows/mod.rs

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,24 @@ fn get_scalar_value_from_args(
220220
})
221221
}
222222

223+
fn get_signed_integer(value: ScalarValue) -> Result<i64> {
224+
if !value.data_type().is_integer() {
225+
return Err(DataFusionError::Execution(
226+
"Expected an integer value".to_string(),
227+
));
228+
}
229+
value.cast_to(&DataType::Int64)?.try_into()
230+
}
231+
232+
fn get_unsigned_integer(value: ScalarValue) -> Result<u64> {
233+
if !value.data_type().is_integer() {
234+
return Err(DataFusionError::Execution(
235+
"Expected an integer value".to_string(),
236+
));
237+
}
238+
value.cast_to(&DataType::UInt64)?.try_into()
239+
}
240+
223241
fn get_casted_value(
224242
default_value: Option<ScalarValue>,
225243
dtype: &DataType,
@@ -259,10 +277,10 @@ fn create_built_in_window_expr(
259277
}
260278

261279
if n.is_unsigned() {
262-
let n: u64 = n.try_into()?;
280+
let n = get_unsigned_integer(n)?;
263281
Arc::new(Ntile::new(name, n, out_data_type))
264282
} else {
265-
let n: i64 = n.try_into()?;
283+
let n: i64 = get_signed_integer(n)?;
266284
if n <= 0 {
267285
return exec_err!("NTILE requires a positive integer");
268286
}
@@ -272,8 +290,8 @@ fn create_built_in_window_expr(
272290
BuiltInWindowFunction::Lag => {
273291
let arg = args[0].clone();
274292
let shift_offset = get_scalar_value_from_args(args, 1)?
275-
.map(|v| v.try_into())
276-
.and_then(|v| v.ok());
293+
.map(get_signed_integer)
294+
.map_or(Ok(None), |v| v.map(Some))?;
277295
let default_value =
278296
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
279297
Arc::new(lag(
@@ -288,8 +306,8 @@ fn create_built_in_window_expr(
288306
BuiltInWindowFunction::Lead => {
289307
let arg = args[0].clone();
290308
let shift_offset = get_scalar_value_from_args(args, 1)?
291-
.map(|v| v.try_into())
292-
.and_then(|v| v.ok());
309+
.map(get_signed_integer)
310+
.map_or(Ok(None), |v| v.map(Some))?;
293311
let default_value =
294312
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
295313
Arc::new(lead(
@@ -303,11 +321,14 @@ fn create_built_in_window_expr(
303321
}
304322
BuiltInWindowFunction::NthValue => {
305323
let arg = args[0].clone();
306-
let n = args[1].as_any().downcast_ref::<Literal>().unwrap().value();
307-
let n: i64 = n
308-
.clone()
309-
.try_into()
310-
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
324+
let n = get_signed_integer(
325+
args[1]
326+
.as_any()
327+
.downcast_ref::<Literal>()
328+
.unwrap()
329+
.value()
330+
.clone(),
331+
)?;
311332
Arc::new(NthValue::nth(
312333
name,
313334
arg,
@@ -618,7 +639,6 @@ mod tests {
618639

619640
use datafusion_functions_aggregate::count::count_udaf;
620641
use futures::FutureExt;
621-
622642
use InputOrderMode::{Linear, PartiallySorted, Sorted};
623643

624644
fn create_test_schema() -> Result<SchemaRef> {

datafusion/sqllogictest/test_files/window.slt

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4830,6 +4830,8 @@ NULL 3
48304830
NULL 2
48314831
NULL 1
48324832

4833+
statement ok
4834+
drop table t
48334835

48344836
### Test for window functions with arrays
48354837
statement ok
@@ -4852,3 +4854,38 @@ c [4, 5, 6] NULL
48524854

48534855
statement ok
48544856
drop table array_data
4857+
4858+
# Test for non-i64 offsets for NTILE, LAG, LEAD, NTH_VALUE
4859+
statement ok
4860+
CREATE TABLE t AS VALUES (3, 3), (4, 4), (5, 5), (6, 6);
4861+
4862+
query IIIIIIIII
4863+
SELECT
4864+
column1,
4865+
ntile(2) OVER (order by column1),
4866+
ntile(arrow_cast(2, 'Int32')) OVER (order by column1),
4867+
lag(column2, -1) OVER (order by column1),
4868+
lag(column2, arrow_cast(-1, 'Int32')) OVER (order by column1),
4869+
lead(column2, -1) OVER (order by column1),
4870+
lead(column2, arrow_cast(-1, 'Int32')) OVER (order by column1),
4871+
nth_value(column2, 2) OVER (order by column1),
4872+
nth_value(column2, arrow_cast(2, 'Int32')) OVER (order by column1)
4873+
FROM t;
4874+
----
4875+
3 1 1 4 4 NULL NULL NULL NULL
4876+
4 1 1 5 5 3 3 4 4
4877+
5 2 2 6 6 4 4 4 4
4878+
6 2 2 NULL NULL 5 5 4 4
4879+
4880+
# NTILE specifies the argument types so the error is different
4881+
query error
4882+
SELECT ntile(1.1) OVER (order by column1) FROM t;
4883+
4884+
query error DataFusion error: Execution error: Expected an integer value
4885+
SELECT lag(column2, 1.1) OVER (order by column1) FROM t;
4886+
4887+
query error DataFusion error: Execution error: Expected an integer value
4888+
SELECT lead(column2, 1.1) OVER (order by column1) FROM t;
4889+
4890+
query error DataFusion error: Execution error: Expected an integer value
4891+
SELECT nth_value(column2, 1.1) OVER (order by column1) FROM t;

0 commit comments

Comments
 (0)