1
1
use std:: cmp:: Ordering ;
2
2
3
3
use vortex_array:: Array ;
4
- use vortex_array:: compute:: { SearchResult , SearchSortedFn , SearchSortedSide , SearchSortedUsizeFn } ;
4
+ use vortex_array:: compute:: {
5
+ SearchResult , SearchSortedFn , SearchSortedSide , SearchSortedUsizeFn , scalar_at,
6
+ } ;
5
7
use vortex_error:: { VortexExpect , VortexResult } ;
6
8
use vortex_scalar:: Scalar ;
7
9
@@ -14,40 +16,87 @@ impl SearchSortedFn<&SparseArray> for SparseEncoding {
14
16
value : & Scalar ,
15
17
side : SearchSortedSide ,
16
18
) -> VortexResult < SearchResult > {
17
- let min_index = array . patches ( ) . min_index ( ) ? ;
18
- // For a sorted array the patches can be either at the beginning or at the end of the array
19
- if min_index == 0 {
20
- match value
21
- . partial_cmp ( array. fill_scalar ( ) )
22
- . vortex_expect ( "value and fill scalar must have same dtype" )
23
- {
24
- Ordering :: Less => array . patches ( ) . search_sorted ( value . clone ( ) , side) ,
25
- Ordering :: Equal => Ok ( SearchResult :: Found ( if side == SearchSortedSide :: Left {
26
- // In case of patches being at the beginning we want the first index after the end of patches
27
- array . patches ( ) . indices ( ) . len ( )
19
+ // first search result in patches
20
+ let patches_result = array. patches ( ) . search_sorted ( value . clone ( ) , side ) ? ;
21
+ match patches_result {
22
+ SearchResult :: Found ( i ) => {
23
+ if value == array. fill_scalar ( ) {
24
+ // Find the relevant position of the fill value in the patches
25
+ let fill_index = fill_position ( array , side ) ? ;
26
+ match side {
27
+ SearchSortedSide :: Left => Ok ( SearchResult :: Found ( i . min ( fill_index ) ) ) ,
28
+ SearchSortedSide :: Right => Ok ( SearchResult :: Found ( i . max ( fill_index ) ) ) ,
29
+ }
28
30
} else {
29
- array. len ( )
30
- } ) ) ,
31
- Ordering :: Greater => Ok ( SearchResult :: NotFound ( array. len ( ) ) ) ,
31
+ Ok ( SearchResult :: Found ( i) )
32
+ }
32
33
}
33
- } else {
34
- match value
35
- . partial_cmp ( array. fill_scalar ( ) )
36
- . vortex_expect ( "value and fill scalar must have same dtype" )
37
- {
38
- Ordering :: Less => Ok ( SearchResult :: NotFound ( 0 ) ) ,
39
- Ordering :: Equal => Ok ( SearchResult :: Found ( if side == SearchSortedSide :: Left {
40
- 0
41
- } else {
42
- // Searching from right the min_index is one value after the last fill value
43
- min_index
44
- } ) ) ,
45
- Ordering :: Greater => array. patches ( ) . search_sorted ( value. clone ( ) , side) ,
34
+ SearchResult :: NotFound ( i) => {
35
+ // Find the relevant position of the fill value in the patches
36
+ let fill_index = fill_position ( array, side) ?;
37
+
38
+ // Adjust the position of the search value relative to the position of the fill value
39
+ match value
40
+ . partial_cmp ( array. fill_scalar ( ) )
41
+ . vortex_expect ( "value and fill scalar must have same dtype" )
42
+ {
43
+ Ordering :: Less => Ok ( SearchResult :: NotFound ( i. min ( fill_index) ) ) ,
44
+ Ordering :: Equal => match side {
45
+ SearchSortedSide :: Left => Ok ( SearchResult :: Found ( i. min ( fill_index) ) ) ,
46
+ SearchSortedSide :: Right => Ok ( SearchResult :: Found ( i. max ( fill_index) ) ) ,
47
+ } ,
48
+ Ordering :: Greater => Ok ( SearchResult :: NotFound ( i. max ( fill_index) ) ) ,
49
+ }
46
50
}
47
51
}
48
52
}
49
53
}
50
54
55
+ fn fill_position ( array : & SparseArray , side : SearchSortedSide ) -> VortexResult < usize > {
56
+ // In not found case we need to find the relative position of fill value to the patches
57
+ let fill_result = if array. fill_scalar ( ) . is_null ( ) {
58
+ // For null fill the patches can only ever be after the fill
59
+ SearchResult :: NotFound ( array. patches ( ) . min_index ( ) ?)
60
+ } else {
61
+ array
62
+ . patches ( )
63
+ . search_sorted ( array. fill_scalar ( ) . clone ( ) , side) ?
64
+ } ;
65
+ let fill_result_index = fill_result. to_index ( ) ;
66
+ // Find the relevant position of the fill value in the patches
67
+ Ok ( if fill_result_index <= array. patches ( ) . min_index ( ) ? {
68
+ // [fill, ..., patch]
69
+ 0
70
+ } else if fill_result_index > array. patches ( ) . max_index ( ) ? {
71
+ // [patch, ..., fill]
72
+ array. len ( )
73
+ } else {
74
+ // [patch, fill, ..., fill, patch]
75
+ match side {
76
+ SearchSortedSide :: Left => fill_result_index,
77
+ SearchSortedSide :: Right => {
78
+ // When searching from right we need to find the right most occurrence of our fill value. If fill value
79
+ // is present in patches this would be the index of the next value after the fill value
80
+ let fill_index = array. patches ( ) . search_index ( fill_result_index) ?. to_index ( ) ;
81
+ if fill_index < array. patches ( ) . num_patches ( ) {
82
+ // Since we are searching from right the fill_index is the index one after the found one
83
+ let next_index =
84
+ usize:: try_from ( & scalar_at ( array. patches ( ) . indices ( ) , fill_index) ?) ?;
85
+ // fill value is dense with a next patch value we want to return the original fill_index,
86
+ // i.e. the fill value cannot exist between fill_index and next_index
87
+ if fill_index + 1 == next_index {
88
+ fill_index
89
+ } else {
90
+ next_index
91
+ }
92
+ } else {
93
+ fill_index
94
+ }
95
+ }
96
+ }
97
+ } )
98
+ }
99
+
51
100
impl SearchSortedUsizeFn < & SparseArray > for SparseEncoding {
52
101
fn search_sorted_usize (
53
102
& self ,
@@ -109,10 +158,66 @@ mod tests {
109
158
. into_array ( )
110
159
}
111
160
161
+ fn sparse_low_high ( ) -> ArrayRef {
162
+ SparseArray :: try_new (
163
+ buffer ! [ 0u64 , 1 , 17 , 18 , 19 ] . into_array ( ) ,
164
+ buffer ! [ 11i32 , 22 , 33 , 44 , 55 ] . into_array ( ) ,
165
+ 20 ,
166
+ Scalar :: primitive ( 30 , Nullability :: NonNullable ) ,
167
+ )
168
+ . unwrap ( )
169
+ . into_array ( )
170
+ }
171
+
172
+ fn sparse_high_fill_in_patches ( ) -> ArrayRef {
173
+ SparseArray :: try_new (
174
+ buffer ! [ 17u64 , 18 , 19 ] . into_array ( ) ,
175
+ buffer ! [ 33_i32 , 44 , 55 ] . into_array ( ) ,
176
+ 20 ,
177
+ Scalar :: primitive ( 33 , Nullability :: NonNullable ) ,
178
+ )
179
+ . unwrap ( )
180
+ . into_array ( )
181
+ }
182
+
183
+ fn sparse_low_fill_in_patches ( ) -> ArrayRef {
184
+ SparseArray :: try_new (
185
+ buffer ! [ 0u64 , 1 , 2 ] . into_array ( ) ,
186
+ buffer ! [ 33_i32 , 44 , 55 ] . into_array ( ) ,
187
+ 20 ,
188
+ Scalar :: primitive ( 55 , Nullability :: NonNullable ) ,
189
+ )
190
+ . unwrap ( )
191
+ . into_array ( )
192
+ }
193
+
194
+ fn sparse_low_high_fill_in_patches_low ( ) -> ArrayRef {
195
+ SparseArray :: try_new (
196
+ buffer ! [ 0u64 , 1 , 17 , 18 , 19 ] . into_array ( ) ,
197
+ buffer ! [ 11i32 , 22 , 33 , 44 , 55 ] . into_array ( ) ,
198
+ 20 ,
199
+ Scalar :: primitive ( 22 , Nullability :: NonNullable ) ,
200
+ )
201
+ . unwrap ( )
202
+ . into_array ( )
203
+ }
204
+
205
+ fn sparse_low_high_fill_in_patches_high ( ) -> ArrayRef {
206
+ SparseArray :: try_new (
207
+ buffer ! [ 0u64 , 1 , 17 , 18 , 19 ] . into_array ( ) ,
208
+ buffer ! [ 11i32 , 22 , 33 , 44 , 55 ] . into_array ( ) ,
209
+ 20 ,
210
+ Scalar :: primitive ( 33 , Nullability :: NonNullable ) ,
211
+ )
212
+ . unwrap ( )
213
+ . into_array ( )
214
+ }
215
+
112
216
#[ rstest]
113
217
#[ case( sparse_high_null_fill( ) , SearchResult :: NotFound ( 20 ) ) ]
114
218
#[ case( sparse_high_non_null_fill( ) , SearchResult :: NotFound ( 20 ) ) ]
115
219
#[ case( sparse_low( ) , SearchResult :: NotFound ( 20 ) ) ]
220
+ #[ case( sparse_low_high( ) , SearchResult :: NotFound ( 20 ) ) ]
116
221
fn search_larger_than_left ( #[ case] array : ArrayRef , #[ case] expected : SearchResult ) {
117
222
let res = search_sorted ( & array, 66 , SearchSortedSide :: Left ) . unwrap ( ) ;
118
223
assert_eq ! ( res, expected) ;
@@ -122,6 +227,7 @@ mod tests {
122
227
#[ case( sparse_high_null_fill( ) , SearchResult :: NotFound ( 20 ) ) ]
123
228
#[ case( sparse_high_non_null_fill( ) , SearchResult :: NotFound ( 20 ) ) ]
124
229
#[ case( sparse_low( ) , SearchResult :: NotFound ( 20 ) ) ]
230
+ #[ case( sparse_low_high( ) , SearchResult :: NotFound ( 20 ) ) ]
125
231
fn search_larger_than_right ( #[ case] array : ArrayRef , #[ case] expected : SearchResult ) {
126
232
let res = search_sorted ( & array, 66 , SearchSortedSide :: Right ) . unwrap ( ) ;
127
233
assert_eq ! ( res, expected) ;
@@ -131,6 +237,7 @@ mod tests {
131
237
#[ case( sparse_high_null_fill( ) , SearchResult :: NotFound ( 17 ) ) ]
132
238
#[ case( sparse_high_non_null_fill( ) , SearchResult :: NotFound ( 0 ) ) ]
133
239
#[ case( sparse_low( ) , SearchResult :: NotFound ( 0 ) ) ]
240
+ #[ case( sparse_low_high( ) , SearchResult :: NotFound ( 1 ) ) ]
134
241
fn search_less_than_left ( #[ case] array : ArrayRef , #[ case] expected : SearchResult ) {
135
242
let res = search_sorted ( & array, 21 , SearchSortedSide :: Left ) . unwrap ( ) ;
136
243
assert_eq ! ( res, expected) ;
@@ -140,6 +247,7 @@ mod tests {
140
247
#[ case( sparse_high_null_fill( ) , SearchResult :: NotFound ( 17 ) ) ]
141
248
#[ case( sparse_high_non_null_fill( ) , SearchResult :: NotFound ( 0 ) ) ]
142
249
#[ case( sparse_low( ) , SearchResult :: NotFound ( 0 ) ) ]
250
+ #[ case( sparse_low_high( ) , SearchResult :: NotFound ( 1 ) ) ]
143
251
fn search_less_than_right ( #[ case] array : ArrayRef , #[ case] expected : SearchResult ) {
144
252
let res = search_sorted ( & array, 21 , SearchSortedSide :: Right ) . unwrap ( ) ;
145
253
assert_eq ! ( res, expected) ;
@@ -149,6 +257,7 @@ mod tests {
149
257
#[ case( sparse_high_null_fill( ) , SearchResult :: Found ( 18 ) ) ]
150
258
#[ case( sparse_high_non_null_fill( ) , SearchResult :: Found ( 18 ) ) ]
151
259
#[ case( sparse_low( ) , SearchResult :: Found ( 1 ) ) ]
260
+ #[ case( sparse_low_high( ) , SearchResult :: Found ( 18 ) ) ]
152
261
fn search_patches_found_left ( #[ case] array : ArrayRef , #[ case] expected : SearchResult ) {
153
262
let res = search_sorted ( & array, 44 , SearchSortedSide :: Left ) . unwrap ( ) ;
154
263
assert_eq ! ( res, expected) ;
@@ -158,11 +267,32 @@ mod tests {
158
267
#[ case( sparse_high_null_fill( ) , SearchResult :: Found ( 19 ) ) ]
159
268
#[ case( sparse_high_non_null_fill( ) , SearchResult :: Found ( 19 ) ) ]
160
269
#[ case( sparse_low( ) , SearchResult :: Found ( 2 ) ) ]
270
+ #[ case( sparse_low_high( ) , SearchResult :: Found ( 19 ) ) ]
161
271
fn search_patches_found_right ( #[ case] array : ArrayRef , #[ case] expected : SearchResult ) {
162
272
let res = search_sorted ( & array, 44 , SearchSortedSide :: Right ) . unwrap ( ) ;
163
273
assert_eq ! ( res, expected) ;
164
274
}
165
275
276
+ #[ rstest]
277
+ #[ case( sparse_high_null_fill( ) , SearchResult :: NotFound ( 19 ) ) ]
278
+ #[ case( sparse_high_non_null_fill( ) , SearchResult :: NotFound ( 19 ) ) ]
279
+ #[ case( sparse_low( ) , SearchResult :: NotFound ( 2 ) ) ]
280
+ #[ case( sparse_low_high( ) , SearchResult :: NotFound ( 19 ) ) ]
281
+ fn search_mid_patches_not_found_left ( #[ case] array : ArrayRef , #[ case] expected : SearchResult ) {
282
+ let res = search_sorted ( & array, 45 , SearchSortedSide :: Left ) . unwrap ( ) ;
283
+ assert_eq ! ( res, expected) ;
284
+ }
285
+
286
+ #[ rstest]
287
+ #[ case( sparse_high_null_fill( ) , SearchResult :: NotFound ( 19 ) ) ]
288
+ #[ case( sparse_high_non_null_fill( ) , SearchResult :: NotFound ( 19 ) ) ]
289
+ #[ case( sparse_low( ) , SearchResult :: NotFound ( 2 ) ) ]
290
+ #[ case( sparse_low_high( ) , SearchResult :: NotFound ( 19 ) ) ]
291
+ fn search_mid_patches_not_found_right ( #[ case] array : ArrayRef , #[ case] expected : SearchResult ) {
292
+ let res = search_sorted ( & array, 45 , SearchSortedSide :: Right ) . unwrap ( ) ;
293
+ assert_eq ! ( res, expected) ;
294
+ }
295
+
166
296
#[ rstest]
167
297
#[ should_panic]
168
298
#[ case( sparse_high_null_fill( ) , Scalar :: null_typed:: <i32 >( ) , SearchResult :: Found ( 18 ) ) ]
@@ -176,6 +306,31 @@ mod tests {
176
306
Scalar :: primitive( 60 , Nullability :: NonNullable ) ,
177
307
SearchResult :: Found ( 3 )
178
308
) ]
309
+ #[ case(
310
+ sparse_low_high( ) ,
311
+ Scalar :: primitive( 30 , Nullability :: NonNullable ) ,
312
+ SearchResult :: Found ( 2 )
313
+ ) ]
314
+ #[ case(
315
+ sparse_high_fill_in_patches( ) ,
316
+ Scalar :: primitive( 33 , Nullability :: NonNullable ) ,
317
+ SearchResult :: Found ( 0 )
318
+ ) ]
319
+ #[ case(
320
+ sparse_low_fill_in_patches( ) ,
321
+ Scalar :: primitive( 55 , Nullability :: NonNullable ) ,
322
+ SearchResult :: Found ( 2 )
323
+ ) ]
324
+ #[ case(
325
+ sparse_low_high_fill_in_patches_low( ) ,
326
+ Scalar :: primitive( 22 , Nullability :: NonNullable ) ,
327
+ SearchResult :: Found ( 1 )
328
+ ) ]
329
+ #[ case(
330
+ sparse_low_high_fill_in_patches_high( ) ,
331
+ Scalar :: primitive( 33 , Nullability :: NonNullable ) ,
332
+ SearchResult :: Found ( 17 )
333
+ ) ]
179
334
fn search_fill_left (
180
335
#[ case] array : ArrayRef ,
181
336
#[ case] search : Scalar ,
@@ -198,6 +353,31 @@ mod tests {
198
353
Scalar :: primitive( 60 , Nullability :: NonNullable ) ,
199
354
SearchResult :: Found ( 20 )
200
355
) ]
356
+ #[ case(
357
+ sparse_low_high( ) ,
358
+ Scalar :: primitive( 30 , Nullability :: NonNullable ) ,
359
+ SearchResult :: Found ( 17 )
360
+ ) ]
361
+ #[ case(
362
+ sparse_high_fill_in_patches( ) ,
363
+ Scalar :: primitive( 33 , Nullability :: NonNullable ) ,
364
+ SearchResult :: Found ( 18 )
365
+ ) ]
366
+ #[ case(
367
+ sparse_low_fill_in_patches( ) ,
368
+ Scalar :: primitive( 55 , Nullability :: NonNullable ) ,
369
+ SearchResult :: Found ( 20 )
370
+ ) ]
371
+ #[ case(
372
+ sparse_low_high_fill_in_patches_low( ) ,
373
+ Scalar :: primitive( 22 , Nullability :: NonNullable ) ,
374
+ SearchResult :: Found ( 17 )
375
+ ) ]
376
+ #[ case(
377
+ sparse_low_high_fill_in_patches_high( ) ,
378
+ Scalar :: primitive( 33 , Nullability :: NonNullable ) ,
379
+ SearchResult :: Found ( 18 )
380
+ ) ]
201
381
fn search_fill_right (
202
382
#[ case] array : ArrayRef ,
203
383
#[ case] search : Scalar ,
0 commit comments