Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ConstantArray#take handles nullable indices #2631

Merged
merged 4 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions fuzz/src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,7 @@ pub fn sort_canonical_array(array: &dyn Array) -> VortexResult<ArrayRef> {
let mut opt_values = bool_array
.boolean_buffer()
.iter()
.zip(
bool_array
.validity_mask()
.vortex_expect("Failed to get logical validity")
.to_boolean_buffer()
.iter(),
)
.zip(bool_array.validity_mask()?.to_boolean_buffer().iter())
.map(|(b, v)| v.then_some(b))
.collect::<Vec<_>>();
opt_values.sort();
Expand Down
13 changes: 4 additions & 9 deletions vortex-array/src/arrays/constant/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod cast;
mod compare;
mod invert;
mod search_sorted;
mod take;

use num_traits::{CheckedMul, ToPrimitive};
use vortex_dtype::{NativePType, PType, match_each_native_ptype};
Expand Down Expand Up @@ -58,15 +59,15 @@ impl ComputeVTable for ConstantEncoding {
Some(self)
}

fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
Some(self)
}

fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
Some(self)
}

fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
Some(self)
}
}
Expand All @@ -77,12 +78,6 @@ impl ScalarAtFn<&ConstantArray> for ConstantEncoding {
}
}

impl TakeFn<&ConstantArray> for ConstantEncoding {
fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
Ok(ConstantArray::new(array.scalar().clone(), indices.len()).into_array())
}
}

