Skip to content

Commit af40ea3

Browse files
XiangpengHaoalamb
andauthored
Implement specialized min/max for GenericBinaryView (StringView and BinaryView) (apache#6089)
* implement better min/max for string view * Apply suggestions from code review Co-authored-by: Andrew Lamb <[email protected]> * address review comments --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 93e4eb2 commit af40ea3

File tree

4 files changed

+132
-9
lines changed

4 files changed

+132
-9
lines changed

arrow-arith/src/aggregate.rs

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ use arrow_buffer::{ArrowNativeType, NullBuffer};
2424
use arrow_data::bit_iterator::try_for_each_valid_idx;
2525
use arrow_schema::*;
2626
use std::borrow::BorrowMut;
27+
use std::cmp::{self, Ordering};
2728
use std::ops::{BitAnd, BitOr, BitXor};
29+
use types::ByteViewType;
2830

2931
/// An accumulator for primitive numeric values.
3032
trait NumericAccumulator<T: ArrowNativeTypeOp>: Copy + Default {
@@ -425,14 +427,55 @@ where
425427
}
426428
}
427429

430+
/// Helper to compute min/max of [`GenericByteViewArray<T>`].
431+
/// The specialized min/max leverages the inlined values to compare the byte views.
432+
/// `swap_cond` is the condition to swap current min/max with the new value.
433+
/// For example, `Ordering::Greater` for max and `Ordering::Less` for min.
434+
fn min_max_view_helper<T: ByteViewType>(
435+
array: &GenericByteViewArray<T>,
436+
swap_cond: cmp::Ordering,
437+
) -> Option<&T::Native> {
438+
let null_count = array.null_count();
439+
if null_count == array.len() {
440+
None
441+
} else if null_count == 0 {
442+
let target_idx = (0..array.len()).reduce(|acc, item| {
443+
// SAFETY: array's length is correct so item is within bounds
444+
let cmp = unsafe { GenericByteViewArray::compare_unchecked(array, item, array, acc) };
445+
if cmp == swap_cond {
446+
item
447+
} else {
448+
acc
449+
}
450+
});
451+
// SAFETY: idx came from valid range `0..array.len()`
452+
unsafe { target_idx.map(|idx| array.value_unchecked(idx)) }
453+
} else {
454+
let nulls = array.nulls().unwrap();
455+
456+
let target_idx = nulls.valid_indices().reduce(|acc_idx, idx| {
457+
let cmp =
458+
unsafe { GenericByteViewArray::compare_unchecked(array, idx, array, acc_idx) };
459+
if cmp == swap_cond {
460+
idx
461+
} else {
462+
acc_idx
463+
}
464+
});
465+
466+
// SAFETY: idx came from valid range `0..array.len()`
467+
unsafe { target_idx.map(|idx| array.value_unchecked(idx)) }
468+
}
469+
}
470+
428471
/// Returns the maximum value in the binary array, according to the natural order.
429472
pub fn max_binary<T: OffsetSizeTrait>(array: &GenericBinaryArray<T>) -> Option<&[u8]> {
430473
min_max_helper::<&[u8], _, _>(array, |a, b| *a < *b)
431474
}
432475

433476
/// Returns the maximum value in the binary view array, according to the natural order.
434477
pub fn max_binary_view(array: &BinaryViewArray) -> Option<&[u8]> {
435-
min_max_helper::<&[u8], _, _>(array, |a, b| *a < *b)
478+
min_max_view_helper(array, Ordering::Greater)
436479
}
437480

438481
/// Returns the minimum value in the binary array, according to the natural order.
@@ -442,7 +485,7 @@ pub fn min_binary<T: OffsetSizeTrait>(array: &GenericBinaryArray<T>) -> Option<&
442485

443486
/// Returns the minimum value in the binary view array, according to the natural order.
444487
pub fn min_binary_view(array: &BinaryViewArray) -> Option<&[u8]> {
445-
min_max_helper::<&[u8], _, _>(array, |a, b| *a > *b)
488+
min_max_view_helper(array, Ordering::Less)
446489
}
447490

448491
/// Returns the maximum value in the string array, according to the natural order.
@@ -452,7 +495,7 @@ pub fn max_string<T: OffsetSizeTrait>(array: &GenericStringArray<T>) -> Option<&
452495

