Skip to content

Commit 1ab0455

Browse files
authored
fix: ConstantArray#take handles nullable indices (#2631)
1 parent 153b26f commit 1ab0455

File tree

13 files changed

+127
-43
lines changed

13 files changed

+127
-43
lines changed

fuzz/src/sort.rs

+1-7
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,7 @@ pub fn sort_canonical_array(array: &dyn Array) -> VortexResult<ArrayRef> {
1616
let mut opt_values = bool_array
1717
.boolean_buffer()
1818
.iter()
19-
.zip(
20-
bool_array
21-
.validity_mask()
22-
.vortex_expect("Failed to get logical validity")
23-
.to_boolean_buffer()
24-
.iter(),
25-
)
19+
.zip(bool_array.validity_mask()?.to_boolean_buffer().iter())
2620
.map(|(b, v)| v.then_some(b))
2721
.collect::<Vec<_>>();
2822
opt_values.sort();

vortex-array/src/arrays/constant/compute/mod.rs

+4-9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ mod cast;
44
mod compare;
55
mod invert;
66
mod search_sorted;
7+
mod take;
78

89
use num_traits::{CheckedMul, ToPrimitive};
910
use vortex_dtype::{NativePType, PType, match_each_native_ptype};
@@ -58,15 +59,15 @@ impl ComputeVTable for ConstantEncoding {
5859
Some(self)
5960
}
6061

61-
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
62+
fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
6263
Some(self)
6364
}
6465

65-
fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
66+
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
6667
Some(self)
6768
}
6869

69-
fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
70+
fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
7071
Some(self)
7172
}
7273
}
@@ -77,12 +78,6 @@ impl ScalarAtFn<&ConstantArray> for ConstantEncoding {
7778
}
7879
}
7980

