Skip to content

Commit 019815a

Browse files
authored
fix: SparseArray#search_sorted properly handles fill values and behaves differently than Patches#search_sorted (#2581)
1 parent 36300ff commit 019815a

File tree

6 files changed

+251
-144
lines changed

6 files changed

+251
-144
lines changed

encodings/sparse/src/compute/mod.rs

+4-87
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use vortex_array::arrays::ConstantArray;
22
use vortex_array::compute::{
3-
BinaryNumericFn, FilterFn, InvertFn, ScalarAtFn, SearchResult, SearchSortedFn,
4-
SearchSortedSide, SearchSortedUsizeFn, SliceFn, TakeFn,
3+
BinaryNumericFn, FilterFn, InvertFn, ScalarAtFn, SearchSortedFn, SearchSortedUsizeFn, SliceFn,
4+
TakeFn,
55
};
66
use vortex_array::vtable::ComputeVTable;
77
use vortex_array::{Array, ArrayRef};
@@ -13,6 +13,7 @@ use crate::{SparseArray, SparseEncoding};
1313

1414
mod binary_numeric;
1515
mod invert;
16+
mod search_sorted;
1617
mod slice;
1718
mod take;
1819

@@ -59,34 +60,6 @@ impl ScalarAtFn<&SparseArray> for SparseEncoding {
5960
}
6061
}
6162

