Skip to content

Commit 9cdd9c7

Browse files
committed
fix: ConstantArray#take handles nullable indices
1 parent 9f331e1 commit 9cdd9c7

File tree

4 files changed

+72
-17
lines changed

4 files changed

+72
-17
lines changed

fuzz/src/sort.rs

+1-7
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,7 @@ pub fn sort_canonical_array(array: &dyn Array) -> VortexResult<ArrayRef> {
1616
let mut opt_values = bool_array
1717
.boolean_buffer()
1818
.iter()
19-
.zip(
20-
bool_array
21-
.validity_mask()
22-
.vortex_expect("Failed to get logical validity")
23-
.to_boolean_buffer()
24-
.iter(),
25-
)
19+
.zip(bool_array.validity_mask()?.to_boolean_buffer().iter())
2620
.map(|(b, v)| v.then_some(b))
2721
.collect::<Vec<_>>();
2822
opt_values.sort();

vortex-array/src/arrays/constant/compute/mod.rs

+4-9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ mod cast;
44
mod compare;
55
mod invert;
66
mod search_sorted;
7+
mod take;
78

89
use num_traits::{CheckedMul, ToPrimitive};
910
use vortex_dtype::{NativePType, PType, match_each_native_ptype};
@@ -58,15 +59,15 @@ impl ComputeVTable for ConstantEncoding {
5859
Some(self)
5960
}
6061

61-
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
62+
fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
6263
Some(self)
6364
}
6465

65-
fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
66+
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
6667
Some(self)
6768
}
6869

69-
fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
70+
fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
7071
Some(self)
7172
}
7273
}
@@ -77,12 +78,6 @@ impl ScalarAtFn<&ConstantArray> for ConstantEncoding {
7778
}
7879
}
7980

80-
impl TakeFn<&ConstantArray> for ConstantEncoding {
81-
fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
82-
Ok(ConstantArray::new(array.scalar().clone(), indices.len()).into_array())
83-
}
84-
}
85-
8681
impl SliceFn<&ConstantArray> for ConstantEncoding {
8782
fn slice(&self, array: &ConstantArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
8883
Ok(ConstantArray::new(array.scalar().clone(), stop - start).into_array())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
use vortex_error::VortexResult;
2+
use vortex_mask::AllOr;
3+
use vortex_scalar::Scalar;
4+
5+
use crate::arrays::{ConstantArray, ConstantEncoding};
6+
use crate::builders::{ArrayBuilderExt, builder_with_capacity};
7+
use crate::compute::TakeFn;
8+
use crate::{Array, ArrayRef};
9+
10+
impl TakeFn<&ConstantArray> for ConstantEncoding {
11+
fn take(&self, array: &ConstantArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
12+
match indices.validity_mask()?.boolean_buffer() {
13+
AllOr::All => {
14+
Ok(ConstantArray::new(array.scalar().clone(), indices.len()).into_array())
15+
}
16+
AllOr::None => Ok(ConstantArray::new(
17+
Scalar::null(array.dtype().clone()),
18+
indices.len(),
19+
)
20+
.into_array()),
21+
AllOr::Some(v) => {
22+
let mut result_builder =
23+
builder_with_capacity(&array.dtype().as_nullable(), indices.len());
24+
for valid in v.iter() {
25+
if valid {
26+
result_builder.append_scalar_value(array.scalar().value().clone())?;
27+
} else {
28+
result_builder.append_null();
29+
}
30+
}
31+
Ok(result_builder.finish())
32+
}
33+
}
34+
}
35+
}
36+
37+
#[cfg(test)]
38+
mod tests {
39+
use vortex_buffer::buffer;
40+
use vortex_mask::AllOr;
41+
42+
use crate::arrays::{ConstantArray, PrimitiveArray};
43+
use crate::compute::take;
44+
use crate::validity::Validity;
45+
use crate::{Array, ToCanonical};
46+
47+
#[test]
48+
fn take_nullable_indices() {
49+
let array = ConstantArray::new(42, 10).to_array();
50+
let taken = take(
51+
&array,
52+
&PrimitiveArray::new(
53+
buffer![0, 5, 7],
54+
Validity::from_iter(vec![false, true, false]),
55+
)
56+
.into_array(),
57+
)
58+
.unwrap();
59+
let valid_indices: &[usize] = &[1usize];
60+
assert_eq!(taken.to_primitive().unwrap().as_slice::<i32>(), &[0, 42, 0]);
61+
assert_eq!(
62+
taken.validity_mask().unwrap().indices(),
63+
AllOr::Some(valid_indices)
64+
);
65+
}
66+
}

vortex-array/src/validity.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ impl Validity {
147147
Self::AllInvalid => Ok(Self::AllInvalid),
148148
Self::Array(is_valid) => {
149149
let maybe_is_valid = take(is_valid, indices)?;
150-
// Null indices invalidite that position.
150+
// Null indices invalidate that position.
151151
let is_valid = fill_null(&maybe_is_valid, Scalar::from(false))?;
152152
Ok(Self::Array(is_valid))
153153
}

0 commit comments

Comments
 (0)