Skip to content

Commit d0a88c6

Browse files
authored
Implement compare operations for view types (#5900)
* compare kernel for view types * add binary view as well * make ci happy * better comments, better readbility * add tests * add more tests
1 parent 72467c6 commit d0a88c6

File tree

3 files changed

+290
-6
lines changed

3 files changed

+290
-6
lines changed

arrow-array/src/array/byte_view_array.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use crate::builder::GenericByteViewBuilder;
2020
use crate::iterator::ArrayIter;
2121
use crate::types::bytes::ByteArrayNativeType;
2222
use crate::types::{BinaryViewType, ByteViewType, StringViewType};
23-
use crate::{Array, ArrayAccessor, ArrayRef};
23+
use crate::{Array, ArrayAccessor, ArrayRef, Scalar};
2424
use arrow_buffer::{Buffer, NullBuffer, ScalarBuffer};
2525
use arrow_data::{ArrayData, ArrayDataBuilder, ByteView};
2626
use arrow_schema::{ArrowError, DataType};
@@ -186,6 +186,11 @@ impl<T: ByteViewType + ?Sized> GenericByteViewArray<T> {
186186
}
187187
}
188188

189+
/// Create a new [`Scalar`] from `value`
190+
pub fn new_scalar(value: impl AsRef<T::Native>) -> Scalar<Self> {
191+
Scalar::new(Self::from_iter_values(std::iter::once(value)))
192+
}
193+
189194
/// Creates a [`GenericByteViewArray`] based on an iterator of values without nulls
190195
pub fn from_iter_values<Ptr, I>(iter: I) -> Self
191196
where
@@ -239,8 +244,7 @@ impl<T: ByteViewType + ?Sized> GenericByteViewArray<T> {
239244
let v = self.views.get_unchecked(idx);
240245
let len = *v as u32;
241246
let b = if len <= 12 {
242-
let ptr = self.views.as_ptr() as *const u8;
243-
std::slice::from_raw_parts(ptr.add(idx * 16 + 4), len as usize)
247+
Self::inline_value(v, len as usize)
244248
} else {
245249
let view = ByteView::from(*v);
246250
let data = self.buffers.get_unchecked(view.buffer_index as usize);
@@ -250,6 +254,17 @@ impl<T: ByteViewType + ?Sized> GenericByteViewArray<T> {
250254
T::Native::from_bytes_unchecked(b)
251255
}
252256

257+
/// Returns the inline value of the view.
258+
///
259+
/// # Safety
260+
/// - The `view` must be a valid element from `Self::views()` that adheres to the view layout.
261+
/// - The `len` must be the length of the inlined value. It should never be larger than 12.
262+
#[inline(always)]
263+
pub unsafe fn inline_value(view: &u128, len: usize) -> &[u8] {
264+
debug_assert!(len <= 12);
265+
std::slice::from_raw_parts((view as *const u128 as *const u8).wrapping_add(4), len)
266+
}
267+
253268
/// constructs a new iterator
254269
pub fn iter(&self) -> ArrayIter<&Self> {
255270
ArrayIter::new(self)

arrow-ord/src/cmp.rs

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
//!
2525
2626
use arrow_array::cast::AsArray;
27-
use arrow_array::types::ByteArrayType;
27+
use arrow_array::types::{ByteArrayType, ByteViewType};
2828
use arrow_array::{
2929
downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum,
30-
FixedSizeBinaryArray, GenericByteArray,
30+
FixedSizeBinaryArray, GenericByteArray, GenericByteViewArray,
3131
};
3232
use arrow_buffer::bit_util::ceil;
3333
use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer};
@@ -228,8 +228,10 @@ fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
228228
(l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v),
229229
(Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v),
230230
(Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v, r.as_string::<i32>(), r_s, r_v),
231+
(Utf8View, Utf8View) => apply(op, l.as_string_view(), l_s, l_v, r.as_string_view(), r_s, r_v),
231232
(LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v, r.as_string::<i64>(), r_s, r_v),
232233
(Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v, r.as_binary::<i32>(), r_s, r_v),
234+
(BinaryView, BinaryView) => apply(op, l.as_binary_view(), l_s, l_v, r.as_binary_view(), r_s, r_v),
233235
(LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s, l_v, r.as_binary::<i64>(), r_s, r_v),
234236
(FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
235237
(Null, Null) => None,
@@ -459,7 +461,7 @@ fn apply_op_vectored<T: ArrayOrd>(
459461
}
460462

461463
trait ArrayOrd {
462-
type Item: Copy + Default;
464+
type Item: Copy;
463465

464466
fn len(&self) -> usize;
465467

@@ -538,6 +540,109 @@ impl<'a, T: ByteArrayType> ArrayOrd for &'a GenericByteArray<T> {
538540
}
539541
}
540542

543+
/// Comparing two ByteView types are non-trivial.
544+
/// It takes a bit of patience to understand why we don't just compare two &[u8] directly.
545+
///
546+
/// ByteView types give us the following two advantages, and we need to be careful not to lose them:
547+
/// (1) For string/byte smaller than 12 bytes, the entire data is inlined in the view.
548+
/// Meaning that reading one array element requires only one memory access
549+
/// (two memory access required for StringArray, one for offset buffer, the other for value buffer).
550+
///
551+
/// (2) For string/byte larger than 12 bytes, we can still be faster than (for certain operations) StringArray/ByteArray,
552+
/// thanks to the inlined 4 bytes.
553+
/// Consider equality check:
554+
/// If the first four bytes of the two strings are different, we can return false immediately (with just one memory access).
555+
/// If we are unlucky and the first four bytes are the same, we need to fallback to compare two full strings.
556+
impl<'a, T: ByteViewType> ArrayOrd for &'a GenericByteViewArray<T> {
557+
/// Item.0 is the array, Item.1 is the index into the array.
558+
/// Why don't we just store Item.0[Item.1] as the item?
559+
/// - Because if we do so, we materialize the entire string (i.e., make multiple memory accesses), which might be unnecessary.
560+
/// - Most of the time (eq, ord), we only need to look at the first 4 bytes to know the answer,
561+
/// e.g., if the inlined 4 bytes are different, we can directly return unequal without looking at the full string.
562+
type Item = (&'a GenericByteViewArray<T>, usize);
563+
564+
/// # Equality check flow
565+
/// (1) if both string are smaller than 12 bytes, we can directly compare the data inlined to the view.
566+
/// (2) if any of the string is larger than 12 bytes, we need to compare the full string.
567+
/// (2.1) if the inlined 4 bytes are different, we can return false immediately.
568+
/// (2.2) o.w., we need to compare the full string.
569+
///
570+
/// # Safety
571+
/// (1) Indexing. The Self::Item.1 encodes the index value, which is already checked in `value` function,
572+
/// so it is safe to index into the views.
573+
/// (2) Slice data from view. We know the bytes 4-8 are inlined data (per spec), so it is safe to slice from the view.
574+
fn is_eq(l: Self::Item, r: Self::Item) -> bool {
575+
let l_view = unsafe { l.0.views().get_unchecked(l.1) };
576+
let l_len = *l_view as u32;
577+
578+
let r_view = unsafe { r.0.views().get_unchecked(r.1) };
579+
let r_len = *r_view as u32;
580+
581+
if l_len != r_len {
582+
return false;
583+
}
584+
585+
if l_len <= 12 {
586+
let l_data = unsafe { GenericByteViewArray::<T>::inline_value(l_view, l_len as usize) };
587+
let r_data = unsafe { GenericByteViewArray::<T>::inline_value(r_view, r_len as usize) };
588+
l_data == r_data
589+
} else {
590+
let l_inlined_data = unsafe { GenericByteViewArray::<T>::inline_value(l_view, 4) };
591+
let r_inlined_data = unsafe { GenericByteViewArray::<T>::inline_value(r_view, 4) };
592+
if l_inlined_data != r_inlined_data {
593+
return false;
594+
}
595+
596+
let l_full_data: &[u8] = unsafe { l.0.value_unchecked(l.1).as_ref() };
597+
let r_full_data: &[u8] = unsafe { r.0.value_unchecked(r.1).as_ref() };
598+
l_full_data == r_full_data
599+
}
600+
}
601+
602+
/// # Ordering check flow
603+
/// (1) if both string are smaller than 12 bytes, we can directly compare the data inlined to the view.
604+
/// (2) if any of the string is larger than 12 bytes, we need to compare the full string.
605+
/// (2.1) if the inlined 4 bytes are different, we can return the result immediately.
606+
/// (2.2) o.w., we need to compare the full string.
607+
///
608+
/// # Safety
609+
/// (1) Indexing. The Self::Item.1 encodes the index value, which is already checked in `value` function,
610+
/// so it is safe to index into the views.
611+
/// (2) Slice data from view. We know the bytes 4-8 are inlined data (per spec), so it is safe to slice from the view.
612+
fn is_lt(l: Self::Item, r: Self::Item) -> bool {
613+
let l_view = l.0.views().get(l.1).unwrap();
614+
let l_len = *l_view as u32;
615+
616+
let r_view = r.0.views().get(r.1).unwrap();
617+
let r_len = *r_view as u32;
618+
619+
if l_len <= 12 && r_len <= 12 {
620+
let l_data = unsafe { GenericByteViewArray::<T>::inline_value(l_view, l_len as usize) };
621+
let r_data = unsafe { GenericByteViewArray::<T>::inline_value(r_view, r_len as usize) };
622+
return l_data < r_data;
623+
}
624+
// one of the string is larger than 12 bytes,
625+
// we then try to compare the inlined data first
626+
let l_inlined_data = unsafe { GenericByteViewArray::<T>::inline_value(l_view, 4) };
627+
let r_inlined_data = unsafe { GenericByteViewArray::<T>::inline_value(r_view, 4) };
628+
if r_inlined_data != l_inlined_data {
629+
return l_inlined_data < r_inlined_data;
630+
}
631+
// unfortunately, we need to compare the full data
632+
let l_full_data: &[u8] = unsafe { l.0.value_unchecked(l.1).as_ref() };
633+
let r_full_data: &[u8] = unsafe { r.0.value_unchecked(r.1).as_ref() };
634+
l_full_data < r_full_data
635+
}
636+
637+
fn len(&self) -> usize {
638+
Array::len(self)
639+
}
640+
641+
unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
642+
(self, idx)
643+
}
644+
}
645+
541646
impl<'a> ArrayOrd for &'a FixedSizeBinaryArray {
542647
type Item = &'a [u8];
543648

0 commit comments

Comments
 (0)