62-
// FIXME(ngates): these are broken in a way that works for array patches, this will be fixed soon.
63-
impl SearchSortedFn<&SparseArray> for SparseEncoding {
64-
fn search_sorted(
65-
&self,
66-
array: &SparseArray,
67-
value: &Scalar,
68-
side: SearchSortedSide,
69-
) -> VortexResult<SearchResult> {
70-
array.patches().search_sorted(value.clone(), side)
71-
}
72-
}
73-
74-
// FIXME(ngates): these are broken in a way that works for array patches, this will be fixed soon.
75-
impl SearchSortedUsizeFn<&SparseArray> for SparseEncoding {
76-
fn search_sorted_usize(
77-
&self,
78-
array: &SparseArray,
79-
value: usize,
80-
side: SearchSortedSide,
81-
) -> VortexResult<SearchResult> {
82-
let Ok(target) = Scalar::from(value).cast(array.dtype()) else {
83-
// If the downcast fails, then the target is too large for the dtype.
84-
return Ok(SearchResult::NotFound(array.len()));
85-
};
86-
SearchSortedFn::search_sorted(self, array, &target, side)
87-
}
88-
}
89-
9063
impl FilterFn<&SparseArray> for SparseEncoding {
9164
fn filter(&self, array: &SparseArray, mask: &Mask) -> VortexResult<ArrayRef> {
9265
let new_length = mask.true_count();
@@ -107,9 +80,7 @@ mod test {
10780
use rstest::{fixture, rstest};
10881
use vortex_array::arrays::PrimitiveArray;
10982
use vortex_array::compute::test_harness::{test_binary_numeric, test_mask};
110-
use vortex_array::compute::{
111-
SearchResult, SearchSortedSide, filter, search_sorted, slice, try_cast,
112-
};
83+
use vortex_array::compute::{filter, try_cast};
11384
use vortex_array::validity::Validity;
11485
use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
11586
use vortex_buffer::buffer;
@@ -131,60 +102,6 @@ mod test {
131102
.into_array()
132103
}
133104

134-
#[rstest]
135-
fn search_larger_than(array: ArrayRef) {
136-
let res = search_sorted(&array, 66, SearchSortedSide::Left).unwrap();
137-
assert_eq!(res, SearchResult::NotFound(16));
138-
}
139-
140-
#[rstest]
141-
fn search_less_than(array: ArrayRef) {
142-
let res = search_sorted(&array, 22, SearchSortedSide::Left).unwrap();
143-
assert_eq!(res, SearchResult::NotFound(2));
144-
}
145-
146-
#[rstest]
147-
fn search_found(array: ArrayRef) {
148-
let res = search_sorted(&array, 44, SearchSortedSide::Left).unwrap();
149-
assert_eq!(res, SearchResult::Found(9));
150-
}
151-
152-
#[rstest]
153-
fn search_not_found_right(array: ArrayRef) {
154-
let res = search_sorted(&array, 56, SearchSortedSide::Right).unwrap();
155-
assert_eq!(res, SearchResult::NotFound(16));
156-
}
157-
158-
#[rstest]
159-
fn search_sliced(array: ArrayRef) {
160-
let array = slice(&array, 7, 20).unwrap();
161-
assert_eq!(
162-
search_sorted(&array, 22, SearchSortedSide::Left).unwrap(),
163-
SearchResult::NotFound(2)
164-
);
165-
}
166-
167-
#[test]
168-
fn search_right() {
169-
let array = SparseArray::try_new(
170-
buffer![0u64].into_array(),
171-
PrimitiveArray::new(buffer![0u8], Validity::AllValid).into_array(),
172-
2,
173-
Scalar::null_typed::<u8>(),
174-
)
175-
.unwrap()
176-
.into_array();
177-
178-
assert_eq!(
179-
search_sorted(&array, 0, SearchSortedSide::Right).unwrap(),
180-
SearchResult::Found(1)
181-
);
182-
assert_eq!(
183-
search_sorted(&array, 1, SearchSortedSide::Right).unwrap(),
184-
SearchResult::NotFound(1)
185-
);
186-
}
187-
188105
#[rstest]
189106
fn test_filter(array: ArrayRef) {
190107
let mut predicate = vec![false, false, true];
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
use std::cmp::Ordering;
2+
3+
use vortex_array::Array;
4+
use vortex_array::compute::{SearchResult, SearchSortedFn, SearchSortedSide, SearchSortedUsizeFn};
5+
use vortex_error::{VortexExpect, VortexResult};
6+
use vortex_scalar::Scalar;
7+
8+
use crate::{SparseArray, SparseEncoding};
9+
10+
impl SearchSortedFn<&SparseArray> for SparseEncoding {
11+
fn search_sorted(
12+
&self,
13+
array: &SparseArray,
14+
value: &Scalar,
15+
side: SearchSortedSide,
16+
) -> VortexResult<SearchResult> {
17+
let min_index = array.patches().min_index()?;
18+
// For a sorted array the patches can be either at the beginning or at the end of the array
19+
if min_index == 0 {
20+
match value
21+
.partial_cmp(array.fill_scalar())
22+
.vortex_expect("value and fill scalar must have same dtype")
23+
{
24+
Ordering::Less => array.patches().search_sorted(value.clone(), side),
25+
Ordering::Equal => Ok(SearchResult::Found(if side == SearchSortedSide::Left {
26+
// In case of patches being at the beginning we want the first index after the end of patches
27+
array.patches().indices().len()
28+
} else {
29+
array.len()
30+
})),
31+
Ordering::Greater => Ok(SearchResult::NotFound(array.len())),
32+
}
33+
} else {
34+
match value
35+
.partial_cmp(array.fill_scalar())
36+
.vortex_expect("value and fill scalar must have same dtype")
37+
{
38+
Ordering::Less => Ok(SearchResult::NotFound(0)),
39+
Ordering::Equal => Ok(SearchResult::Found(if side == SearchSortedSide::Left {
40+
0
41+
} else {
42+
// Searching from right the min_index is one value after the last fill value
43+
min_index
44+
})),
45+
Ordering::Greater => array.patches().search_sorted(value.clone(), side),
46+
}
47+
}
48+
}
49+
}
50+
51+
impl SearchSortedUsizeFn<&SparseArray> for SparseEncoding {
52+
fn search_sorted_usize(
53+
&self,
54+
array: &SparseArray,
55+
value: usize,
56+
side: SearchSortedSide,
57+
) -> VortexResult<SearchResult> {
58+
let Ok(target) = Scalar::from(value).cast(array.dtype()) else {
59+
// If the downcast fails, then the target is too large for the dtype.
60+
return Ok(SearchResult::NotFound(array.len()));
61+
};
62+
SearchSortedFn::search_sorted(self, array, &target, side)
63+
}
64+
}
65+
66+
#[cfg(test)]
67+
mod tests {
68+
use rstest::rstest;
69+
use vortex_array::arrays::PrimitiveArray;
70+
use vortex_array::compute::{SearchResult, SearchSortedSide, search_sorted};
71+
use vortex_array::validity::Validity;
72+
use vortex_array::{Array, ArrayRef, IntoArray};
73+
use vortex_buffer::buffer;
74+
use vortex_dtype::Nullability;
75+
use vortex_scalar::Scalar;
76+
77+
use crate::SparseArray;
78+
79+
fn sparse_high_null_fill() -> ArrayRef {
80+
SparseArray::try_new(
81+
buffer![17u64, 18, 19].into_array(),
82+
PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(),
83+
20,
84+
Scalar::null_typed::<i32>(),
85+
)
86+
.unwrap()
87+
.into_array()
88+
}
89+
90+
fn sparse_high_non_null_fill() -> ArrayRef {
91+
SparseArray::try_new(
92+
buffer![17u64, 18, 19].into_array(),
93+
buffer![33_i32, 44, 55].into_array(),
94+
20,
95+
Scalar::primitive(22, Nullability::NonNullable),
96+
)
97+
.unwrap()
98+
.into_array()
99+
}
100+
101+
fn sparse_low() -> ArrayRef {
102+
SparseArray::try_new(
103+
buffer![0u64, 1, 2].into_array(),
104+
buffer![33_i32, 44, 55].into_array(),
105+
20,
106+
Scalar::primitive(60, Nullability::NonNullable),
107+
)
108+
.unwrap()
109+
.into_array()
110+
}
111+
112+
#[rstest]
113+
#[case(sparse_high_null_fill(), SearchResult::NotFound(20))]
114+
#[case(sparse_high_non_null_fill(), SearchResult::NotFound(20))]
115+
#[case(sparse_low(), SearchResult::NotFound(20))]
116+
fn search_larger_than_left(#[case] array: ArrayRef, #[case] expected: SearchResult) {
117+
let res = search_sorted(&array, 66, SearchSortedSide::Left).unwrap();
118+
assert_eq!(res, expected);
119+
}
120+
121+
#[rstest]
122+
#[case(sparse_high_null_fill(), SearchResult::NotFound(20))]
123+
#[case(sparse_high_non_null_fill(), SearchResult::NotFound(20))]
124+
#[case(sparse_low(), SearchResult::NotFound(20))]
125+
fn search_larger_than_right(#[case] array: ArrayRef, #[case] expected: SearchResult) {
126+
let res = search_sorted(&array, 66, SearchSortedSide::Right).unwrap();
127+
assert_eq!(res, expected);
128+
}
129+
130+
#[rstest]
131+
#[case(sparse_high_null_fill(), SearchResult::NotFound(17))]
132+
#[case(sparse_high_non_null_fill(), SearchResult::NotFound(0))]
133+
#[case(sparse_low(), SearchResult::NotFound(0))]
134+
fn search_less_than_left(#[case] array: ArrayRef, #[case] expected: SearchResult) {
135+
let res = search_sorted(&array, 21, SearchSortedSide::Left).unwrap();
136+
assert_eq!(res, expected);
137+
}
138+
139+
#[rstest]
140+
#[case(sparse_high_null_fill(), SearchResult::NotFound(17))]
141+
#[case(sparse_high_non_null_fill(), SearchResult::NotFound(0))]
142+
#[case(sparse_low(), SearchResult::NotFound(0))]
143+
fn search_less_than_right(#[case] array: ArrayRef, #[case] expected: SearchResult) {
144+
let res = search_sorted(&array, 21, SearchSortedSide::Right).unwrap();
145+
assert_eq!(res, expected);
146+
}
147+
148+
#[rstest]
149+
#[case(sparse_high_null_fill(), SearchResult::Found(18))]
150+
#[case(sparse_high_non_null_fill(), SearchResult::Found(18))]
151+
#[case(sparse_low(), SearchResult::Found(1))]
152+
fn search_patches_found_left(#[case] array: ArrayRef, #[case] expected: SearchResult) {
153+
let res = search_sorted(&array, 44, SearchSortedSide::Left).unwrap();
154+
assert_eq!(res, expected);
155+
}
156+
157+
#[rstest]
158+
#[case(sparse_high_null_fill(), SearchResult::Found(19))]
159+
#[case(sparse_high_non_null_fill(), SearchResult::Found(19))]
160+
#[case(sparse_low(), SearchResult::Found(2))]
161+
fn search_patches_found_right(#[case] array: ArrayRef, #[case] expected: SearchResult) {
162+
let res = search_sorted(&array, 44, SearchSortedSide::Right).unwrap();
163+
assert_eq!(res, expected);
164+
}
165+
166+
#[rstest]
167+
#[should_panic]
168+
#[case(sparse_high_null_fill(), Scalar::null_typed::<i32>(), SearchResult::Found(18))]
169+
#[case(
170+
sparse_high_non_null_fill(),
171+
Scalar::primitive(22, Nullability::NonNullable),
172+
SearchResult::Found(0)
173+
)]
174+
#[case(
175+
sparse_low(),
176+
Scalar::primitive(60, Nullability::NonNullable),
177+
SearchResult::Found(3)
178+
)]
179+
fn search_fill_left(
180+
#[case] array: ArrayRef,
181+
#[case] search: Scalar,
182+
#[case] expected: SearchResult,
183+
) {
184+
let res = search_sorted(&array, search, SearchSortedSide::Left).unwrap();
185+
assert_eq!(res, expected);
186+
}
187+
188+
#[rstest]
189+
#[should_panic]
190+
#[case(sparse_high_null_fill(), Scalar::null_typed::<i32>(), SearchResult::Found(18))]
191+
#[case(
192+
sparse_high_non_null_fill(),
193+
Scalar::primitive(22, Nullability::NonNullable),
194+
SearchResult::Found(17)
195+
)]
196+
#[case(
197+
sparse_low(),
198+
Scalar::primitive(60, Nullability::NonNullable),
199+
SearchResult::Found(20)
200+
)]
201+
fn search_fill_right(
202+
#[case] array: ArrayRef,
203+
#[case] search: Scalar,
204+
#[case] expected: SearchResult,
205+
) {
206+
let res = search_sorted(&array, search, SearchSortedSide::Right).unwrap();
207+
assert_eq!(res, expected);
208+
}
209+
}

0 commit comments

Comments
 (0)