Skip to content

Commit 61ed374

Browse files
viiryaalamb
andauthored
First and Last Accumulators should update with state row excluding is_set flag (#7565)
* First and Last Accumulators should update with state row excluding is_set flag * Add test * Update datafusion/physical-expr/src/aggregate/first_last.rs Co-authored-by: Andrew Lamb <[email protected]> * Update datafusion/physical-expr/src/aggregate/first_last.rs * Remove --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 7b12666 commit 61ed374

File tree

1 file changed

+83
-18
lines changed

1 file changed

+83
-18
lines changed

datafusion/physical-expr/src/aggregate/first_last.rs

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,6 @@ struct FirstValueAccumulator {
165165
orderings: Vec<ScalarValue>,
166166
// Stores the applicable ordering requirement.
167167
ordering_req: LexOrdering,
168-
// Whether merge_batch() is called before
169-
is_merge_called: bool,
170168
}
171169

172170
impl FirstValueAccumulator {
@@ -185,7 +183,6 @@ impl FirstValueAccumulator {
185183
is_set: false,
186184
orderings,
187185
ordering_req,
188-
is_merge_called: false,
189186
})
190187
}
191188

@@ -201,9 +198,7 @@ impl Accumulator for FirstValueAccumulator {
201198
fn state(&self) -> Result<Vec<ScalarValue>> {
202199
let mut result = vec![self.first.clone()];
203200
result.extend(self.orderings.iter().cloned());
204-
if !self.is_merge_called {
205-
result.push(ScalarValue::Boolean(Some(self.is_set)));
206-
}
201+
result.push(ScalarValue::Boolean(Some(self.is_set)));
207202
Ok(result)
208203
}
209204

@@ -218,7 +213,6 @@ impl Accumulator for FirstValueAccumulator {
218213
}
219214

220215
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
221-
self.is_merge_called = true;
222216
// FIRST_VALUE(first1, first2, first3, ...)
223217
// last index contains is_set flag.
224218
let is_set_idx = states.len() - 1;
@@ -237,13 +231,17 @@ impl Accumulator for FirstValueAccumulator {
237231
};
238232
if !ordered_states[0].is_empty() {
239233
let first_row = get_row_at_idx(&ordered_states, 0)?;
240-
let first_ordering = &first_row[1..];
234+
// When collecting orderings, we exclude the is_set flag from the state.
235+
let first_ordering = &first_row[1..is_set_idx];
241236
let sort_options = get_sort_options(&self.ordering_req);
242237
// Either there is no existing value, or there is an earlier version in new data.
243238
if !self.is_set
244239
|| compare_rows(first_ordering, &self.orderings, &sort_options)?.is_lt()
245240
{
246-
self.update_with_new_row(&first_row);
241+
// Update with first value in the state. Note that we should exclude the
242+
// is_set flag from the state. Otherwise, we will end up with a state
243+
// containing two is_set flags.
244+
self.update_with_new_row(&first_row[0..is_set_idx]);
247245
}
248246
}
249247
Ok(())
@@ -390,8 +388,6 @@ struct LastValueAccumulator {
390388
orderings: Vec<ScalarValue>,
391389
// Stores the applicable ordering requirement.
392390
ordering_req: LexOrdering,
393-
// Whether merge_batch() is called before
394-
is_merge_called: bool,
395391
}
396392

397393
impl LastValueAccumulator {
@@ -410,7 +406,6 @@ impl LastValueAccumulator {
410406
is_set: false,
411407
orderings,
412408
ordering_req,
413-
is_merge_called: false,
414409
})
415410
}
416411

