|
17 | 17 |
|
18 | 18 | //! Defines filter kernels
|
19 | 19 |
|
| 20 | +use std::ops::AddAssign; |
20 | 21 | use std::sync::Arc;
|
21 | 22 |
|
22 | 23 | use arrow_array::builder::BooleanBufferBuilder;
|
23 | 24 | use arrow_array::cast::AsArray;
|
24 |
| -use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType}; |
| 25 | +use arrow_array::types::{ |
| 26 | + ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, RunEndIndexType, |
| 27 | +}; |
25 | 28 | use arrow_array::*;
|
26 |
| -use arrow_buffer::{bit_util, BooleanBuffer, NullBuffer}; |
| 29 | +use arrow_buffer::{bit_util, BooleanBuffer, NullBuffer, RunEndBuffer}; |
27 | 30 | use arrow_buffer::{Buffer, MutableBuffer};
|
28 | 31 | use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
|
29 | 32 | use arrow_data::transform::MutableArrayData;
|
@@ -336,6 +339,12 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<Array
|
336 | 339 | DataType::LargeBinary => {
|
337 | 340 | Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
|
338 | 341 | }
|
| 342 | + DataType::RunEndEncoded(_, _) => { |
| 343 | + downcast_run_array!{ |
| 344 | + values => Ok(Arc::new(filter_run_end_array(values, predicate)?)), |
| 345 | + t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t) |
| 346 | + } |
| 347 | + } |
339 | 348 | DataType::Dictionary(_, _) => downcast_dictionary_array! {
|
340 | 349 | values => Ok(Arc::new(filter_dict(values, predicate))),
|
341 | 350 | t => unimplemented!("Filter not supported for dictionary type {:?}", t)
|
@@ -368,6 +377,55 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<Array
|
368 | 377 | }
|
369 | 378 | }
|
370 | 379 |
|
| 380 | +/// Filter any supported [`RunArray`] based on a [`FilterPredicate`] |
| 381 | +fn filter_run_end_array<R: RunEndIndexType>( |
| 382 | + re_arr: &RunArray<R>, |
| 383 | + pred: &FilterPredicate, |
| 384 | +) -> Result<RunArray<R>, ArrowError> |
| 385 | +where |
| 386 | + R::Native: Into<i64> + From<bool>, |
| 387 | + R::Native: AddAssign, |
| 388 | +{ |
| 389 | + let run_ends: &RunEndBuffer<R::Native> = re_arr.run_ends(); |
| 390 | + let mut values_filter = BooleanBufferBuilder::new(run_ends.len()); |
| 391 | + let mut new_run_ends = vec![R::default_value(); run_ends.len()]; |
| 392 | + |
| 393 | + let mut start = 0i64; |
| 394 | + let mut i = 0; |
| 395 | + let filter_values = pred.filter.values(); |
| 396 | + let mut count = R::default_value(); |
| 397 | + |
| 398 | + for end in run_ends.inner().into_iter().map(|i| (*i).into()) { |
| 399 | + let mut keep = false; |
| 400 | + // in filter_array the predicate array is checked to have the same len as the run end array |
| 401 | + // this means the largest value in the run_ends is == to pred.len() |
| 402 | + // so we're always within bounds when calling value_unchecked |
| 403 | + for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) { |
| 404 | + count += R::Native::from(pred); |
| 405 | + keep |= pred |
| 406 | + } |
| 407 | + // this is to avoid branching |
| 408 | + new_run_ends[i] = count; |
| 409 | + i += keep as usize; |
| 410 | + |
| 411 | + values_filter.append(keep); |
| 412 | + start = end; |
| 413 | + } |
| 414 | + |
| 415 | + new_run_ends.truncate(i); |
| 416 | + |
| 417 | + if values_filter.is_empty() { |
| 418 | + new_run_ends.clear(); |
| 419 | + } |
| 420 | + |
| 421 | + let values = re_arr.values(); |
| 422 | + let pred = BooleanArray::new(values_filter.finish(), None); |
| 423 | + let values = filter(&values, &pred)?; |
| 424 | + |
| 425 | + let run_ends = PrimitiveArray::<R>::new(new_run_ends.into(), None); |
| 426 | + RunArray::try_new(&run_ends, &values) |
| 427 | +} |
| 428 | + |
371 | 429 | /// Computes a new null mask for `data` based on `predicate`
|
372 | 430 | ///
|
373 | 431 | /// If the predicate selected no null-rows, returns `None`, otherwise returns
|
@@ -635,6 +693,7 @@ where
|
635 | 693 | #[cfg(test)]
|
636 | 694 | mod tests {
|
637 | 695 | use arrow_array::builder::*;
|
| 696 | + use arrow_array::cast::as_run_array; |
638 | 697 | use arrow_array::types::*;
|
639 | 698 | use rand::distributions::{Alphanumeric, Standard};
|
640 | 699 | use rand::prelude::*;
|
@@ -844,6 +903,78 @@ mod tests {
|
844 | 903 | assert_eq!(9, d.value(1));
|
845 | 904 | }
|
846 | 905 |
|
| 906 | + #[test] |
| 907 | + fn test_filter_run_end_encoding_array() { |
| 908 | + let run_ends = Int64Array::from(vec![2, 3, 8]); |
| 909 | + let values = Int64Array::from(vec![7, -2, 9]); |
| 910 | + let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); |
| 911 | + let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]); |
| 912 | + let c = filter(&a, &b).unwrap(); |
| 913 | + let actual: &RunArray<Int64Type> = as_run_array(&c); |
| 914 | + assert_eq!(4, actual.len()); |
| 915 | + |
| 916 | + let expected = RunArray::try_new( |
| 917 | + &Int64Array::from(vec![1, 2, 4]), |
| 918 | + &Int64Array::from(vec![7, -2, 9]), |
| 919 | + ) |
| 920 | + .expect("Failed to make expected RunArray test is broken"); |
| 921 | + |
| 922 | + assert_eq!(&actual.run_ends().values(), &expected.run_ends().values()); |
| 923 | + assert_eq!(actual.values(), expected.values()) |
| 924 | + } |
| 925 | + |
| 926 | + #[test] |
| 927 | + fn test_filter_run_end_encoding_array_remove_value() { |
| 928 | + let run_ends = Int32Array::from(vec![2, 3, 8, 10]); |
| 929 | + let values = Int32Array::from(vec![7, -2, 9, -8]); |
| 930 | + let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); |
| 931 | + let b = BooleanArray::from(vec![ |
| 932 | + false, true, false, false, true, false, true, false, false, false, |
| 933 | + ]); |
| 934 | + let c = filter(&a, &b).unwrap(); |
| 935 | + let actual: &RunArray<Int32Type> = as_run_array(&c); |
| 936 | + assert_eq!(3, actual.len()); |
| 937 | + |
| 938 | + let expected = |
| 939 | + RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9])) |
| 940 | + .expect("Failed to make expected RunArray test is broken"); |
| 941 | + |
| 942 | + assert_eq!(&actual.run_ends().values(), &expected.run_ends().values()); |
| 943 | + assert_eq!(actual.values(), expected.values()) |
| 944 | + } |
| 945 | + |
| 946 | + #[test] |
| 947 | + fn test_filter_run_end_encoding_array_remove_all_but_one() { |
| 948 | + let run_ends = Int16Array::from(vec![2, 3, 8, 10]); |
| 949 | + let values = Int16Array::from(vec![7, -2, 9, -8]); |
| 950 | + let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); |
| 951 | + let b = BooleanArray::from(vec![ |
| 952 | + false, false, false, false, false, false, true, false, false, false, |
| 953 | + ]); |
| 954 | + let c = filter(&a, &b).unwrap(); |
| 955 | + let actual: &RunArray<Int16Type> = as_run_array(&c); |
| 956 | + assert_eq!(1, actual.len()); |
| 957 | + |
| 958 | + let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9])) |
| 959 | + .expect("Failed to make expected RunArray test is broken"); |
| 960 | + |
| 961 | + assert_eq!(&actual.run_ends().values(), &expected.run_ends().values()); |
| 962 | + assert_eq!(actual.values(), expected.values()) |
| 963 | + } |
| 964 | + |
| 965 | + #[test] |
| 966 | + fn test_filter_run_end_encoding_array_empty() { |
| 967 | + let run_ends = Int64Array::from(vec![2, 3, 8, 10]); |
| 968 | + let values = Int64Array::from(vec![7, -2, 9, -8]); |
| 969 | + let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); |
| 970 | + let b = BooleanArray::from(vec![ |
| 971 | + false, false, false, false, false, false, false, false, false, false, |
| 972 | + ]); |
| 973 | + let c = filter(&a, &b).unwrap(); |
| 974 | + let actual: &RunArray<Int64Type> = as_run_array(&c); |
| 975 | + assert_eq!(0, actual.len()); |
| 976 | + } |
| 977 | + |
847 | 978 | #[test]
|
848 | 979 | fn test_filter_dictionary_array() {
|
849 | 980 | let values = [Some("hello"), None, Some("world"), Some("!")];
|
|
0 commit comments