diff --git a/fuzz/src/sort.rs b/fuzz/src/sort.rs index 4525cea6e4..e1daca5d34 100644 --- a/fuzz/src/sort.rs +++ b/fuzz/src/sort.rs @@ -16,13 +16,7 @@ pub fn sort_canonical_array(array: &dyn Array) -> VortexResult { let mut opt_values = bool_array .boolean_buffer() .iter() - .zip( - bool_array - .validity_mask() - .vortex_expect("Failed to get logical validity") - .to_boolean_buffer() - .iter(), - ) + .zip(bool_array.validity_mask()?.to_boolean_buffer().iter()) .map(|(b, v)| v.then_some(b)) .collect::>(); opt_values.sort(); diff --git a/vortex-array/src/arrays/constant/compute/mod.rs b/vortex-array/src/arrays/constant/compute/mod.rs index db84ac6939..c96c9d0914 100644 --- a/vortex-array/src/arrays/constant/compute/mod.rs +++ b/vortex-array/src/arrays/constant/compute/mod.rs @@ -4,6 +4,7 @@ mod cast; mod compare; mod invert; mod search_sorted; +mod take; use num_traits::{CheckedMul, ToPrimitive}; use vortex_dtype::{NativePType, PType, match_each_native_ptype}; @@ -58,15 +59,15 @@ impl ComputeVTable for ConstantEncoding { Some(self) } - fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> { + fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> { Some(self) } - fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> { + fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> { Some(self) } - fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> { + fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> { Some(self) } } @@ -77,12 +78,6 @@ impl ScalarAtFn<&ConstantArray> for ConstantEncoding { } } -impl TakeFn<&ConstantArray> for ConstantEncoding { - fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult { - Ok(ConstantArray::new(array.scalar().clone(), indices.len()).into_array()) - } -} - impl SliceFn<&ConstantArray> for ConstantEncoding { fn slice(&self, array: &ConstantArray, start: usize, stop: usize) -> VortexResult { Ok(ConstantArray::new(array.scalar().clone(), stop - start).into_array()) diff --git a/vortex-array/src/arrays/constant/compute/take.rs b/vortex-array/src/arrays/constant/compute/take.rs new file mode 100644 index 0000000000..f67f105cfd --- /dev/null +++ b/vortex-array/src/arrays/constant/compute/take.rs @@ -0,0 +1,70 @@ +use vortex_error::VortexResult; +use vortex_mask::{AllOr, Mask}; +use vortex_scalar::Scalar; + +use crate::arrays::{ConstantArray, ConstantEncoding}; +use crate::builders::builder_with_capacity; +use crate::compute::TakeFn; +use crate::{Array, ArrayRef}; + +impl TakeFn<&ConstantArray> for ConstantEncoding { + fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult { + match indices.validity_mask()?.boolean_buffer() { + AllOr::All => { + Ok(ConstantArray::new(array.scalar().clone(), indices.len()).into_array()) + } + AllOr::None => Ok(ConstantArray::new( + Scalar::null(array.dtype().clone()), + indices.len(), + ) + .into_array()), + AllOr::Some(v) => { + let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array(); + + if array.scalar().is_null() { + return Ok(arr); + } + + let mut result_builder = + builder_with_capacity(&array.dtype().as_nullable(), indices.len()); + result_builder.extend_from_array(&arr)?; + result_builder.set_validity(Mask::from_buffer(v.clone())); + Ok(result_builder.finish()) + } + } + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_mask::AllOr; + + use crate::arrays::{ConstantArray, PrimitiveArray}; + use crate::compute::take; + use crate::validity::Validity; + use crate::{Array, ToCanonical}; + + #[test] + fn take_nullable_indices() { + let array = ConstantArray::new(42, 10).to_array(); + let taken = take( + &array, + &PrimitiveArray::new( + buffer![0, 5, 7], + Validity::from_iter(vec![false, true, false]), + ) + .into_array(), + ) + .unwrap(); + let valid_indices: &[usize] = &[1usize]; + assert_eq!( + taken.to_primitive().unwrap().as_slice::(), + &[42, 42, 42] + ); + assert_eq!( + taken.validity_mask().unwrap().indices(), + AllOr::Some(valid_indices) + ); + } +} diff --git a/vortex-array/src/builders/bool.rs b/vortex-array/src/builders/bool.rs index a858ea3687..38f211eed9 100644 --- a/vortex-array/src/builders/bool.rs +++ b/vortex-array/src/builders/bool.rs @@ -3,6 +3,7 @@ use std::any::Any; use arrow_buffer::BooleanBufferBuilder; use vortex_dtype::{DType, Nullability}; use vortex_error::{VortexResult, vortex_bail}; +use vortex_mask::Mask; use crate::arrays::BoolArray; use crate::builders::ArrayBuilder; @@ -85,6 +86,11 @@ impl ArrayBuilder for BoolBuilder { Ok(()) } + fn set_validity(&mut self, validity: Mask) { + self.nulls = LazyNullBufferBuilder::new(validity.len()); + self.nulls.append_validity_mask(validity); + } + fn finish(&mut self) -> ArrayRef { assert_eq!( self.nulls.len(), diff --git a/vortex-array/src/builders/extension.rs b/vortex-array/src/builders/extension.rs index dfc5fdbbeb..77522370d4 100644 --- a/vortex-array/src/builders/extension.rs +++ b/vortex-array/src/builders/extension.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use vortex_dtype::{DType, ExtDType}; use vortex_error::{VortexResult, vortex_bail}; +use vortex_mask::Mask; use vortex_scalar::ExtScalar; use crate::arrays::ExtensionArray; @@ -82,6 +83,10 @@ impl ArrayBuilder for ExtensionBuilder { array.storage().append_to_builder(self.storage.as_mut()) } + fn set_validity(&mut self, validity: Mask) { + self.storage.set_validity(validity); + } + fn finish(&mut self) -> ArrayRef { let storage = self.storage.finish(); ExtensionArray::new(self.ext_dtype(), storage).into_array() diff --git a/vortex-array/src/builders/lazy_validity_builder.rs b/vortex-array/src/builders/lazy_validity_builder.rs index 90ce193dde..38b35936cd 100644 --- a/vortex-array/src/builders/lazy_validity_builder.rs +++ b/vortex-array/src/builders/lazy_validity_builder.rs @@ -52,7 +52,6 @@ impl LazyNullBufferBuilder { .append_n(n, false); } - #[allow(dead_code)] #[inline] pub fn append_null(&mut self) { self.materialize_if_needed(); @@ -62,7 +61,6 @@ impl LazyNullBufferBuilder { .append(false); } - #[allow(dead_code)] #[inline] pub fn append(&mut self, not_null: bool) { if not_null { diff --git a/vortex-array/src/builders/list.rs b/vortex-array/src/builders/list.rs index 0ca9d1b485..733c90f486 100644 --- a/vortex-array/src/builders/list.rs +++ b/vortex-array/src/builders/list.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use vortex_dtype::Nullability::NonNullable; use vortex_dtype::{DType, NativePType, Nullability}; use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err}; +use vortex_mask::Mask; use vortex_scalar::{BinaryNumericOperator, ListScalar}; use crate::arrays::{ConstantArray, ListArray, OffsetPType}; @@ -150,6 +151,11 @@ impl ArrayBuilder for ListBuilder { Ok(()) } + fn set_validity(&mut self, validity: Mask) { + self.nulls = LazyNullBufferBuilder::new(validity.len()); + self.nulls.append_validity_mask(validity); + } + fn finish(&mut self) -> ArrayRef { assert_eq!( self.index_builder.len(), diff --git a/vortex-array/src/builders/mod.rs b/vortex-array/src/builders/mod.rs index c01199790c..c12a810aa2 100644 --- a/vortex-array/src/builders/mod.rs +++ b/vortex-array/src/builders/mod.rs @@ -17,6 +17,7 @@ pub use primitive::*; pub use varbinview::*; use vortex_dtype::{DType, match_each_native_ptype}; use vortex_error::{VortexResult, vortex_bail, vortex_err}; +use vortex_mask::Mask; use vortex_scalar::{ BinaryScalar, BoolScalar, ExtScalar, ListScalar, PrimitiveScalar, Scalar, ScalarValue, StructScalar, Utf8Scalar, @@ -57,6 +58,9 @@ pub trait ArrayBuilder: Send { /// Extends the array with the provided array, canonicalizing if necessary. fn extend_from_array(&mut self, array: &dyn Array) -> VortexResult<()>; + /// Override builders validity with the one provided + fn set_validity(&mut self, validity: Mask); + /// Constructs an Array from the builder components. /// /// # Panics diff --git a/vortex-array/src/builders/null.rs b/vortex-array/src/builders/null.rs index eeeec7ced2..10e054db97 100644 --- a/vortex-array/src/builders/null.rs +++ b/vortex-array/src/builders/null.rs @@ -2,6 +2,7 @@ use std::any::Any; use vortex_dtype::DType; use vortex_error::VortexResult; +use vortex_mask::Mask; use crate::arrays::NullArray; use crate::builders::ArrayBuilder; @@ -54,6 +55,10 @@ impl ArrayBuilder for NullBuilder { Ok(()) } + fn set_validity(&mut self, validity: Mask) { + self.length = validity.len(); + } + fn finish(&mut self) -> ArrayRef { NullArray::new(self.length).into_array() } diff --git a/vortex-array/src/builders/primitive.rs b/vortex-array/src/builders/primitive.rs index a42cb7d92c..909ae6b8b3 100644 --- a/vortex-array/src/builders/primitive.rs +++ b/vortex-array/src/builders/primitive.rs @@ -178,6 +178,11 @@ impl ArrayBuilder for PrimitiveBuilder { Ok(()) } + fn set_validity(&mut self, validity: Mask) { + self.nulls = LazyNullBufferBuilder::new(validity.len()); + self.nulls.append_validity_mask(validity); + } + fn finish(&mut self) -> ArrayRef { self.finish_into_primitive().into_array() } diff --git a/vortex-array/src/builders/struct_.rs b/vortex-array/src/builders/struct_.rs index 7183b99517..ce77676fcb 100644 --- a/vortex-array/src/builders/struct_.rs +++ b/vortex-array/src/builders/struct_.rs @@ -4,18 +4,18 @@ use std::sync::Arc; use itertools::Itertools; use vortex_dtype::{DType, Nullability, StructDType}; use vortex_error::{VortexExpect, VortexResult, vortex_bail}; +use vortex_mask::Mask; use vortex_scalar::StructScalar; use crate::arrays::StructArray; -use crate::builders::{ArrayBuilder, ArrayBuilderExt, BoolBuilder, builder_with_capacity}; -use crate::validity::Validity; +use crate::builders::lazy_validity_builder::LazyNullBufferBuilder; +use crate::builders::{ArrayBuilder, ArrayBuilderExt, builder_with_capacity}; use crate::variants::StructArrayTrait; use crate::{Array, ArrayRef, Canonical}; pub struct StructBuilder { builders: Vec>, - // TODO(ngates): this should be a NullBufferBuilder? Or mask builder? - validity: BoolBuilder, + validity: LazyNullBufferBuilder, struct_dtype: Arc, nullability: Nullability, dtype: DType, @@ -34,7 +34,7 @@ impl StructBuilder { Self { builders, - validity: BoolBuilder::with_capacity(Nullability::NonNullable, capacity), + validity: LazyNullBufferBuilder::new(capacity), struct_dtype: struct_dtype.clone(), nullability, dtype: DType::Struct(struct_dtype, nullability), @@ -54,7 +54,7 @@ impl StructBuilder { for (builder, field) in self.builders.iter_mut().zip_eq(fields) { builder.append_scalar(&field)?; } - self.validity.append_value(true); + self.validity.append_non_null(); } else { self.append_null() } @@ -84,7 +84,7 @@ impl ArrayBuilder for StructBuilder { self.builders .iter_mut() .for_each(|builder| builder.append_zeros(n)); - self.validity.append_values(true, n); + self.validity.append_n_non_nulls(n); } fn append_nulls(&mut self, n: usize) { @@ -93,7 +93,7 @@ impl ArrayBuilder for StructBuilder { // We push zero values into our children when appending a null in case the children are // themselves non-nullable. .for_each(|builder| builder.append_zeros(n)); - self.validity.append_value(false); + self.validity.append_null(); } fn extend_from_array(&mut self, array: &dyn Array) -> VortexResult<()> { @@ -120,21 +120,15 @@ impl ArrayBuilder for StructBuilder { a.append_to_builder(builder.as_mut())?; } - match array.validity() { - Validity::NonNullable | Validity::AllValid => { - self.validity.append_values(true, array.len()); - } - Validity::AllInvalid => { - self.validity.append_values(false, array.len()); - } - Validity::Array(validity) => { - validity.append_to_builder(&mut self.validity)?; - } - } - + self.validity.append_validity_mask(array.validity_mask()?); Ok(()) } + fn set_validity(&mut self, validity: Mask) { + self.validity = LazyNullBufferBuilder::new(validity.len()); + self.validity.append_validity_mask(validity); + } + fn finish(&mut self) -> ArrayRef { let len = self.len(); let fields = self @@ -156,10 +150,7 @@ impl ArrayBuilder for StructBuilder { } } - let validity = match self.nullability { - Nullability::NonNullable => Validity::NonNullable, - Nullability::Nullable => Validity::Array(self.validity.finish()), - }; + let validity = self.validity.finish_with_nullability(self.nullability); StructArray::try_new(self.struct_dtype.names().clone(), fields, len, validity) .vortex_expect("Fields must all have same length.") diff --git a/vortex-array/src/builders/varbinview.rs b/vortex-array/src/builders/varbinview.rs index 8b8f4f718c..5a062622b8 100644 --- a/vortex-array/src/builders/varbinview.rs +++ b/vortex-array/src/builders/varbinview.rs @@ -177,6 +177,11 @@ impl ArrayBuilder for VarBinViewBuilder { Ok(()) } + fn set_validity(&mut self, validity: Mask) { + self.null_buffer_builder = LazyNullBufferBuilder::new(validity.len()); + self.null_buffer_builder.append_validity_mask(validity); + } + fn finish(&mut self) -> ArrayRef { self.flush_in_progress(); let buffers = std::mem::take(&mut self.completed); diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 5ce9d27a18..c12275260c 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -147,7 +147,7 @@ impl Validity { Self::AllInvalid => Ok(Self::AllInvalid), Self::Array(is_valid) => { let maybe_is_valid = take(is_valid, indices)?; - // Null indices invalidite that position. + // Null indices invalidate that position. let is_valid = fill_null(&maybe_is_valid, Scalar::from(false))?; Ok(Self::Array(is_valid)) }