@@ -426,9 +421,7 @@ impl Accumulator for LastValueAccumulator {
426421
fn state(&self) -> Result<Vec<ScalarValue>> {
427422
let mut result = vec![self.last.clone()];
428423
result.extend(self.orderings.clone());
429-
if !self.is_merge_called {
430-
result.push(ScalarValue::Boolean(Some(self.is_set)));
431-
}
424+
result.push(ScalarValue::Boolean(Some(self.is_set)));
432425
Ok(result)
433426
}
434427

@@ -442,7 +435,6 @@ impl Accumulator for LastValueAccumulator {
442435
}
443436

444437
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
445-
self.is_merge_called = true;
446438
// LAST_VALUE(last1, last2, last3, ...)
447439
// last index contains is_set flag.
448440
let is_set_idx = states.len() - 1;
@@ -463,14 +455,18 @@ impl Accumulator for LastValueAccumulator {
463455
if !ordered_states[0].is_empty() {
464456
let last_idx = ordered_states[0].len() - 1;
465457
let last_row = get_row_at_idx(&ordered_states, last_idx)?;
466-
let last_ordering = &last_row[1..];
458+
// When collecting orderings, we exclude the is_set flag from the state.
459+
let last_ordering = &last_row[1..is_set_idx];
467460
let sort_options = get_sort_options(&self.ordering_req);
468461
// Either there is no existing value, or there is a newer (latest)
469462
// version in the new data:
470463
if !self.is_set
471464
|| compare_rows(last_ordering, &self.orderings, &sort_options)?.is_gt()
472465
{
473-
self.update_with_new_row(&last_row);
466+
// Update with last value in the state. Note that we should exclude the
467+
// is_set flag from the state. Otherwise, we will end up with a state
468+
// containing two is_set flags.
469+
self.update_with_new_row(&last_row[0..is_set_idx]);
474470
}
475471
}
476472
Ok(())
@@ -531,6 +527,7 @@ mod tests {
531527
use datafusion_common::{Result, ScalarValue};
532528
use datafusion_expr::Accumulator;
533529

530+
use arrow::compute::concat;
534531
use std::sync::Arc;
535532

536533
#[test]
@@ -562,4 +559,72 @@ mod tests {
562559
assert_eq!(last_accumulator.evaluate()?, ScalarValue::Int64(Some(12)));
563560
Ok(())
564561
}
562+
563+
#[test]
564+
fn test_first_last_state_after_merge() -> Result<()> {
565+
let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)];
566+
// create 3 ArrayRefs between each interval e.g from 0 to 9, 1 to 10, 2 to 12
567+
let arrs = ranges
568+
.into_iter()
569+
.map(|(start, end)| {
570+
Arc::new((start..end).collect::<Int64Array>()) as ArrayRef
571+
})
572+
.collect::<Vec<_>>();
573+
574+
// FirstValueAccumulator
575+
let mut first_accumulator =
576+
FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
577+
578+
first_accumulator.update_batch(&[arrs[0].clone()])?;
579+
let state1 = first_accumulator.state()?;
580+
581+
let mut first_accumulator =
582+
FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
583+
first_accumulator.update_batch(&[arrs[1].clone()])?;
584+
let state2 = first_accumulator.state()?;
585+
586+
assert_eq!(state1.len(), state2.len());
587+
588+
let mut states = vec![];
589+
590+
for idx in 0..state1.len() {
591+
states.push(concat(&[&state1[idx].to_array(), &state2[idx].to_array()])?);
592+
}
593+
594+
let mut first_accumulator =
595+
FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
596+
first_accumulator.merge_batch(&states)?;
597+
598+
let merged_state = first_accumulator.state()?;
599+
assert_eq!(merged_state.len(), state1.len());
600+
601+
// LastValueAccumulator
602+
let mut last_accumulator =
603+
LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
604+
605+
last_accumulator.update_batch(&[arrs[0].clone()])?;
606+
let state1 = last_accumulator.state()?;
607+
608+
let mut last_accumulator =
609+
LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
610+
last_accumulator.update_batch(&[arrs[1].clone()])?;
611+
let state2 = last_accumulator.state()?;
612+
613+
assert_eq!(state1.len(), state2.len());
614+
615+
let mut states = vec![];
616+
617+
for idx in 0..state1.len() {
618+
states.push(concat(&[&state1[idx].to_array(), &state2[idx].to_array()])?);
619+
}
620+
621+
let mut last_accumulator =
622+
LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?;
623+
last_accumulator.merge_batch(&states)?;
624+
625+
let merged_state = last_accumulator.state()?;
626+
assert_eq!(merged_state.len(), state1.len());
627+
628+
Ok(())
629+
}
565630
}

0 commit comments

Comments
 (0)