Skip to content

Commit 62000b4

Browse files
authored
perf(array-agg): add fast path for array agg for merge_batch (#14299)
* perf(array-agg): add fast path for array agg for `merge_batch` * update comment * fix slice length * fix: make sure we are not inserting empty lists
1 parent 2a8b885 commit 62000b4

File tree

1 file changed

+76
-3
lines changed

1 file changed

+76
-3
lines changed

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
1919

20-
use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, StructArray};
20+
use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, ListArray, StructArray};
2121
use arrow::datatypes::DataType;
2222

2323
use arrow_schema::{Field, Fields};
@@ -177,6 +177,67 @@ impl ArrayAggAccumulator {
177177
datatype: datatype.clone(),
178178
})
179179
}
180+
181+
/// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non empty list)
182+
/// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end
183+
fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option<ArrayRef> {
184+
let offsets = list_array.value_offsets();
185+
// Offsets always have at least 1 value
186+
let initial_offset = offsets[0];
187+
let null_count = list_array.null_count();
188+
189+
// If no nulls than just use the fast path
190+
// This is ok as the state is a ListArray rather than a ListViewArray so all the values are consecutive
191+
if null_count == 0 {
192+
// According to Arrow specification, the first offset can be non-zero
193+
let list_values = list_array.values().slice(
194+
initial_offset as usize,
195+
(offsets[offsets.len() - 1] - initial_offset) as usize,
196+
);
197+
return Some(list_values);
198+
}
199+
200+
// If all the values are null than just return an empty values array
201+
if list_array.null_count() == list_array.len() {
202+
return Some(list_array.values().slice(0, 0));
203+
}
204+
205+
// According to the Arrow spec, null values can point to non empty lists
206+
// So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value
207+
208+
// Unwrapping is safe as we just checked if there is a null value
209+
let nulls = list_array.nulls().unwrap();
210+
211+
let mut valid_slices_iter = nulls.valid_slices();
212+
213+
// This is safe as we validated that that are at least 1 valid value in the array
214+
let (start, end) = valid_slices_iter.next().unwrap();
215+
216+
let start_offset = offsets[start];
217+
218+
// End is exclusive, so it already point to the last offset value
219+
// This is valid as the length of the array is always 1 less than the length of the offsets
220+
let mut end_offset_of_last_valid_value = offsets[end];
221+
222+
for (start, end) in valid_slices_iter {
223+
// If there is a null value that point to a non empty list than the start offset of the valid value
224+
// will be different that the end offset of the last valid value
225+
if offsets[start] != end_offset_of_last_valid_value {
226+
return None;
227+
}
228+
229+
// End is exclusive, so it already point to the last offset value
230+
// This is valid as the length of the array is always 1 less than the length of the offsets
231+
end_offset_of_last_valid_value = offsets[end];
232+
}
233+
234+
let consecutive_valid_values = list_array.values().slice(
235+
start_offset as usize,
236+
(end_offset_of_last_valid_value - start_offset) as usize,
237+
);
238+
239+
Some(consecutive_valid_values)
240+
}
180241
}
181242

182243
impl Accumulator for ArrayAggAccumulator {
@@ -208,9 +269,21 @@ impl Accumulator for ArrayAggAccumulator {
208269
}
209270

210271
let list_arr = as_list_array(&states[0])?;
211-
for arr in list_arr.iter().flatten() {
212-
self.values.push(arr);
272+
273+
match Self::get_optional_values_to_merge_as_is(list_arr) {
274+
Some(values) => {
275+
// Make sure we don't insert empty lists
276+
if values.len() > 0 {
277+
self.values.push(values);
278+
}
279+
}
280+
None => {
281+
for arr in list_arr.iter().flatten() {
282+
self.values.push(arr);
283+
}
284+
}
213285
}
286+
214287
Ok(())
215288
}
216289

0 commit comments

Comments
 (0)