@@ -165,8 +165,6 @@ struct FirstValueAccumulator {
165
165
orderings : Vec < ScalarValue > ,
166
166
// Stores the applicable ordering requirement.
167
167
ordering_req : LexOrdering ,
168
- // Whether merge_batch() is called before
169
- is_merge_called : bool ,
170
168
}
171
169
172
170
impl FirstValueAccumulator {
@@ -185,7 +183,6 @@ impl FirstValueAccumulator {
185
183
is_set : false ,
186
184
orderings,
187
185
ordering_req,
188
- is_merge_called : false ,
189
186
} )
190
187
}
191
188
@@ -201,9 +198,7 @@ impl Accumulator for FirstValueAccumulator {
201
198
fn state ( & self ) -> Result < Vec < ScalarValue > > {
202
199
let mut result = vec ! [ self . first. clone( ) ] ;
203
200
result. extend ( self . orderings . iter ( ) . cloned ( ) ) ;
204
- if !self . is_merge_called {
205
- result. push ( ScalarValue :: Boolean ( Some ( self . is_set ) ) ) ;
206
- }
201
+ result. push ( ScalarValue :: Boolean ( Some ( self . is_set ) ) ) ;
207
202
Ok ( result)
208
203
}
209
204
@@ -218,7 +213,6 @@ impl Accumulator for FirstValueAccumulator {
218
213
}
219
214
220
215
fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
221
- self . is_merge_called = true ;
222
216
// FIRST_VALUE(first1, first2, first3, ...)
223
217
// last index contains is_set flag.
224
218
let is_set_idx = states. len ( ) - 1 ;
@@ -237,13 +231,17 @@ impl Accumulator for FirstValueAccumulator {
237
231
} ;
238
232
if !ordered_states[ 0 ] . is_empty ( ) {
239
233
let first_row = get_row_at_idx ( & ordered_states, 0 ) ?;
240
- let first_ordering = & first_row[ 1 ..] ;
234
+ // When collecting orderings, we exclude the is_set flag from the state.
235
+ let first_ordering = & first_row[ 1 ..is_set_idx] ;
241
236
let sort_options = get_sort_options ( & self . ordering_req ) ;
242
237
// Either there is no existing value, or there is an earlier version in new data.
243
238
if !self . is_set
244
239
|| compare_rows ( first_ordering, & self . orderings , & sort_options) ?. is_lt ( )
245
240
{
246
- self . update_with_new_row ( & first_row) ;
241
+ // Update with first value in the state. Note that we should exclude the
242
+ // is_set flag from the state. Otherwise, we will end up with a state
243
+ // containing two is_set flags.
244
+ self . update_with_new_row ( & first_row[ 0 ..is_set_idx] ) ;
247
245
}
248
246
}
249
247
Ok ( ( ) )
@@ -390,8 +388,6 @@ struct LastValueAccumulator {
390
388
orderings : Vec < ScalarValue > ,
391
389
// Stores the applicable ordering requirement.
392
390
ordering_req : LexOrdering ,
393
- // Whether merge_batch() is called before
394
- is_merge_called : bool ,
395
391
}
396
392
397
393
impl LastValueAccumulator {
@@ -410,7 +406,6 @@ impl LastValueAccumulator {
410
406
is_set : false ,
411
407
orderings,
412
408
ordering_req,
413
- is_merge_called : false ,
414
409
} )
415
410
}
416
411
@@ -426,9 +421,7 @@ impl Accumulator for LastValueAccumulator {
426
421
fn state ( & self ) -> Result < Vec < ScalarValue > > {
427
422
let mut result = vec ! [ self . last. clone( ) ] ;
428
423
result. extend ( self . orderings . clone ( ) ) ;
429
- if !self . is_merge_called {
430
- result. push ( ScalarValue :: Boolean ( Some ( self . is_set ) ) ) ;
431
- }
424
+ result. push ( ScalarValue :: Boolean ( Some ( self . is_set ) ) ) ;
432
425
Ok ( result)
433
426
}
434
427
@@ -442,7 +435,6 @@ impl Accumulator for LastValueAccumulator {
442
435
}
443
436
444
437
fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
445
- self . is_merge_called = true ;
446
438
// LAST_VALUE(last1, last2, last3, ...)
447
439
// last index contains is_set flag.
448
440
let is_set_idx = states. len ( ) - 1 ;
@@ -463,14 +455,18 @@ impl Accumulator for LastValueAccumulator {
463
455
if !ordered_states[ 0 ] . is_empty ( ) {
464
456
let last_idx = ordered_states[ 0 ] . len ( ) - 1 ;
465
457
let last_row = get_row_at_idx ( & ordered_states, last_idx) ?;
466
- let last_ordering = & last_row[ 1 ..] ;
458
+ // When collecting orderings, we exclude the is_set flag from the state.
459
+ let last_ordering = & last_row[ 1 ..is_set_idx] ;
467
460
let sort_options = get_sort_options ( & self . ordering_req ) ;
468
461
// Either there is no existing value, or there is a newer (latest)
469
462
// version in the new data:
470
463
if !self . is_set
471
464
|| compare_rows ( last_ordering, & self . orderings , & sort_options) ?. is_gt ( )
472
465
{
473
- self . update_with_new_row ( & last_row) ;
466
+ // Update with last value in the state. Note that we should exclude the
467
+ // is_set flag from the state. Otherwise, we will end up with a state
468
+ // containing two is_set flags.
469
+ self . update_with_new_row ( & last_row[ 0 ..is_set_idx] ) ;
474
470
}
475
471
}
476
472
Ok ( ( ) )
@@ -531,6 +527,7 @@ mod tests {
531
527
use datafusion_common:: { Result , ScalarValue } ;
532
528
use datafusion_expr:: Accumulator ;
533
529
530
+ use arrow:: compute:: concat;
534
531
use std:: sync:: Arc ;
535
532
536
533
#[ test]
@@ -562,4 +559,72 @@ mod tests {
562
559
assert_eq ! ( last_accumulator. evaluate( ) ?, ScalarValue :: Int64 ( Some ( 12 ) ) ) ;
563
560
Ok ( ( ) )
564
561
}
562
+
563
+ #[ test]
564
+ fn test_first_last_state_after_merge ( ) -> Result < ( ) > {
565
+ let ranges: Vec < ( i64 , i64 ) > = vec ! [ ( 0 , 10 ) , ( 1 , 11 ) , ( 2 , 13 ) ] ;
566
+ // create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 2 to 12
567
+ let arrs = ranges
568
+ . into_iter ( )
569
+ . map ( |( start, end) | {
570
+ Arc :: new ( ( start..end) . collect :: < Int64Array > ( ) ) as ArrayRef
571
+ } )
572
+ . collect :: < Vec < _ > > ( ) ;
573
+
574
+ // FirstValueAccumulator
575
+ let mut first_accumulator =
576
+ FirstValueAccumulator :: try_new ( & DataType :: Int64 , & [ ] , vec ! [ ] ) ?;
577
+
578
+ first_accumulator. update_batch ( & [ arrs[ 0 ] . clone ( ) ] ) ?;
579
+ let state1 = first_accumulator. state ( ) ?;
580
+
581
+ let mut first_accumulator =
582
+ FirstValueAccumulator :: try_new ( & DataType :: Int64 , & [ ] , vec ! [ ] ) ?;
583
+ first_accumulator. update_batch ( & [ arrs[ 1 ] . clone ( ) ] ) ?;
584
+ let state2 = first_accumulator. state ( ) ?;
585
+
586
+ assert_eq ! ( state1. len( ) , state2. len( ) ) ;
587
+
588
+ let mut states = vec ! [ ] ;
589
+
590
+ for idx in 0 ..state1. len ( ) {
591
+ states. push ( concat ( & [ & state1[ idx] . to_array ( ) , & state2[ idx] . to_array ( ) ] ) ?) ;
592
+ }
593
+
594
+ let mut first_accumulator =
595
+ FirstValueAccumulator :: try_new ( & DataType :: Int64 , & [ ] , vec ! [ ] ) ?;
596
+ first_accumulator. merge_batch ( & states) ?;
597
+
598
+ let merged_state = first_accumulator. state ( ) ?;
599
+ assert_eq ! ( merged_state. len( ) , state1. len( ) ) ;
600
+
601
+ // LastValueAccumulator
602
+ let mut last_accumulator =
603
+ LastValueAccumulator :: try_new ( & DataType :: Int64 , & [ ] , vec ! [ ] ) ?;
604
+
605
+ last_accumulator. update_batch ( & [ arrs[ 0 ] . clone ( ) ] ) ?;
606
+ let state1 = last_accumulator. state ( ) ?;
607
+
608
+ let mut last_accumulator =
609
+ LastValueAccumulator :: try_new ( & DataType :: Int64 , & [ ] , vec ! [ ] ) ?;
610
+ last_accumulator. update_batch ( & [ arrs[ 1 ] . clone ( ) ] ) ?;
611
+ let state2 = last_accumulator. state ( ) ?;
612
+
613
+ assert_eq ! ( state1. len( ) , state2. len( ) ) ;
614
+
615
+ let mut states = vec ! [ ] ;
616
+
617
+ for idx in 0 ..state1. len ( ) {
618
+ states. push ( concat ( & [ & state1[ idx] . to_array ( ) , & state2[ idx] . to_array ( ) ] ) ?) ;
619
+ }
620
+
621
+ let mut last_accumulator =
622
+ LastValueAccumulator :: try_new ( & DataType :: Int64 , & [ ] , vec ! [ ] ) ?;
623
+ last_accumulator. merge_batch ( & states) ?;
624
+
625
+ let merged_state = last_accumulator. state ( ) ?;
626
+ assert_eq ! ( merged_state. len( ) , state1. len( ) ) ;
627
+
628
+ Ok ( ( ) )
629
+ }
565
630
}
0 commit comments