Skip to content

Commit 07a5285

Browse files
committed
arrow-ord: add support for partitioning nested types
This support is currently incorrectly assumed by `BoundedWindowAggExec`, so partitioning on a nested type (e.g. struct) causes a nested comparison failure on execution. This commit adds a check to use distinct on non-nested types and falls back to using make_comparator on nested types.
1 parent d4b9482 commit 07a5285

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

arrow-ord/src/partition.rs

+29-2
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ use std::ops::Range;
2121

2222
use arrow_array::{Array, ArrayRef};
2323
use arrow_buffer::BooleanBuffer;
24-
use arrow_schema::ArrowError;
24+
use arrow_schema::{ArrowError, SortOptions};
2525

2626
use crate::cmp::distinct;
27+
use crate::ord::make_comparator;
2728

2829
/// A computed set of partitions, see [`partition`]
2930
#[derive(Debug, Clone)]
@@ -156,7 +157,14 @@ fn find_boundaries(v: &dyn Array) -> Result<BooleanBuffer, ArrowError> {
156157
let slice_len = v.len() - 1;
157158
let v1 = v.slice(0, slice_len);
158159
let v2 = v.slice(1, slice_len);
159-
Ok(distinct(&v1, &v2)?.values().clone())
160+
161+
if !v.data_type().is_nested() {
162+
return Ok(distinct(&v1, &v2)?.values().clone());
163+
}
164+
// Given that we're only comparing values, null ordering in the input or
165+
// sort options do not matter.
166+
let cmp = make_comparator(&v1, &v2, SortOptions::default())?;
167+
Ok((0..slice_len).map(|i| !cmp(i, i).is_eq()).collect())
160168
}
161169

162170
#[cfg(test)]
@@ -298,4 +306,23 @@ mod tests {
298306
vec![(0..1), (1..2), (2..4), (4..5), (5..7), (7..8), (8..9)],
299307
);
300308
}
309+
310+
#[test]
311+
fn test_partition_nested() {
312+
let ints: ArrayRef = Arc::new(Int64Array::from(vec![
313+
None,
314+
None,
315+
Some(1),
316+
Some(2),
317+
Some(2),
318+
Some(2),
319+
Some(3),
320+
Some(4),
321+
]));
322+
let input = vec![Arc::new(StructArray::try_from(vec![("f1", ints)]).unwrap()) as _];
323+
assert_eq!(
324+
partition(&input).unwrap().ranges(),
325+
vec![0..2, 2..3, 3..6, 6..7, 7..8]
326+
)
327+
}
301328
}

0 commit comments

Comments
 (0)