Skip to content

Commit 2757d7b

Browse files
authored
fix: SparseArray#search_sorted handles non null fill value that is between patch values (#2585)
Need to add some comments to the code since it took me a while to figure out this flow. The code before was naive and didn't acocunt for fill value that falls in between patch values which could still constitute a sorted sparse array, i.e. `[A, fill, fill, fill, B]`
1 parent 69d4746 commit 2757d7b

File tree

3 files changed

+214
-37
lines changed

3 files changed

+214
-37
lines changed

encodings/sparse/src/compute/search_sorted.rs

+208-28
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use std::cmp::Ordering;
22

33
use vortex_array::Array;
4-
use vortex_array::compute::{SearchResult, SearchSortedFn, SearchSortedSide, SearchSortedUsizeFn};
4+
use vortex_array::compute::{
5+
SearchResult, SearchSortedFn, SearchSortedSide, SearchSortedUsizeFn, scalar_at,
6+
};
57
use vortex_error::{VortexExpect, VortexResult};
68
use vortex_scalar::Scalar;
79

@@ -14,40 +16,87 @@ impl SearchSortedFn<&SparseArray> for SparseEncoding {
1416
value: &Scalar,
1517
side: SearchSortedSide,
1618
) -> VortexResult<SearchResult> {
17-
let min_index = array.patches().min_index()?;
18-
// For a sorted array the patches can be either at the beginning or at the end of the array
19-
if min_index == 0 {
20-
match value
21-
.partial_cmp(array.fill_scalar())
22-
.vortex_expect("value and fill scalar must have same dtype")
23-
{
24-
Ordering::Less => array.patches().search_sorted(value.clone(), side),
25-
Ordering::Equal => Ok(SearchResult::Found(if side == SearchSortedSide::Left {
26-
// In case of patches being at the beginning we want the first index after the end of patches
27-
array.patches().indices().len()
19+
// first search result in patches
20+
let patches_result = array.patches().search_sorted(value.clone(), side)?;
21+
match patches_result {
22+
SearchResult::Found(i) => {
23+
if value == array.fill_scalar() {
24+
// Find the relevant position of the fill value in the patches
25+
let fill_index = fill_position(array, side)?;
26+
match side {
27+
SearchSortedSide::Left => Ok(SearchResult::Found(i.min(fill_index))),
28+
SearchSortedSide::Right => Ok(SearchResult::Found(i.max(fill_index))),
29+
}
2830
} else {
29-
array.len()
30-
})),
31-
Ordering::Greater => Ok(SearchResult::NotFound(array.len())),
31+
Ok(SearchResult::Found(i))
32+
}
3233
}
33-
} else {
34-
match value
35-
.partial_cmp(array.fill_scalar())
36-
.vortex_expect("value and fill scalar must have same dtype")
37-
{
38-
Ordering::Less => Ok(SearchResult::NotFound(0)),
39-
Ordering::Equal => Ok(SearchResult::Found(if side == SearchSortedSide::Left {
40-
0
41-
} else {
42-
// Searching from right the min_index is one value after the last fill value
43-
min_index
44-
})),
45-
Ordering::Greater => array.patches().search_sorted(value.clone(), side),
34+
SearchResult::NotFound(i) => {
35+
// Find the relevant position of the fill value in the patches
36+
let fill_index = fill_position(array, side)?;
37+
38+
// Adjust the position of the search value relative to the position of the fill value
39+
match value
40+
.partial_cmp(array.fill_scalar())
41+
.vortex_expect("value and fill scalar must have same dtype")
42+
{
43+
Ordering::Less => Ok(SearchResult::NotFound(i.min(fill_index))),
44+
Ordering::Equal => match side {
45+
SearchSortedSide::Left => Ok(SearchResult::Found(i.min(fill_index))),
46+
SearchSortedSide::Right => Ok(SearchResult::Found(i.max(fill_index))),
47+
},
48+
Ordering::Greater => Ok(SearchResult::NotFound(i.max(fill_index))),
49+
}
4650
}
4751
}
4852
}
4953
}
5054

55+
fn fill_position(array: &SparseArray, side: SearchSortedSide) -> VortexResult<usize> {
56+
// In not found case we need to find the relative position of fill value to the patches
57+
let fill_result = if array.fill_scalar().is_null() {
58+
// For null fill the patches can only ever be after the fill
59+
SearchResult::NotFound(array.patches().min_index()?)
60+
} else {
61+
array
62+
.patches()
63+
.search_sorted(array.fill_scalar().clone(), side)?
64+
};
65+
let fill_result_index = fill_result.to_index();
66+
// Find the relevant position of the fill value in the patches
67+
Ok(if fill_result_index <= array.patches().min_index()? {
68+
// [fill, ..., patch]
69+
0
70+
} else if fill_result_index > array.patches().max_index()? {
71+
// [patch, ..., fill]
72+
array.len()
73+
} else {
74+
// [patch, fill, ..., fill, patch]
75+
match side {
76+
SearchSortedSide::Left => fill_result_index,
77+
SearchSortedSide::Right => {
78+
// When searching from right we need to find the right most occurrence of our fill value. If fill value
79+
// is present in patches this would be the index of the next value after the fill value
80+
let fill_index = array.patches().search_index(fill_result_index)?.to_index();
81+
if fill_index < array.patches().num_patches() {
82+
// Since we are searching from right the fill_index is the index one after the found one
83+
let next_index =
84+
usize::try_from(&scalar_at(array.patches().indices(), fill_index)?)?;
85+
// fill value is dense with a next patch value we want to return the original fill_index,
86+
// i.e. the fill value cannot exist between fill_index and next_index
87+
if fill_index + 1 == next_index {
88+
fill_index
89+
} else {
90+
next_index
91+
}
92+
} else {
93+
fill_index
94+
}
95+
}
96+
}
97+
})
98+
}
99+
51100
impl SearchSortedUsizeFn<&SparseArray> for SparseEncoding {
52101
fn search_sorted_usize(
53102
&self,
@@ -109,10 +158,66 @@ mod tests {
109158
.into_array()
110159
}
111160

161+
fn sparse_low_high() -> ArrayRef {
162+
SparseArray::try_new(
163+
buffer![0u64, 1, 17, 18, 19].into_array(),
164+
buffer![11i32, 22, 33, 44, 55].into_array(),
165+
20,
166+
Scalar::primitive(30, Nullability::NonNullable),
167+
)
168+
.unwrap()
169+
.into_array()
170+
}
171+
172+
fn sparse_high_fill_in_patches() -> ArrayRef {
173+
SparseArray::try_new(
174+
buffer![17u64, 18, 19].into_array(),
175+
buffer![33_i32, 44, 55].into_array(),
176+
20,
177+
Scalar::primitive(33, Nullability::NonNullable),
178+
)
179+
.unwrap()
180+
.into_array()
181+
}
182+
183+
fn sparse_low_fill_in_patches() -> ArrayRef {
184+
SparseArray::try_new(
185+
buffer![0u64, 1, 2].into_array(),
186+
buffer![33_i32, 44, 55].into_array(),
187+
20,
188+
Scalar::primitive(55, Nullability::NonNullable),
189+
)
190+
.unwrap()
191+
.into_array()
192+
}
193+
194+
fn sparse_low_high_fill_in_patches_low() -> ArrayRef {
195+
SparseArray::try_new(
196+
buffer![0u64, 1, 17, 18, 19].into_array(),
197+
buffer![11i32, 22, 33, 44, 55].into_array(),
198+
20,
199+
Scalar::primitive(22, Nullability::NonNullable),
200+
)
201+
.unwrap()
202+
.into_array()
203+
}
204+
205+
fn sparse_low_high_fill_in_patches_high() -> ArrayRef {
206+
SparseArray::try_new(
207+
buffer![0u64, 1, 17, 18, 19].into_array(),
208+
buffer![11i32, 22, 33, 44, 55].into_array(),
209+
20,
210+
Scalar::primitive(33, Nullability::NonNullable),
211+
)
212+
.unwrap()
213+
.into_array()
214+
}
215+
112216
#[rstest]
113217
#[case(sparse_high_null_fill(), SearchResult::NotFound(20))]
114218
#[case(sparse_high_non_null_fill(), SearchResult::NotFound(20))]
115219
#[case(sparse_low(), SearchResult::NotFound(20))]
220+
#[case(sparse_low_high(), SearchResult::NotFound(20))]
116221
fn search_larger_than_left(#[case] array: ArrayRef, #[case] expected: SearchResult) {
117222
let res = search_sorted(&array, 66, SearchSortedSide::Left).unwrap();
118223
assert_eq!(res, expected);
@@ -122,6 +227,7 @@ mod tests {
122227
#[case(sparse_high_null_fill(), SearchResult::NotFound(20))]
123228
#[case(sparse_high_non_null_fill(), SearchResult::NotFound(20))]
124229
#[case(sparse_low(), SearchResult::NotFound(20))]
230+
#[case(sparse_low_high(), SearchResult::NotFound(20))]
125231
fn search_larger_than_right(#[case] array: ArrayRef, #[case] expected: SearchResult) {
126232
let res = search_sorted(&array, 66, SearchSortedSide::Right).unwrap();
127233
assert_eq!(res, expected);
@@ -131,6 +237,7 @@ mod tests {
131237
#[case(sparse_high_null_fill(), SearchResult::NotFound(17))]
132238
#[case(sparse_high_non_null_fill(), SearchResult::NotFound(0))]
133239
#[case(sparse_low(), SearchResult::NotFound(0))]
240+
#[case(sparse_low_high(), SearchResult::NotFound(1))]
134241
fn search_less_than_left(#[case] array: ArrayRef, #[case] expected: SearchResult) {
135242
let res = search_sorted(&array, 21, SearchSortedSide::Left).unwrap();
136243
assert_eq!(res, expected);
@@ -140,6 +247,7 @@ mod tests {
140247
#[case(sparse_high_null_fill(), SearchResult::NotFound(17))]
141248
#[case(sparse_high_non_null_fill(), SearchResult::NotFound(0))]
142249
#[case(sparse_low(), SearchResult::NotFound(0))]
250+
#[case(sparse_low_high(), SearchResult::NotFound(1))]
143251
fn search_less_than_right(#[case] array: ArrayRef, #[case] expected: SearchResult) {
144252
let res = search_sorted(&array, 21, SearchSortedSide::Right).unwrap();
145253
assert_eq!(res, expected);
@@ -149,6 +257,7 @@ mod tests {
149257
#[case(sparse_high_null_fill(), SearchResult::Found(18))]
150258
#[case(sparse_high_non_null_fill(), SearchResult::Found(18))]
151259
#[case(sparse_low(), SearchResult::Found(1))]
260+
#[case(sparse_low_high(), SearchResult::Found(18))]
152261
fn search_patches_found_left(#[case] array: ArrayRef, #[case] expected: SearchResult) {
153262
let res = search_sorted(&array, 44, SearchSortedSide::Left).unwrap();
154263
assert_eq!(res, expected);
@@ -158,11 +267,32 @@ mod tests {
158267
#[case(sparse_high_null_fill(), SearchResult::Found(19))]
159268
#[case(sparse_high_non_null_fill(), SearchResult::Found(19))]
160269
#[case(sparse_low(), SearchResult::Found(2))]
270+
#[case(sparse_low_high(), SearchResult::Found(19))]
161271
fn search_patches_found_right(#[case] array: ArrayRef, #[case] expected: SearchResult) {
162272
let res = search_sorted(&array, 44, SearchSortedSide::Right).unwrap();
163273
assert_eq!(res, expected);
164274
}
165275

276+
#[rstest]
277+
#[case(sparse_high_null_fill(), SearchResult::NotFound(19))]
278+
#[case(sparse_high_non_null_fill(), SearchResult::NotFound(19))]
279+
#[case(sparse_low(), SearchResult::NotFound(2))]
280+
#[case(sparse_low_high(), SearchResult::NotFound(19))]
281+
fn search_mid_patches_not_found_left(#[case] array: ArrayRef, #[case] expected: SearchResult) {
282+
let res = search_sorted(&array, 45, SearchSortedSide::Left).unwrap();
283+
assert_eq!(res, expected);
284+
}
285+
286+
#[rstest]
287+
#[case(sparse_high_null_fill(), SearchResult::NotFound(19))]
288+
#[case(sparse_high_non_null_fill(), SearchResult::NotFound(19))]
289+
#[case(sparse_low(), SearchResult::NotFound(2))]
290+
#[case(sparse_low_high(), SearchResult::NotFound(19))]
291+
fn search_mid_patches_not_found_right(#[case] array: ArrayRef, #[case] expected: SearchResult) {
292+
let res = search_sorted(&array, 45, SearchSortedSide::Right).unwrap();
293+
assert_eq!(res, expected);
294+
}
295+
166296
#[rstest]
167297
#[should_panic]
168298
#[case(sparse_high_null_fill(), Scalar::null_typed::<i32>(), SearchResult::Found(18))]
@@ -176,6 +306,31 @@ mod tests {
176306
Scalar::primitive(60, Nullability::NonNullable),
177307
SearchResult::Found(3)
178308
)]
309+
#[case(
310+
sparse_low_high(),
311+
Scalar::primitive(30, Nullability::NonNullable),
312+
SearchResult::Found(2)
313+
)]
314+
#[case(
315+
sparse_high_fill_in_patches(),
316+
Scalar::primitive(33, Nullability::NonNullable),
317+
SearchResult::Found(0)
318+
)]
319+
#[case(
320+
sparse_low_fill_in_patches(),
321+
Scalar::primitive(55, Nullability::NonNullable),
322+
SearchResult::Found(2)
323+
)]
324+
#[case(
325+
sparse_low_high_fill_in_patches_low(),
326+
Scalar::primitive(22, Nullability::NonNullable),
327+
SearchResult::Found(1)
328+
)]
329+
#[case(
330+
sparse_low_high_fill_in_patches_high(),
331+
Scalar::primitive(33, Nullability::NonNullable),
332+
SearchResult::Found(17)
333+
)]
179334
fn search_fill_left(
180335
#[case] array: ArrayRef,
181336
#[case] search: Scalar,
@@ -198,6 +353,31 @@ mod tests {
198353
Scalar::primitive(60, Nullability::NonNullable),
199354
SearchResult::Found(20)
200355
)]
356+
#[case(
357+
sparse_low_high(),
358+
Scalar::primitive(30, Nullability::NonNullable),
359+
SearchResult::Found(17)
360+
)]
361+
#[case(
362+
sparse_high_fill_in_patches(),
363+
Scalar::primitive(33, Nullability::NonNullable),
364+
SearchResult::Found(18)
365+
)]
366+
#[case(
367+
sparse_low_fill_in_patches(),
368+
Scalar::primitive(55, Nullability::NonNullable),
369+
SearchResult::Found(20)
370+
)]
371+
#[case(
372+
sparse_low_high_fill_in_patches_low(),
373+
Scalar::primitive(22, Nullability::NonNullable),
374+
SearchResult::Found(17)
375+
)]
376+
#[case(
377+
sparse_low_high_fill_in_patches_high(),
378+
Scalar::primitive(33, Nullability::NonNullable),
379+
SearchResult::Found(18)
380+
)]
201381
fn search_fill_right(
202382
#[case] array: ArrayRef,
203383
#[case] search: Scalar,

fuzz/fuzz_targets/array_ops.rs

+4-7
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus {
2121
for (i, (action, expected)) in actions.into_iter().enumerate() {
2222
match action {
2323
Action::Compress(c) => {
24-
let compressed_array =
25-
fuzz_compress(current_array.to_canonical().unwrap().as_ref(), &c);
24+
let compressed_array = c
25+
.compress(current_array.to_canonical().unwrap().as_ref())
26+
.unwrap();
2627
assert_array_eq(&expected.array(), &compressed_array, i);
2728
current_array = compressed_array;
2829
}
@@ -50,7 +51,7 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus {
5051
])
5152
.contains(&current_array.encoding())
5253
{
53-
sorted = fuzz_compress(&sorted, &BtrBlocksCompressor);
54+
sorted = BtrBlocksCompressor.compress(&sorted).unwrap();
5455
}
5556
assert_search_sorted(sorted, s, side, expected.search(), i)
5657
}
@@ -63,10 +64,6 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus {
6364
Corpus::Keep
6465
});
6566

66-
fn fuzz_compress(array: &dyn Array, compressor: &BtrBlocksCompressor) -> ArrayRef {
67-
compressor.compress(array).unwrap()
68-
}
69-
7067
fn assert_search_sorted(
7168
array: ArrayRef,
7269
s: Scalar,

vortex-array/src/patches.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ impl Patches {
217217
}
218218
}
219219

220-
/// Return the insertion point of [index] in the [Self::indices].
221-
fn search_index(&self, index: usize) -> VortexResult<SearchResult> {
220+
/// Return the insertion point of `index` in the [Self::indices].
221+
pub fn search_index(&self, index: usize) -> VortexResult<SearchResult> {
222222
search_sorted_usize(&self.indices, index + self.offset, SearchSortedSide::Left)
223223
}
224224

0 commit comments

Comments
 (0)