Skip to content

Commit 144c9c7

Browse files
Implement take kernel for byte view array. (#5602)
* impl take kernel for byte view array. * Add unit tests. * Use ArrayData equality * Rename to byte_view --------- Co-authored-by: Raphael Taylor-Davies <[email protected]>
1 parent 16f4a7f commit 144c9c7

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

arrow-array/src/cast.rs

+36
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,34 @@ pub trait AsArray: private::Sealed {
779779
self.as_bytes_opt().expect("binary array")
780780
}
781781

782+
/// Downcast this to a [`StringViewArray`] returning `None` if not possible
783+
fn as_string_view(&self) -> &StringViewArray {
784+
self.as_byte_view_opt().expect("string view array")
785+
}
786+
787+
/// Downcast this to a [`StringViewArray`] returning `None` if not possible
788+
fn as_string_view_opt(&self) -> Option<&StringViewArray> {
789+
self.as_byte_view_opt()
790+
}
791+
792+
/// Downcast this to a [`StringViewArray`] returning `None` if not possible
793+
fn as_binary_view(&self) -> &BinaryViewArray {
794+
self.as_byte_view_opt().expect("binary view array")
795+
}
796+
797+
/// Downcast this to a [`BinaryViewArray`] returning `None` if not possible
798+
fn as_binary_view_opt(&self) -> Option<&BinaryViewArray> {
799+
self.as_byte_view_opt()
800+
}
801+
802+
/// Downcast this to a [`GenericByteViewArray`] returning `None` if not possible
803+
fn as_byte_view<T: ByteViewType>(&self) -> &GenericByteViewArray<T> {
804+
self.as_byte_view_opt().expect("byte view array")
805+
}
806+
807+
/// Downcast this to a [`GenericByteViewArray`] returning `None` if not possible
808+
fn as_byte_view_opt<T: ByteViewType>(&self) -> Option<&GenericByteViewArray<T>>;
809+
782810
/// Downcast this to a [`StructArray`] returning `None` if not possible
783811
fn as_struct_opt(&self) -> Option<&StructArray>;
784812

@@ -852,6 +880,10 @@ impl AsArray for dyn Array + '_ {
852880
self.as_any().downcast_ref()
853881
}
854882

883+
fn as_byte_view_opt<T: ByteViewType>(&self) -> Option<&GenericByteViewArray<T>> {
884+
self.as_any().downcast_ref()
885+
}
886+
855887
fn as_struct_opt(&self) -> Option<&StructArray> {
856888
self.as_any().downcast_ref()
857889
}
@@ -899,6 +931,10 @@ impl AsArray for ArrayRef {
899931
self.as_ref().as_bytes_opt()
900932
}
901933

934+
fn as_byte_view_opt<T: ByteViewType>(&self) -> Option<&GenericByteViewArray<T>> {
935+
self.as_ref().as_byte_view_opt()
936+
}
937+
902938
fn as_struct_opt(&self) -> Option<&StructArray> {
903939
self.as_ref().as_struct_opt()
904940
}

arrow-select/src/take.rs

+67
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
143143
DataType::LargeUtf8 => {
144144
Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
145145
}
146+
DataType::Utf8View => {
147+
Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
148+
}
146149
DataType::List(_) => {
147150
Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
148151
}
@@ -204,6 +207,9 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
204207
DataType::LargeBinary => {
205208
Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
206209
}
210+
DataType::BinaryView => {
211+
Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
212+
}
207213
DataType::FixedSizeBinary(size) => {
208214
let values = values
209215
.as_any()
@@ -437,6 +443,20 @@ fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
437443
Ok(GenericByteArray::from(array_data))
438444
}
439445

446+
/// `take` implementation for byte view arrays
447+
fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
448+
array: &GenericByteViewArray<T>,
449+
indices: &PrimitiveArray<IndexType>,
450+
) -> Result<GenericByteViewArray<T>, ArrowError> {
451+
let new_views = take_native(array.views(), indices);
452+
let new_nulls = take_nulls(array.nulls(), indices);
453+
Ok(GenericByteViewArray::new(
454+
new_views,
455+
array.data_buffers().to_vec(),
456+
new_nulls,
457+
))
458+
}
459+
440460
/// `take` implementation for list arrays
441461
///
442462
/// Calculates the index and indexed offset for the inner array,
@@ -1424,6 +1444,53 @@ mod tests {
14241444
assert_eq!(result.as_ref(), &expected);
14251445
}
14261446

1447+
fn _test_byte_view<T>()
1448+
where
1449+
T: ByteViewType,
1450+
str: AsRef<T::Native>,
1451+
T::Native: PartialEq,
1452+
{
1453+
let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1454+
let array = {
1455+
// ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1456+
let mut builder = GenericByteViewBuilder::<T>::new();
1457+
builder.append_value("hello");
1458+
builder.append_value("world");
1459+
builder.append_null();
1460+
builder.append_value("large payload over 12 bytes");
1461+
builder.append_value("lulu");
1462+
builder.finish()
1463+
};
1464+
1465+
let actual = take(&array, &index, None).unwrap();
1466+
1467+
assert_eq!(actual.len(), index.len());
1468+
1469+
let expected = {
1470+
// ["large payload over 12 bytes", null, "world", "large payload over 12 bytes", "lulu", null]
1471+
let mut builder = GenericByteViewBuilder::<T>::new();
1472+
builder.append_value("large payload over 12 bytes");
1473+
builder.append_null();
1474+
builder.append_value("world");
1475+
builder.append_value("large payload over 12 bytes");
1476+
builder.append_value("lulu");
1477+
builder.append_null();
1478+
builder.finish()
1479+
};
1480+
1481+
assert_eq!(actual.as_ref(), &expected);
1482+
}
1483+
1484+
#[test]
1485+
fn test_take_string_view() {
1486+
_test_byte_view::<StringViewType>()
1487+
}
1488+
1489+
#[test]
1490+
fn test_take_binary_view() {
1491+
_test_byte_view::<BinaryViewType>()
1492+
}
1493+
14271494
macro_rules! test_take_list {
14281495
($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
14291496
// Construct a value array, [[0,0,0], [-1,-2,-1], [], [2,3]]

0 commit comments

Comments
 (0)