453496
/// Returns the maximum value in the string view array, according to the natural order.
454497
pub fn max_string_view(array: &StringViewArray) -> Option<&str> {
455-
min_max_helper::<&str, _, _>(array, |a, b| *a < *b)
498+
min_max_view_helper(array, Ordering::Greater)
456499
}
457500

458501
/// Returns the minimum value in the string array, according to the natural order.
@@ -462,7 +505,7 @@ pub fn min_string<T: OffsetSizeTrait>(array: &GenericStringArray<T>) -> Option<&
462505

463506
/// Returns the minimum value in the string view array, according to the natural order.
464507
pub fn min_string_view(array: &StringViewArray) -> Option<&str> {
465-
min_max_helper::<&str, _, _>(array, |a, b| *a > *b)
508+
min_max_view_helper(array, Ordering::Less)
466509
}
467510

468511
/// Returns the sum of values in the array.

arrow-array/src/array/byte_view_array.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,66 @@ impl<T: ByteViewType + ?Sized> GenericByteViewArray<T> {
336336

337337
builder.finish()
338338
}
339+
340+
/// Comparing two [`GenericByteViewArray`] at index `left_idx` and `right_idx`
341+
///
342+
/// Comparing two ByteView types are non-trivial.
343+
/// It takes a bit of patience to understand why we don't just compare two &[u8] directly.
344+
///
345+
/// ByteView types give us the following two advantages, and we need to be careful not to lose them:
346+
/// (1) For string/byte smaller than 12 bytes, the entire data is inlined in the view.
347+
/// Meaning that reading one array element requires only one memory access
348+
/// (two memory access required for StringArray, one for offset buffer, the other for value buffer).
349+
///
350+
/// (2) For string/byte larger than 12 bytes, we can still be faster than (for certain operations) StringArray/ByteArray,
351+
/// thanks to the inlined 4 bytes.
352+
/// Consider equality check:
353+
/// If the first four bytes of the two strings are different, we can return false immediately (with just one memory access).
354+
///
355+
/// If we directly compare two &[u8], we materialize the entire string (i.e., make multiple memory accesses), which might be unnecessary.
356+
/// - Most of the time (eq, ord), we only need to look at the first 4 bytes to know the answer,
357+
/// e.g., if the inlined 4 bytes are different, we can directly return unequal without looking at the full string.
358+
///
359+
/// # Order check flow
360+
/// (1) if both string are smaller than 12 bytes, we can directly compare the data inlined to the view.
361+
/// (2) if any of the string is larger than 12 bytes, we need to compare the full string.
362+
/// (2.1) if the inlined 4 bytes are different, we can return the result immediately.
363+
/// (2.2) o.w., we need to compare the full string.
364+
///
365+
/// # Safety
366+
/// The left/right_idx must within range of each array
367+
pub unsafe fn compare_unchecked(
368+
left: &GenericByteViewArray<T>,
369+
left_idx: usize,
370+
right: &GenericByteViewArray<T>,
371+
right_idx: usize,
372+
) -> std::cmp::Ordering {
373+
let l_view = left.views().get_unchecked(left_idx);
374+
let l_len = *l_view as u32;
375+
376+
let r_view = right.views().get_unchecked(right_idx);
377+
let r_len = *r_view as u32;
378+
379+
if l_len <= 12 && r_len <= 12 {
380+
let l_data = unsafe { GenericByteViewArray::<T>::inline_value(l_view, l_len as usize) };
381+
let r_data = unsafe { GenericByteViewArray::<T>::inline_value(r_view, r_len as usize) };
382+
return l_data.cmp(r_data);
383+
}
384+
385+
// one of the string is larger than 12 bytes,
386+
// we then try to compare the inlined data first
387+
let l_inlined_data = unsafe { GenericByteViewArray::<T>::inline_value(l_view, 4) };
388+
let r_inlined_data = unsafe { GenericByteViewArray::<T>::inline_value(r_view, 4) };
389+
if r_inlined_data != l_inlined_data {
390+
return l_inlined_data.cmp(r_inlined_data);
391+
}
392+
393+
// unfortunately, we need to compare the full data
394+
let l_full_data: &[u8] = unsafe { left.value_unchecked(left_idx).as_ref() };
395+
let r_full_data: &[u8] = unsafe { right.value_unchecked(right_idx).as_ref() };
396+
397+
l_full_data.cmp(r_full_data)
398+
}
339399
}
340400

