diff --git a/arrow-ord/src/partition.rs b/arrow-ord/src/partition.rs index ec1647393239..fa00edab69fd 100644 --- a/arrow-ord/src/partition.rs +++ b/arrow-ord/src/partition.rs @@ -21,9 +21,10 @@ use std::ops::Range; use arrow_array::{Array, ArrayRef}; use arrow_buffer::BooleanBuffer; -use arrow_schema::ArrowError; +use arrow_schema::{ArrowError, SortOptions}; use crate::cmp::distinct; +use crate::ord::make_comparator; /// A computed set of partitions, see [`partition`] #[derive(Debug, Clone)] @@ -156,18 +157,24 @@ fn find_boundaries(v: &dyn Array) -> Result { let slice_len = v.len() - 1; let v1 = v.slice(0, slice_len); let v2 = v.slice(1, slice_len); - Ok(distinct(&v1, &v2)?.values().clone()) + + if !v.data_type().is_nested() { + return Ok(distinct(&v1, &v2)?.values().clone()); + } + // Given that we're only comparing values, null ordering in the input or + // sort options do not matter. + let cmp = make_comparator(&v1, &v2, SortOptions::default())?; + Ok((0..slice_len).map(|i| !cmp(i, i).is_eq()).collect()) } #[cfg(test)] mod tests { use std::sync::Arc; + use super::*; use arrow_array::*; use arrow_schema::DataType; - use super::*; - #[test] fn test_partition_empty() { let err = partition(&[]).unwrap_err(); @@ -298,4 +305,31 @@ mod tests { vec![(0..1), (1..2), (2..4), (4..5), (5..7), (7..8), (8..9)], ); } + + #[test] + fn test_partition_nested() { + let input = vec![ + Arc::new( + StructArray::try_from(vec![( + "f1", + Arc::new(Int64Array::from(vec![ + None, + None, + Some(1), + Some(2), + Some(2), + Some(2), + Some(3), + Some(4), + ])) as _, + )]) + .unwrap(), + ) as _, + Arc::new(Int64Array::from(vec![1, 1, 1, 2, 3, 3, 3, 4])) as _, + ]; + assert_eq!( + partition(&input).unwrap().ranges(), + vec![0..2, 2..3, 3..4, 4..6, 6..7, 7..8] + ) + } }