80-
impl TakeFn<&ConstantArray> for ConstantEncoding {
81-
fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
82-
Ok(ConstantArray::new(array.scalar().clone(), indices.len()).into_array())
83-
}
84-
}
85-
8681
impl SliceFn<&ConstantArray> for ConstantEncoding {
8782
fn slice(&self, array: &ConstantArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
8883
Ok(ConstantArray::new(array.scalar().clone(), stop - start).into_array())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
use vortex_error::VortexResult;
2+
use vortex_mask::{AllOr, Mask};
3+
use vortex_scalar::Scalar;
4+
5+
use crate::arrays::{ConstantArray, ConstantEncoding};
6+
use crate::builders::builder_with_capacity;
7+
use crate::compute::TakeFn;
8+
use crate::{Array, ArrayRef};
9+
10+
impl TakeFn<&ConstantArray> for ConstantEncoding {
11+
fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
12+
match indices.validity_mask()?.boolean_buffer() {
13+
AllOr::All => {
14+
Ok(ConstantArray::new(array.scalar().clone(), indices.len()).into_array())
15+
}
16+
AllOr::None => Ok(ConstantArray::new(
17+
Scalar::null(array.dtype().clone()),
18+
indices.len(),
19+
)
20+
.into_array()),
21+
AllOr::Some(v) => {
22+
let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
23+
24+
if array.scalar().is_null() {
25+
return Ok(arr);
26+
}
27+
28+
let mut result_builder =
29+
builder_with_capacity(&array.dtype().as_nullable(), indices.len());
30+
result_builder.extend_from_array(&arr)?;
31+
result_builder.set_validity(Mask::from_buffer(v.clone()));
32+
Ok(result_builder.finish())
33+
}
34+
}
35+
}
36+
}
37+
38+
#[cfg(test)]
39+
mod tests {
40+
use vortex_buffer::buffer;
41+
use vortex_mask::AllOr;
42+
43+
use crate::arrays::{ConstantArray, PrimitiveArray};
44+
use crate::compute::take;
45+
use crate::validity::Validity;
46+
use crate::{Array, ToCanonical};
47+
48+
#[test]
49+
fn take_nullable_indices() {
50+
let array = ConstantArray::new(42, 10).to_array();
51+
let taken = take(
52+
&array,
53+
&PrimitiveArray::new(
54+
buffer![0, 5, 7],
55+
Validity::from_iter(vec![false, true, false]),
56+
)
57+
.into_array(),
58+
)
59+
.unwrap();
60+
let valid_indices: &[usize] = &[1usize];
61+
assert_eq!(
62+
taken.to_primitive().unwrap().as_slice::<i32>(),
63+
&[42, 42, 42]
64+
);
65+
assert_eq!(
66+
taken.validity_mask().unwrap().indices(),
67+
AllOr::Some(valid_indices)
68+
);
69+
}
70+
}

vortex-array/src/builders/bool.rs

+6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::any::Any;
33
use arrow_buffer::BooleanBufferBuilder;
44
use vortex_dtype::{DType, Nullability};
55
use vortex_error::{VortexResult, vortex_bail};
6+
use vortex_mask::Mask;
67

78
use crate::arrays::BoolArray;
89
use crate::builders::ArrayBuilder;
@@ -85,6 +86,11 @@ impl ArrayBuilder for BoolBuilder {
8586
Ok(())
8687
}
8788

89+
fn set_validity(&mut self, validity: Mask) {
90+
self.nulls = LazyNullBufferBuilder::new(validity.len());
91+
self.nulls.append_validity_mask(validity);
92+
}
93+
8894
fn finish(&mut self) -> ArrayRef {
8995
assert_eq!(
9096
self.nulls.len(),

vortex-array/src/builders/extension.rs

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::sync::Arc;
33

44
use vortex_dtype::{DType, ExtDType};
55
use vortex_error::{VortexResult, vortex_bail};
6+
use vortex_mask::Mask;
67
use vortex_scalar::ExtScalar;
78

89
use crate::arrays::ExtensionArray;
@@ -82,6 +83,10 @@ impl ArrayBuilder for ExtensionBuilder {
8283
array.storage().append_to_builder(self.storage.as_mut())
8384
}
8485

86+
fn set_validity(&mut self, validity: Mask) {
87+
self.storage.set_validity(validity);
88+
}
89+
8590
fn finish(&mut self) -> ArrayRef {
8691
let storage = self.storage.finish();
8792
ExtensionArray::new(self.ext_dtype(), storage).into_array()

vortex-array/src/builders/lazy_validity_builder.rs

-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ impl LazyNullBufferBuilder {
5252
.append_n(n, false);
5353
}
5454

55-
#[allow(dead_code)]
5655
#[inline]
5756
pub fn append_null(&mut self) {
5857
self.materialize_if_needed();
@@ -62,7 +61,6 @@ impl LazyNullBufferBuilder {
6261
.append(false);
6362
}
6463

65-
#[allow(dead_code)]
6664
#[inline]
6765
pub fn append(&mut self, not_null: bool) {
6866
if not_null {

vortex-array/src/builders/list.rs

+6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::sync::Arc;
44
use vortex_dtype::Nullability::NonNullable;
55
use vortex_dtype::{DType, NativePType, Nullability};
66
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
7+
use vortex_mask::Mask;
78
use vortex_scalar::{BinaryNumericOperator, ListScalar};
89

910
use crate::arrays::{ConstantArray, ListArray, OffsetPType};
@@ -150,6 +151,11 @@ impl<O: OffsetPType> ArrayBuilder for ListBuilder<O> {
150151
Ok(())
151152
}
152153

154+
fn set_validity(&mut self, validity: Mask) {
155+
self.nulls = LazyNullBufferBuilder::new(validity.len());
156+
self.nulls.append_validity_mask(validity);
157+
}
158+
153159
fn finish(&mut self) -> ArrayRef {
154160
assert_eq!(
155161
self.index_builder.len(),

vortex-array/src/builders/mod.rs

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pub use primitive::*;
1717
pub use varbinview::*;
1818
use vortex_dtype::{DType, match_each_native_ptype};
1919
use vortex_error::{VortexResult, vortex_bail, vortex_err};
20+
use vortex_mask::Mask;
2021
use vortex_scalar::{
2122
BinaryScalar, BoolScalar, ExtScalar, ListScalar, PrimitiveScalar, Scalar, ScalarValue,
2223
StructScalar, Utf8Scalar,
@@ -57,6 +58,9 @@ pub trait ArrayBuilder: Send {
5758
/// Extends the array with the provided array, canonicalizing if necessary.
5859
fn extend_from_array(&mut self, array: &dyn Array) -> VortexResult<()>;
5960

61+
/// Override builders validity with the one provided
62+
fn set_validity(&mut self, validity: Mask);
63+
6064
/// Constructs an Array from the builder components.
6165
///
6266
/// # Panics

vortex-array/src/builders/null.rs

+5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::any::Any;
22

33
use vortex_dtype::DType;
44
use vortex_error::VortexResult;
5+
use vortex_mask::Mask;
56

67
use crate::arrays::NullArray;
78
use crate::builders::ArrayBuilder;
@@ -54,6 +55,10 @@ impl ArrayBuilder for NullBuilder {
5455
Ok(())
5556
}
5657

58+
fn set_validity(&mut self, validity: Mask) {
59+
self.length = validity.len();
60+
}
61+
5762
fn finish(&mut self) -> ArrayRef {
5863
NullArray::new(self.length).into_array()
5964
}

vortex-array/src/builders/primitive.rs

+5
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ impl<T: NativePType> ArrayBuilder for PrimitiveBuilder<T> {
178178
Ok(())
179179
}
180180

181+
fn set_validity(&mut self, validity: Mask) {
182+
self.nulls = LazyNullBufferBuilder::new(validity.len());
183+
self.nulls.append_validity_mask(validity);
184+
}
185+
181186
fn finish(&mut self) -> ArrayRef {
182187
self.finish_into_primitive().into_array()
183188
}

vortex-array/src/builders/struct_.rs

+15-24
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@ use std::sync::Arc;
44
use itertools::Itertools;
55
use vortex_dtype::{DType, Nullability, StructDType};
66
use vortex_error::{VortexExpect, VortexResult, vortex_bail};
7+
use vortex_mask::Mask;
78
use vortex_scalar::StructScalar;
89

910
use crate::arrays::StructArray;
10-
use crate::builders::{ArrayBuilder, ArrayBuilderExt, BoolBuilder, builder_with_capacity};
11-
use crate::validity::Validity;
11+
use crate::builders::lazy_validity_builder::LazyNullBufferBuilder;
12+
use crate::builders::{ArrayBuilder, ArrayBuilderExt, builder_with_capacity};
1213
use crate::variants::StructArrayTrait;
1314
use crate::{Array, ArrayRef, Canonical};
1415

1516
pub struct StructBuilder {
1617
builders: Vec<Box<dyn ArrayBuilder>>,
17-
// TODO(ngates): this should be a NullBufferBuilder? Or mask builder?
18-
validity: BoolBuilder,
18+
validity: LazyNullBufferBuilder,
1919
struct_dtype: Arc<StructDType>,
2020
nullability: Nullability,
2121
dtype: DType,
@@ -34,7 +34,7 @@ impl StructBuilder {
3434

3535
Self {
3636
builders,
37-
validity: BoolBuilder::with_capacity(Nullability::NonNullable, capacity),
37+
validity: LazyNullBufferBuilder::new(capacity),
3838
struct_dtype: struct_dtype.clone(),
3939
nullability,
4040
dtype: DType::Struct(struct_dtype, nullability),
@@ -54,7 +54,7 @@ impl StructBuilder {
5454
for (builder, field) in self.builders.iter_mut().zip_eq(fields) {
5555
builder.append_scalar(&field)?;
5656
}
57-
self.validity.append_value(true);
57+
self.validity.append_non_null();
5858
} else {
5959
self.append_null()
6060
}
@@ -84,7 +84,7 @@ impl ArrayBuilder for StructBuilder {
8484
self.builders
8585
.iter_mut()
8686
.for_each(|builder| builder.append_zeros(n));
87-
self.validity.append_values(true, n);
87+
self.validity.append_n_non_nulls(n);
8888
}
8989

9090
fn append_nulls(&mut self, n: usize) {
@@ -93,7 +93,7 @@ impl ArrayBuilder for StructBuilder {
9393
// We push zero values into our children when appending a null in case the children are
9494
// themselves non-nullable.
9595
.for_each(|builder| builder.append_zeros(n));
96-
self.validity.append_value(false);
96+
self.validity.append_null();
9797
}
9898

9999
fn extend_from_array(&mut self, array: &dyn Array) -> VortexResult<()> {
@@ -120,21 +120,15 @@ impl ArrayBuilder for StructBuilder {
120120
a.append_to_builder(builder.as_mut())?;
121121
}
122122

123-
match array.validity() {
124-
Validity::NonNullable | Validity::AllValid => {
125-
self.validity.append_values(true, array.len());
126-
}
127-
Validity::AllInvalid => {
128-
self.validity.append_values(false, array.len());
129-
}
130-
Validity::Array(validity) => {
131-
validity.append_to_builder(&mut self.validity)?;
132-
}
133-
}
134-
123+
self.validity.append_validity_mask(array.validity_mask()?);
135124
Ok(())
136125
}
137126

127+
fn set_validity(&mut self, validity: Mask) {
128+
self.validity = LazyNullBufferBuilder::new(validity.len());
129+
self.validity.append_validity_mask(validity);
130+
}
131+
138132
fn finish(&mut self) -> ArrayRef {
139133
let len = self.len();
140134
let fields = self
@@ -156,10 +150,7 @@ impl ArrayBuilder for StructBuilder {
156150
}
157151
}
158152

159-
let validity = match self.nullability {
160-
Nullability::NonNullable => Validity::NonNullable,
161-
Nullability::Nullable => Validity::Array(self.validity.finish()),
162-
};
153+
let validity = self.validity.finish_with_nullability(self.nullability);
163154

164155
StructArray::try_new(self.struct_dtype.names().clone(), fields, len, validity)
165156
.vortex_expect("Fields must all have same length.")

vortex-array/src/builders/varbinview.rs

+5
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ impl ArrayBuilder for VarBinViewBuilder {
174174
Ok(())
175175
}
176176

177+
fn set_validity(&mut self, validity: Mask) {
178+
self.null_buffer_builder = LazyNullBufferBuilder::new(validity.len());
179+
self.null_buffer_builder.append_validity_mask(validity);
180+
}
181+
177182
fn finish(&mut self) -> ArrayRef {
178183
self.flush_in_progress();
179184
let buffers = std::mem::take(&mut self.completed);

vortex-array/src/validity.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ impl Validity {
147147
Self::AllInvalid => Ok(Self::AllInvalid),
148148
Self::Array(is_valid) => {
149149
let maybe_is_valid = take(is_valid, indices)?;
150-
// Null indices invalidite that position.
150+
// Null indices invalidate that position.
151151
let is_valid = fill_null(&maybe_is_valid, Scalar::from(false))?;
152152
Ok(Self::Array(is_valid))
153153
}

0 commit comments

Comments
 (0)