Skip to content

Commit 960acf6

Browse files
authored
fix: take_into mask length must equal _indices_ length (#2396)
I added a test that catches this behavior. In particular, the indices must differ in length from the array and also it must be either all valid or all invalid. I additionally changed `Validity::to_logical` to assert the length matches the array length. Since mixed validity is more common than the other two, I hope this new assertion will surface these kinds of issues in more tests. An example of a failure on `develop` with these new tests: ``` ---- array::primitive::compute::take::test::test_take_into stdout ---- thread 'array::primitive::compute::take::test::test_take_into' panicked at vortex-array/src/builders/primitive.rs:98:9: assertion `left == right` failed: null count must equal value count left: 5 right: 3 stack backtrace: 0: rust_begin_unwind at /rustc/409998c4e8cae45344fd434b358b697cc93870d0/library/std/src/panicking.rs:676:5 1: core::panicking::panic_fmt at /rustc/409998c4e8cae45344fd434b358b697cc93870d0/library/core/src/panicking.rs:75:14 2: core::panicking::assert_failed_inner 3: core::panicking::assert_failed at /rustc/409998c4e8cae45344fd434b358b697cc93870d0/library/core/src/panicking.rs:364:5 4: vortex_array::builders::primitive::PrimitiveBuilder<T>::finish_into_primitive at ./src/builders/primitive.rs:98:9 5: <vortex_array::builders::primitive::PrimitiveBuilder<T> as vortex_array::builders::ArrayBuilder>::finish at ./src/builders/primitive.rs:176:9 6: vortex_array::array::primitive::compute::take::test::test_take_into at ./src/array/primitive/compute/take.rs:150:22 7: vortex_array::array::primitive::compute::take::test::test_take_into::{{closure}} at ./src/array/primitive/compute/take.rs:142:24 ```
1 parent 3883c7a commit 960acf6

File tree

3 files changed

+70
-6
lines changed

3 files changed

+70
-6
lines changed

vortex-array/src/array/primitive/compute/take.rs

+52-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ impl TakeFn<PrimitiveArray> for PrimitiveEncoding {
5050
let indices = indices.clone().into_primitive()?;
5151
// TODO(joe): impl take over mask and use `Array::validity_mask`, instead of `validity()`.
5252
let validity = array.validity().take(indices.as_ref())?;
53-
let mask = validity.to_logical(array.len())?;
53+
let mask = validity.to_logical(indices.len())?;
5454

5555
match_each_native_ptype!(array.ptype(), |$T| {
5656
match_each_integer_ptype!(indices.ptype(), |$I| {
@@ -66,6 +66,8 @@ fn take_into_impl<T: NativePType, I: NativePType + AsPrimitive<usize>>(
6666
mask: Mask,
6767
builder: &mut dyn ArrayBuilder,
6868
) -> VortexResult<()> {
69+
assert_eq!(indices.len(), mask.len());
70+
6971
let array = array.as_slice::<T>();
7072
let indices = indices.as_slice::<I>();
7173
let builder = builder
@@ -103,11 +105,13 @@ unsafe fn take_primitive_unchecked<T: NativePType, I: NativePType + AsPrimitive<
103105
#[cfg(test)]
104106
mod test {
105107
use vortex_buffer::buffer;
108+
use vortex_dtype::Nullability;
106109
use vortex_scalar::Scalar;
107110

108111
use crate::array::primitive::compute::take::take_primitive;
109112
use crate::array::{BoolArray, PrimitiveArray};
110-
use crate::compute::{scalar_at, take};
113+
use crate::builders::{ArrayBuilder as _, PrimitiveBuilder};
114+
use crate::compute::{scalar_at, take, take_into};
111115
use crate::validity::Validity;
112116
use crate::IntoArray as _;
113117

@@ -135,4 +139,50 @@ mod test {
135139
// the third index is null
136140
assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<i32>());
137141
}
142+
143+
#[test]
144+
fn test_take_into() {
145+
let values = PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5], Validity::NonNullable);
146+
let all_valid_indices = PrimitiveArray::new(
147+
buffer![0, 3, 4],
148+
Validity::Array(BoolArray::from_iter([true, true, true]).into_array()),
149+
);
150+
let mut builder = PrimitiveBuilder::<i32>::new(Nullability::Nullable);
151+
take_into(&values, all_valid_indices, &mut builder).unwrap();
152+
let actual = builder.finish().unwrap();
153+
assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::from(Some(1)));
154+
assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::from(Some(4)));
155+
assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::from(Some(5)));
156+
157+
let mixed_valid_indices = PrimitiveArray::new(
158+
buffer![0, 3, 4],
159+
Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
160+
);
161+
let mut builder = PrimitiveBuilder::<i32>::new(Nullability::Nullable);
162+
take_into(&values, mixed_valid_indices, &mut builder).unwrap();
163+
let actual = builder.finish().unwrap();
164+
assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::from(Some(1)));
165+
assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::from(Some(4)));
166+
// the third index is null
167+
assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<i32>());
168+
169+
let all_invalid_indices = PrimitiveArray::new(
170+
buffer![0, 3, 4],
171+
Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
172+
);
173+
let mut builder = PrimitiveBuilder::<i32>::new(Nullability::Nullable);
174+
take_into(&values, all_invalid_indices, &mut builder).unwrap();
175+
let actual = builder.finish().unwrap();
176+
assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::null_typed::<i32>());
177+
assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::null_typed::<i32>());
178+
assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<i32>());
179+
180+
let non_null_indices = PrimitiveArray::new(buffer![0, 3, 4], Validity::NonNullable);
181+
let mut builder = PrimitiveBuilder::<i32>::new(Nullability::NonNullable);
182+
take_into(&values, non_null_indices, &mut builder).unwrap();
183+
let actual = builder.finish().unwrap();
184+
assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::from(1));
185+
assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::from(4));
186+
assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::from(5));
187+
}
138188
}

vortex-array/src/compute/take.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,14 @@ fn take_into_impl(
224224
indices: &Array,
225225
builder: &mut dyn ArrayBuilder,
226226
) -> VortexResult<()> {
227-
if array.dtype() != builder.dtype() {
227+
let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
228+
let result_dtype = array.dtype().with_nullability(result_nullability);
229+
if &result_dtype != builder.dtype() {
228230
vortex_bail!(
229-
"TakeIntoFn {} had a builder with a different dtype {} to the array dtype {}",
231+
"TakeIntoFn {} had a builder with a different dtype {} to the resulting array dtype {}",
230232
array.encoding(),
231-
array.dtype(),
232-
builder.dtype()
233+
builder.dtype(),
234+
result_dtype,
233235
);
234236
}
235237
if let Some(take_fn) = array.vtable().take_fn() {

vortex-dtype/src/nullability.rs

+12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::fmt::{Display, Formatter};
2+
use std::ops::BitOr;
23

34
/// Whether an instance of a DType can be `null or not
45
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
@@ -10,6 +11,17 @@ pub enum Nullability {
1011
Nullable,
1112
}
1213

14+
impl BitOr for Nullability {
15+
type Output = Nullability;
16+
17+
fn bitor(self, rhs: Self) -> Self::Output {
18+
match (self, rhs) {
19+
(Self::NonNullable, Self::NonNullable) => Self::NonNullable,
20+
_ => Self::Nullable,
21+
}
22+
}
23+
}
24+
1325
impl From<bool> for Nullability {
1426
fn from(value: bool) -> Self {
1527
if value {

0 commit comments

Comments
 (0)