341401
impl<T: ByteViewType + ?Sized> Debug for GenericByteViewArray<T> {

arrow-ord/src/cmp.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -579,13 +579,13 @@ impl<'a, T: ByteViewType> ArrayOrd for &'a GenericByteViewArray<T> {
579579
return false;
580580
}
581581

582-
unsafe { compare_byte_view_unchecked(l.0, l.1, r.0, r.1).is_eq() }
582+
unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0, r.1).is_eq() }
583583
}
584584

585585
fn is_lt(l: Self::Item, r: Self::Item) -> bool {
586586
// # Safety
587587
// The index is within bounds as it is checked in value()
588-
unsafe { compare_byte_view_unchecked(l.0, l.1, r.0, r.1).is_lt() }
588+
unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0, r.1).is_lt() }
589589
}
590590

591591
fn len(&self) -> usize {
@@ -626,7 +626,7 @@ pub fn compare_byte_view<T: ByteViewType>(
626626
) -> std::cmp::Ordering {
627627
assert!(left_idx < left.len());
628628
assert!(right_idx < right.len());
629-
unsafe { compare_byte_view_unchecked(left, left_idx, right, right_idx) }
629+
unsafe { GenericByteViewArray::compare_unchecked(left, left_idx, right, right_idx) }
630630
}
631631

632632
/// Comparing two [`GenericByteViewArray`] at index `left_idx` and `right_idx`
@@ -656,6 +656,7 @@ pub fn compare_byte_view<T: ByteViewType>(
656656
///
657657
/// # Safety
658658
/// The left/right_idx must within range of each array
659+
#[deprecated(note = "Use `GenericByteViewArray::compare_unchecked` instead")]
659660
pub unsafe fn compare_byte_view_unchecked<T: ByteViewType>(
660661
left: &GenericByteViewArray<T>,
661662
left_idx: usize,

arrow/benches/aggregate_kernels.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ fn add_benchmark(c: &mut Criterion) {
5757
primitive_benchmark::<Int64Type>(c, "int64");
5858

5959
{
60-
let nonnull_strings = create_string_array::<i32>(BATCH_SIZE, 0.0);
61-
let nullable_strings = create_string_array::<i32>(BATCH_SIZE, 0.5);
60+
let nonnull_strings = create_string_array_with_len::<i32>(BATCH_SIZE, 0.0, 16);
61+
let nullable_strings = create_string_array_with_len::<i32>(BATCH_SIZE, 0.5, 16);
6262
c.benchmark_group("string")
6363
.throughput(Throughput::Elements(BATCH_SIZE as u64))
6464
.bench_function("min nonnull", |b| b.iter(|| min_string(&nonnull_strings)))
@@ -67,6 +67,25 @@ fn add_benchmark(c: &mut Criterion) {
6767
.bench_function("max nullable", |b| b.iter(|| max_string(&nullable_strings)));
6868
}
6969

70+
{
71+
let nonnull_strings = create_string_view_array_with_len(BATCH_SIZE, 0.0, 16, false);
72+
let nullable_strings = create_string_view_array_with_len(BATCH_SIZE, 0.5, 16, false);
73+
c.benchmark_group("string view")
74+
.throughput(Throughput::Elements(BATCH_SIZE as u64))
75+
.bench_function("min nonnull", |b| {
76+
b.iter(|| min_string_view(&nonnull_strings))
77+
})
78+
.bench_function("max nonnull", |b| {
79+
b.iter(|| max_string_view(&nonnull_strings))
80+
})
81+
.bench_function("min nullable", |b| {
82+
b.iter(|| min_string_view(&nullable_strings))
83+
})
84+
.bench_function("max nullable", |b| {
85+
b.iter(|| max_string_view(&nullable_strings))
86+
});
87+
}
88+
7089
{
7190
let nonnull_bools_mixed = create_boolean_array(BATCH_SIZE, 0.0, 0.5);
7291
let nonnull_bools_all_false = create_boolean_array(BATCH_SIZE, 0.0, 0.0);

0 commit comments

Comments
 (0)