21
21
use crate :: window:: BuiltInWindowFunctionExpr ;
22
22
use crate :: PhysicalExpr ;
23
23
use arrow:: array:: ArrayRef ;
24
- use arrow:: compute:: cast;
25
24
use arrow:: datatypes:: { DataType , Field } ;
26
25
use arrow_array:: Array ;
27
26
use datafusion_common:: {
@@ -42,7 +41,7 @@ pub struct WindowShift {
42
41
data_type : DataType ,
43
42
shift_offset : i64 ,
44
43
expr : Arc < dyn PhysicalExpr > ,
45
- default_value : Option < ScalarValue > ,
44
+ default_value : ScalarValue ,
46
45
ignore_nulls : bool ,
47
46
}
48
47
@@ -53,7 +52,7 @@ impl WindowShift {
53
52
}
54
53
55
54
/// Get the default_value for window shift expression.
56
- pub fn get_default_value ( & self ) -> Option < ScalarValue > {
55
+ pub fn get_default_value ( & self ) -> ScalarValue {
57
56
self . default_value . clone ( )
58
57
}
59
58
}
@@ -64,7 +63,7 @@ pub fn lead(
64
63
data_type : DataType ,
65
64
expr : Arc < dyn PhysicalExpr > ,
66
65
shift_offset : Option < i64 > ,
67
- default_value : Option < ScalarValue > ,
66
+ default_value : ScalarValue ,
68
67
ignore_nulls : bool ,
69
68
) -> WindowShift {
70
69
WindowShift {
@@ -83,7 +82,7 @@ pub fn lag(
83
82
data_type : DataType ,
84
83
expr : Arc < dyn PhysicalExpr > ,
85
84
shift_offset : Option < i64 > ,
86
- default_value : Option < ScalarValue > ,
85
+ default_value : ScalarValue ,
87
86
ignore_nulls : bool ,
88
87
) -> WindowShift {
89
88
WindowShift {
@@ -139,7 +138,7 @@ impl BuiltInWindowFunctionExpr for WindowShift {
139
138
#[ derive( Debug ) ]
140
139
pub ( crate ) struct WindowShiftEvaluator {
141
140
shift_offset : i64 ,
142
- default_value : Option < ScalarValue > ,
141
+ default_value : ScalarValue ,
143
142
ignore_nulls : bool ,
144
143
// VecDeque contains offset values that between non-null entries
145
144
non_null_offsets : VecDeque < usize > ,
@@ -152,45 +151,28 @@ impl WindowShiftEvaluator {
152
151
}
153
152
}
154
153
155
- fn create_empty_array (
156
- value : Option < & ScalarValue > ,
157
- data_type : & DataType ,
158
- size : usize ,
159
- ) -> Result < ArrayRef > {
160
- use arrow:: array:: new_null_array;
161
- let array = value
162
- . as_ref ( )
163
- . map ( |scalar| scalar. to_array_of_size ( size) )
164
- . transpose ( ) ?
165
- . unwrap_or_else ( || new_null_array ( data_type, size) ) ;
166
- if array. data_type ( ) != data_type {
167
- cast ( & array, data_type) . map_err ( |e| arrow_datafusion_err ! ( e) )
168
- } else {
169
- Ok ( array)
170
- }
171
- }
172
-
173
154
// TODO: change the original arrow::compute::kernels::window::shift impl to support an optional default value
174
155
fn shift_with_default_value (
175
156
array : & ArrayRef ,
176
157
offset : i64 ,
177
- value : Option < & ScalarValue > ,
158
+ default_value : & ScalarValue ,
178
159
) -> Result < ArrayRef > {
179
160
use arrow:: compute:: concat;
180
161
181
162
let value_len = array. len ( ) as i64 ;
182
163
if offset == 0 {
183
164
Ok ( array. clone ( ) )
184
165
} else if offset == i64:: MIN || offset. abs ( ) >= value_len {
185
- create_empty_array ( value , array . data_type ( ) , array . len ( ) )
166
+ default_value . to_array_of_size ( value_len as usize )
186
167
} else {
187
168
let slice_offset = ( -offset) . clamp ( 0 , value_len) as usize ;
188
169
let length = array. len ( ) - offset. unsigned_abs ( ) as usize ;
189
170
let slice = array. slice ( slice_offset, length) ;
190
171
191
172
// Generate array with remaining `null` items
192
173
let nulls = offset. unsigned_abs ( ) as usize ;
193
- let default_values = create_empty_array ( value, slice. data_type ( ) , nulls) ?;
174
+ let default_values = default_value. to_array_of_size ( nulls) ?;
175
+
194
176
// Concatenate both arrays, add nulls after if shift > 0 else before
195
177
if offset > 0 {
196
178
concat ( & [ default_values. as_ref ( ) , slice. as_ref ( ) ] )
@@ -236,9 +218,7 @@ impl PartitionEvaluator for WindowShiftEvaluator {
236
218
values : & [ ArrayRef ] ,
237
219
range : & Range < usize > ,
238
220
) -> Result < ScalarValue > {
239
- // TODO: do not recalculate default value every call
240
221
let array = & values[ 0 ] ;
241
- let dtype = array. data_type ( ) ;
242
222
let len = array. len ( ) ;
243
223
244
224
// LAG mode
@@ -334,10 +314,10 @@ impl PartitionEvaluator for WindowShiftEvaluator {
334
314
// - ignore nulls mode and current value is null and is within window bounds
335
315
// .unwrap() is safe here as there is a none check in front
336
316
#[ allow( clippy:: unnecessary_unwrap) ]
337
- if idx. is_none ( ) || ( self . ignore_nulls && array. is_null ( idx. unwrap ( ) ) ) {
338
- get_default_value ( self . default_value . as_ref ( ) , dtype)
339
- } else {
317
+ if !( idx. is_none ( ) || ( self . ignore_nulls && array. is_null ( idx. unwrap ( ) ) ) ) {
340
318
ScalarValue :: try_from_array ( array, idx. unwrap ( ) )
319
+ } else {
320
+ Ok ( self . default_value . clone ( ) )
341
321
}
342
322
}
343
323
@@ -353,25 +333,14 @@ impl PartitionEvaluator for WindowShiftEvaluator {
353
333
}
354
334
// LEAD, LAG window functions take single column, values will have size 1
355
335
let value = & values[ 0 ] ;
356
- shift_with_default_value ( value, self . shift_offset , self . default_value . as_ref ( ) )
336
+ shift_with_default_value ( value, self . shift_offset , & self . default_value )
357
337
}
358
338
359
339
fn supports_bounded_execution ( & self ) -> bool {
360
340
true
361
341
}
362
342
}
363
343
364
- fn get_default_value (
365
- default_value : Option < & ScalarValue > ,
366
- dtype : & DataType ,
367
- ) -> Result < ScalarValue > {
368
- match default_value {
369
- Some ( v) if !v. data_type ( ) . is_null ( ) => v. cast_to ( dtype) ,
370
- // If None or Null datatype
371
- _ => ScalarValue :: try_from ( dtype) ,
372
- }
373
- }
374
-
375
344
#[ cfg( test) ]
376
345
mod tests {
377
346
use super :: * ;
@@ -400,10 +369,10 @@ mod tests {
400
369
test_i32_result (
401
370
lead (
402
371
"lead" . to_owned ( ) ,
403
- DataType :: Float32 ,
372
+ DataType :: Int32 ,
404
373
Arc :: new ( Column :: new ( "c3" , 0 ) ) ,
405
374
None ,
406
- None ,
375
+ ScalarValue :: Null . cast_to ( & DataType :: Int32 ) ? ,
407
376
false ,
408
377
) ,
409
378
[
@@ -423,10 +392,10 @@ mod tests {
423
392
test_i32_result (
424
393
lag (
425
394
"lead" . to_owned ( ) ,
426
- DataType :: Float32 ,
395
+ DataType :: Int32 ,
427
396
Arc :: new ( Column :: new ( "c3" , 0 ) ) ,
428
397
None ,
429
- None ,
398
+ ScalarValue :: Null . cast_to ( & DataType :: Int32 ) ? ,
430
399
false ,
431
400
) ,
432
401
[
@@ -449,7 +418,7 @@ mod tests {
449
418
DataType :: Int32 ,
450
419
Arc :: new ( Column :: new ( "c3" , 0 ) ) ,
451
420
None ,
452
- Some ( ScalarValue :: Int32 ( Some ( 100 ) ) ) ,
421
+ ScalarValue :: Int32 ( Some ( 100 ) ) ,
453
422
false ,
454
423
) ,
455
424
[
0 commit comments