Skip to content

Commit 0c0fce3

Browse files
authored
LEAD/LAG calculate default value once (#9485)
* LEAD/LAG calculate default value once * refmt
1 parent 37b7375 commit 0c0fce3

File tree

3 files changed

+38
-54
lines changed

3 files changed

+38
-54
lines changed

datafusion/physical-expr/src/window/lead_lag.rs

Lines changed: 18 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
use crate::window::BuiltInWindowFunctionExpr;
2222
use crate::PhysicalExpr;
2323
use arrow::array::ArrayRef;
24-
use arrow::compute::cast;
2524
use arrow::datatypes::{DataType, Field};
2625
use arrow_array::Array;
2726
use datafusion_common::{
@@ -42,7 +41,7 @@ pub struct WindowShift {
4241
data_type: DataType,
4342
shift_offset: i64,
4443
expr: Arc<dyn PhysicalExpr>,
45-
default_value: Option<ScalarValue>,
44+
default_value: ScalarValue,
4645
ignore_nulls: bool,
4746
}
4847

@@ -53,7 +52,7 @@ impl WindowShift {
5352
}
5453

5554
/// Get the default_value for window shift expression.
56-
pub fn get_default_value(&self) -> Option<ScalarValue> {
55+
pub fn get_default_value(&self) -> ScalarValue {
5756
self.default_value.clone()
5857
}
5958
}
@@ -64,7 +63,7 @@ pub fn lead(
6463
data_type: DataType,
6564
expr: Arc<dyn PhysicalExpr>,
6665
shift_offset: Option<i64>,
67-
default_value: Option<ScalarValue>,
66+
default_value: ScalarValue,
6867
ignore_nulls: bool,
6968
) -> WindowShift {
7069
WindowShift {
@@ -83,7 +82,7 @@ pub fn lag(
8382
data_type: DataType,
8483
expr: Arc<dyn PhysicalExpr>,
8584
shift_offset: Option<i64>,
86-
default_value: Option<ScalarValue>,
85+
default_value: ScalarValue,
8786
ignore_nulls: bool,
8887
) -> WindowShift {
8988
WindowShift {
@@ -139,7 +138,7 @@ impl BuiltInWindowFunctionExpr for WindowShift {
139138
#[derive(Debug)]
140139
pub(crate) struct WindowShiftEvaluator {
141140
shift_offset: i64,
142-
default_value: Option<ScalarValue>,
141+
default_value: ScalarValue,
143142
ignore_nulls: bool,
144143
// VecDeque contains offset values that between non-null entries
145144
non_null_offsets: VecDeque<usize>,
@@ -152,45 +151,28 @@ impl WindowShiftEvaluator {
152151
}
153152
}
154153

155-
fn create_empty_array(
156-
value: Option<&ScalarValue>,
157-
data_type: &DataType,
158-
size: usize,
159-
) -> Result<ArrayRef> {
160-
use arrow::array::new_null_array;
161-
let array = value
162-
.as_ref()
163-
.map(|scalar| scalar.to_array_of_size(size))
164-
.transpose()?
165-
.unwrap_or_else(|| new_null_array(data_type, size));
166-
if array.data_type() != data_type {
167-
cast(&array, data_type).map_err(|e| arrow_datafusion_err!(e))
168-
} else {
169-
Ok(array)
170-
}
171-
}
172-
173154
// TODO: change the original arrow::compute::kernels::window::shift impl to support an optional default value
174155
fn shift_with_default_value(
175156
array: &ArrayRef,
176157
offset: i64,
177-
value: Option<&ScalarValue>,
158+
default_value: &ScalarValue,
178159
) -> Result<ArrayRef> {
179160
use arrow::compute::concat;
180161

181162
let value_len = array.len() as i64;
182163
if offset == 0 {
183164
Ok(array.clone())
184165
} else if offset == i64::MIN || offset.abs() >= value_len {
185-
create_empty_array(value, array.data_type(), array.len())
166+
default_value.to_array_of_size(value_len as usize)
186167
} else {
187168
let slice_offset = (-offset).clamp(0, value_len) as usize;
188169
let length = array.len() - offset.unsigned_abs() as usize;
189170
let slice = array.slice(slice_offset, length);
190171

191172
// Generate array with remaining `null` items
192173
let nulls = offset.unsigned_abs() as usize;
193-
let default_values = create_empty_array(value, slice.data_type(), nulls)?;
174+
let default_values = default_value.to_array_of_size(nulls)?;
175+
194176
// Concatenate both arrays, add nulls after if shift > 0 else before
195177
if offset > 0 {
196178
concat(&[default_values.as_ref(), slice.as_ref()])
@@ -236,9 +218,7 @@ impl PartitionEvaluator for WindowShiftEvaluator {
236218
values: &[ArrayRef],
237219
range: &Range<usize>,
238220
) -> Result<ScalarValue> {
239-
// TODO: do not recalculate default value every call
240221
let array = &values[0];
241-
let dtype = array.data_type();
242222
let len = array.len();
243223

244224
// LAG mode
@@ -334,10 +314,10 @@ impl PartitionEvaluator for WindowShiftEvaluator {
334314
// - ignore nulls mode and current value is null and is within window bounds
335315
// .unwrap() is safe here as there is a none check in front
336316
#[allow(clippy::unnecessary_unwrap)]
337-
if idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap())) {
338-
get_default_value(self.default_value.as_ref(), dtype)
339-
} else {
317+
if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) {
340318
ScalarValue::try_from_array(array, idx.unwrap())
319+
} else {
320+
Ok(self.default_value.clone())
341321
}
342322
}
343323

@@ -353,25 +333,14 @@ impl PartitionEvaluator for WindowShiftEvaluator {
353333
}
354334
// LEAD, LAG window functions take single column, values will have size 1
355335
let value = &values[0];
356-
shift_with_default_value(value, self.shift_offset, self.default_value.as_ref())
336+
shift_with_default_value(value, self.shift_offset, &self.default_value)
357337
}
358338

359339
fn supports_bounded_execution(&self) -> bool {
360340
true
361341
}
362342
}
363343

364-
fn get_default_value(
365-
default_value: Option<&ScalarValue>,
366-
dtype: &DataType,
367-
) -> Result<ScalarValue> {
368-
match default_value {
369-
Some(v) if !v.data_type().is_null() => v.cast_to(dtype),
370-
// If None or Null datatype
371-
_ => ScalarValue::try_from(dtype),
372-
}
373-
}
374-
375344
#[cfg(test)]
376345
mod tests {
377346
use super::*;
@@ -400,10 +369,10 @@ mod tests {
400369
test_i32_result(
401370
lead(
402371
"lead".to_owned(),
403-
DataType::Float32,
372+
DataType::Int32,
404373
Arc::new(Column::new("c3", 0)),
405374
None,
406-
None,
375+
ScalarValue::Null.cast_to(&DataType::Int32)?,
407376
false,
408377
),
409378
[
@@ -423,10 +392,10 @@ mod tests {
423392
test_i32_result(
424393
lag(
425394
"lead".to_owned(),
426-
DataType::Float32,
395+
DataType::Int32,
427396
Arc::new(Column::new("c3", 0)),
428397
None,
429-
None,
398+
ScalarValue::Null.cast_to(&DataType::Int32)?,
430399
false,
431400
),
432401
[
@@ -449,7 +418,7 @@ mod tests {
449418
DataType::Int32,
450419
Arc::new(Column::new("c3", 0)),
451420
None,
452-
Some(ScalarValue::Int32(Some(100))),
421+
ScalarValue::Int32(Some(100)),
453422
false,
454423
),
455424
[

datafusion/physical-plan/src/windows/mod.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,17 @@ fn get_scalar_value_from_args(
156156
})
157157
}
158158

159+
fn get_casted_value(
160+
default_value: Option<ScalarValue>,
161+
dtype: &DataType,
162+
) -> Result<ScalarValue> {
163+
match default_value {
164+
Some(v) if !v.data_type().is_null() => v.cast_to(dtype),
165+
// If None or Null datatype
166+
_ => ScalarValue::try_from(dtype),
167+
}
168+
}
169+
159170
fn create_built_in_window_expr(
160171
fun: &BuiltInWindowFunction,
161172
args: &[Arc<dyn PhysicalExpr>],
@@ -204,7 +215,8 @@ fn create_built_in_window_expr(
204215
let shift_offset = get_scalar_value_from_args(args, 1)?
205216
.map(|v| v.try_into())
206217
.and_then(|v| v.ok());
207-
let default_value = get_scalar_value_from_args(args, 2)?;
218+
let default_value =
219+
get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?;
208220
Arc::new(lag(
209221
name,
210222
data_type.clone(),
@@ -219,7 +231,8 @@ fn create_built_in_window_expr(
219231
let shift_offset = get_scalar_value_from_args(args, 1)?
220232
.map(|v| v.try_into())
221233
.and_then(|v| v.ok());
222-
let default_value = get_scalar_value_from_args(args, 2)?;
234+
let default_value =
235+
get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?;
223236
Arc::new(lead(
224237
name,
225238
data_type.clone(),

datafusion/proto/src/physical_plan/to_proto.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,11 @@ impl TryFrom<Arc<dyn WindowExpr>> for protobuf::PhysicalWindowExprNode {
167167
window_shift_expr.get_shift_offset(),
168168
)))),
169169
);
170-
if let Some(default_value) = window_shift_expr.get_default_value() {
171-
args.insert(2, Arc::new(Literal::new(default_value)));
172-
}
170+
args.insert(
171+
2,
172+
Arc::new(Literal::new(window_shift_expr.get_default_value())),
173+
);
174+
173175
if window_shift_expr.get_shift_offset() >= 0 {
174176
protobuf::BuiltInWindowFunction::Lag
175177
} else {

0 commit comments

Comments
 (0)