impl SliceFn<&ConstantArray> for ConstantEncoding {
fn slice(&self, array: &ConstantArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
Ok(ConstantArray::new(array.scalar().clone(), stop - start).into_array())
Expand Down
70 changes: 70 additions & 0 deletions vortex-array/src/arrays/constant/compute/take.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use vortex_error::VortexResult;
use vortex_mask::{AllOr, Mask};
use vortex_scalar::Scalar;

use crate::arrays::{ConstantArray, ConstantEncoding};
use crate::builders::builder_with_capacity;
use crate::compute::TakeFn;
use crate::{Array, ArrayRef};

impl TakeFn<&ConstantArray> for ConstantEncoding {
fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
match indices.validity_mask()?.boolean_buffer() {
AllOr::All => {
Ok(ConstantArray::new(array.scalar().clone(), indices.len()).into_array())
}
AllOr::None => Ok(ConstantArray::new(
Scalar::null(array.dtype().clone()),
indices.len(),
)
.into_array()),
AllOr::Some(v) => {
let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();

if array.scalar().is_null() {
return Ok(arr);
}

let mut result_builder =
builder_with_capacity(&array.dtype().as_nullable(), indices.len());
result_builder.extend_from_array(&arr)?;
result_builder.set_validity(Mask::from_buffer(v.clone()));
Ok(result_builder.finish())
}
}
}
}

#[cfg(test)]
mod tests {
use vortex_buffer::buffer;
use vortex_mask::AllOr;

use crate::arrays::{ConstantArray, PrimitiveArray};
use crate::compute::take;
use crate::validity::Validity;
use crate::{Array, ToCanonical};

#[test]
fn take_nullable_indices() {
let array = ConstantArray::new(42, 10).to_array();
let taken = take(
&array,
&PrimitiveArray::new(
buffer![0, 5, 7],
Validity::from_iter(vec![false, true, false]),
)
.into_array(),
)
.unwrap();
let valid_indices: &[usize] = &[1usize];
assert_eq!(
taken.to_primitive().unwrap().as_slice::<i32>(),
&[42, 42, 42]
);
assert_eq!(
taken.validity_mask().unwrap().indices(),
AllOr::Some(valid_indices)
);
}
}
6 changes: 6 additions & 0 deletions vortex-array/src/builders/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::any::Any;
use arrow_buffer::BooleanBufferBuilder;
use vortex_dtype::{DType, Nullability};
use vortex_error::{VortexResult, vortex_bail};
use vortex_mask::Mask;

use crate::arrays::BoolArray;
use crate::builders::ArrayBuilder;
Expand Down Expand Up @@ -85,6 +86,11 @@ impl ArrayBuilder for BoolBuilder {
Ok(())
}

fn set_validity(&mut self, validity: Mask) {
self.nulls = LazyNullBufferBuilder::new(validity.len());
self.nulls.append_validity_mask(validity);
}

fn finish(&mut self) -> ArrayRef {
assert_eq!(
self.nulls.len(),
Expand Down
5 changes: 5 additions & 0 deletions vortex-array/src/builders/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::Arc;

use vortex_dtype::{DType, ExtDType};
use vortex_error::{VortexResult, vortex_bail};
use vortex_mask::Mask;
use vortex_scalar::ExtScalar;

use crate::arrays::ExtensionArray;
Expand Down Expand Up @@ -82,6 +83,10 @@ impl ArrayBuilder for ExtensionBuilder {
array.storage().append_to_builder(self.storage.as_mut())
}

fn set_validity(&mut self, validity: Mask) {
self.storage.set_validity(validity);
}

fn finish(&mut self) -> ArrayRef {
let storage = self.storage.finish();
ExtensionArray::new(self.ext_dtype(), storage).into_array()
Expand Down
2 changes: 0 additions & 2 deletions vortex-array/src/builders/lazy_validity_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ impl LazyNullBufferBuilder {
.append_n(n, false);
}

#[allow(dead_code)]
#[inline]
pub fn append_null(&mut self) {
self.materialize_if_needed();
Expand All @@ -62,7 +61,6 @@ impl LazyNullBufferBuilder {
.append(false);
}

#[allow(dead_code)]
#[inline]
pub fn append(&mut self, not_null: bool) {
if not_null {
Expand Down
6 changes: 6 additions & 0 deletions vortex-array/src/builders/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;
use vortex_dtype::Nullability::NonNullable;
use vortex_dtype::{DType, NativePType, Nullability};
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
use vortex_mask::Mask;
use vortex_scalar::{BinaryNumericOperator, ListScalar};

use crate::arrays::{ConstantArray, ListArray, OffsetPType};
Expand Down Expand Up @@ -150,6 +151,11 @@ impl<O: OffsetPType> ArrayBuilder for ListBuilder<O> {
Ok(())
}

fn set_validity(&mut self, validity: Mask) {
self.nulls = LazyNullBufferBuilder::new(validity.len());
self.nulls.append_validity_mask(validity);
}

fn finish(&mut self) -> ArrayRef {
assert_eq!(
self.index_builder.len(),
Expand Down
4 changes: 4 additions & 0 deletions vortex-array/src/builders/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub use primitive::*;
pub use varbinview::*;
use vortex_dtype::{DType, match_each_native_ptype};
use vortex_error::{VortexResult, vortex_bail, vortex_err};
use vortex_mask::Mask;
use vortex_scalar::{
BinaryScalar, BoolScalar, ExtScalar, ListScalar, PrimitiveScalar, Scalar, ScalarValue,
StructScalar, Utf8Scalar,
Expand Down Expand Up @@ -57,6 +58,9 @@ pub trait ArrayBuilder: Send {
/// Extends the array with the provided array, canonicalizing if necessary.
fn extend_from_array(&mut self, array: &dyn Array) -> VortexResult<()>;

/// Override builders validity with the one provided
fn set_validity(&mut self, validity: Mask);

/// Constructs an Array from the builder components.
///
/// # Panics
Expand Down
5 changes: 5 additions & 0 deletions vortex-array/src/builders/null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::any::Any;

use vortex_dtype::DType;
use vortex_error::VortexResult;
use vortex_mask::Mask;

use crate::arrays::NullArray;
use crate::builders::ArrayBuilder;
Expand Down Expand Up @@ -54,6 +55,10 @@ impl ArrayBuilder for NullBuilder {
Ok(())
}

fn set_validity(&mut self, validity: Mask) {
self.length = validity.len();
}

fn finish(&mut self) -> ArrayRef {
NullArray::new(self.length).into_array()
}
Expand Down
5 changes: 5 additions & 0 deletions vortex-array/src/builders/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ impl<T: NativePType> ArrayBuilder for PrimitiveBuilder<T> {
Ok(())
}

fn set_validity(&mut self, validity: Mask) {
self.nulls = LazyNullBufferBuilder::new(validity.len());
self.nulls.append_validity_mask(validity);
}

fn finish(&mut self) -> ArrayRef {
self.finish_into_primitive().into_array()
}
Expand Down
39 changes: 15 additions & 24 deletions vortex-array/src/builders/struct_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ use std::sync::Arc;
use itertools::Itertools;
use vortex_dtype::{DType, Nullability, StructDType};
use vortex_error::{VortexExpect, VortexResult, vortex_bail};
use vortex_mask::Mask;
use vortex_scalar::StructScalar;

use crate::arrays::StructArray;
use crate::builders::{ArrayBuilder, ArrayBuilderExt, BoolBuilder, builder_with_capacity};
use crate::validity::Validity;
use crate::builders::lazy_validity_builder::LazyNullBufferBuilder;
use crate::builders::{ArrayBuilder, ArrayBuilderExt, builder_with_capacity};
use crate::variants::StructArrayTrait;
use crate::{Array, ArrayRef, Canonical};

pub struct StructBuilder {
builders: Vec<Box<dyn ArrayBuilder>>,
// TODO(ngates): this should be a NullBufferBuilder? Or mask builder?
validity: BoolBuilder,
validity: LazyNullBufferBuilder,
struct_dtype: Arc<StructDType>,
nullability: Nullability,
dtype: DType,
Expand All @@ -34,7 +34,7 @@ impl StructBuilder {

Self {
builders,
validity: BoolBuilder::with_capacity(Nullability::NonNullable, capacity),
validity: LazyNullBufferBuilder::new(capacity),
struct_dtype: struct_dtype.clone(),
nullability,
dtype: DType::Struct(struct_dtype, nullability),
Expand All @@ -54,7 +54,7 @@ impl StructBuilder {
for (builder, field) in self.builders.iter_mut().zip_eq(fields) {
builder.append_scalar(&field)?;
}
self.validity.append_value(true);
self.validity.append_non_null();
} else {
self.append_null()
}
Expand Down Expand Up @@ -84,7 +84,7 @@ impl ArrayBuilder for StructBuilder {
self.builders
.iter_mut()
.for_each(|builder| builder.append_zeros(n));
self.validity.append_values(true, n);
self.validity.append_n_non_nulls(n);
}

fn append_nulls(&mut self, n: usize) {
Expand All @@ -93,7 +93,7 @@ impl ArrayBuilder for StructBuilder {
// We push zero values into our children when appending a null in case the children are
// themselves non-nullable.
.for_each(|builder| builder.append_zeros(n));
self.validity.append_value(false);
self.validity.append_null();
}

fn extend_from_array(&mut self, array: &dyn Array) -> VortexResult<()> {
Expand All @@ -120,21 +120,15 @@ impl ArrayBuilder for StructBuilder {
a.append_to_builder(builder.as_mut())?;
}

match array.validity() {
Validity::NonNullable | Validity::AllValid => {
self.validity.append_values(true, array.len());
}
Validity::AllInvalid => {
self.validity.append_values(false, array.len());
}
Validity::Array(validity) => {
validity.append_to_builder(&mut self.validity)?;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this builder here was wrong, not sure how a bug like this could happen...

}
}

self.validity.append_validity_mask(array.validity_mask()?);
Ok(())
}

fn set_validity(&mut self, validity: Mask) {
self.validity = LazyNullBufferBuilder::new(validity.len());
self.validity.append_validity_mask(validity);
}

fn finish(&mut self) -> ArrayRef {
let len = self.len();
let fields = self
Expand All @@ -156,10 +150,7 @@ impl ArrayBuilder for StructBuilder {
}
}

let validity = match self.nullability {
Nullability::NonNullable => Validity::NonNullable,
Nullability::Nullable => Validity::Array(self.validity.finish()),
};
let validity = self.validity.finish_with_nullability(self.nullability);

StructArray::try_new(self.struct_dtype.names().clone(), fields, len, validity)
.vortex_expect("Fields must all have same length.")
Expand Down
5 changes: 5 additions & 0 deletions vortex-array/src/builders/varbinview.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ impl ArrayBuilder for VarBinViewBuilder {
Ok(())
}

fn set_validity(&mut self, validity: Mask) {
self.null_buffer_builder = LazyNullBufferBuilder::new(validity.len());
self.null_buffer_builder.append_validity_mask(validity);
}

fn finish(&mut self) -> ArrayRef {
self.flush_in_progress();
let buffers = std::mem::take(&mut self.completed);
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/validity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl Validity {
Self::AllInvalid => Ok(Self::AllInvalid),
Self::Array(is_valid) => {
let maybe_is_valid = take(is_valid, indices)?;
// Null indices invalidite that position.
// Null indices invalidate that position.
let is_valid = fill_null(&maybe_is_valid, Scalar::from(false))?;
Ok(Self::Array(is_valid))
}
Expand Down
Loading