Skip to content

Commit

Permalink
fix: Properly zip struct validities
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite committed Sep 24, 2024
1 parent ea7953e commit ff0eaa0
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 70 deletions.
1 change: 1 addition & 0 deletions crates/polars-arrow/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static {
/// The caller must ensure that `offset + length <= self.len()`
#[must_use]
unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Box<dyn Array> {
debug_assert!(offset + length <= self.len());
let mut new = self.to_boxed();
new.slice_unchecked(offset, length);
new
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,8 @@ where
})
.collect();

debug_assert_eq!(offset, array.len());

// SAFETY: We just slice the original chunks, their type will not change.
unsafe {
Self::from_chunks_and_dtype(self.name().clone(), chunks, self.dtype().clone())
Expand Down
265 changes: 195 additions & 70 deletions crates/polars-core/src/chunked_array/ops/zip.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::borrow::Cow;

use arrow::bitmap::Bitmap;
use arrow::bitmap::{Bitmap, MutableBitmap};
use arrow::compute::utils::{combine_validities_and, combine_validities_and_not};
use polars_compute::if_then_else::{if_then_else_validity, IfThenElseKernel};

Expand Down Expand Up @@ -216,7 +216,10 @@ impl ChunkZip<StructType> for StructChunked {
mask: &BooleanChunked,
other: &ChunkedArray<StructType>,
) -> PolarsResult<ChunkedArray<StructType>> {
let length = self.length.max(mask.length).max(other.length);
let min_length = self.length.min(mask.length).min(other.length);
let max_length = self.length.max(mask.length).max(other.length);

let length = if min_length == 0 { 0 } else { max_length };

debug_assert!(self.length == 1 || self.length == length);
debug_assert!(mask.length == 1 || mask.length == length);
Expand All @@ -227,6 +230,22 @@ impl ChunkZip<StructType> for StructChunked {
let mut if_true: Cow<ChunkedArray<StructType>> = Cow::Borrowed(self);
let mut if_false: Cow<ChunkedArray<StructType>> = Cow::Borrowed(other);

// Special case. In this case, we know what to do.
// @TODO: Optimization. If all mask values are the same, select one of the two.
if mask.length == 1 {
// pl.when(None) <=> pl.when(False)
let is_true = mask.get(0).unwrap_or(false);
return Ok(if is_true && self.length == 1 {
self.new_from_index(0, length)
} else if is_true {
self.clone()
} else if other.length == 1 {
other.new_from_index(0, length)
} else {
other.clone()
});
}

// align_chunks_ternary can only align chunks if:
// - Each chunkedarray only has 1 chunk
// - Each chunkedarray has an equal length (i.e. is broadcasted)
Expand All @@ -235,21 +254,6 @@ impl ChunkZip<StructType> for StructChunked {
let needs_broadcast =
if_true.chunks().len() > 1 || if_false.chunks().len() > 1 || mask.chunks().len() > 1;
if needs_broadcast && length > 1 {
// Special case. In this case, we know what to do.
if mask.length == 1 {
// pl.when(None) <=> pl.when(False)
let is_true = mask.get(0).unwrap_or(false);
return Ok(if is_true && self.length == 1 {
self.new_from_index(0, length)
} else if is_true {
self.clone()
} else if other.length == 1 {
other.new_from_index(0, length)
} else {
other.clone()
});
}

if self.length == 1 {
let broadcasted = self.new_from_index(0, length);
if_true = Cow::Owned(broadcasted);
Expand Down Expand Up @@ -288,70 +292,191 @@ impl ChunkZip<StructType> for StructChunked {

let mut out = StructChunked::from_series(self.name().clone(), fields.iter())?;

// Zip the validities.
if (l.null_count + r.null_count) > 0 {
let validities = l
.chunks()
.iter()
.zip(r.chunks())
.map(|(l, r)| (l.validity(), r.validity()));

fn broadcast(v: Option<&Bitmap>, arr: &ArrayRef) -> Bitmap {
if v.unwrap().get(0).unwrap() {
Bitmap::new_with_value(true, arr.len())
} else {
Bitmap::new_zeroed(arr.len())
fn rechunk_bitmaps(
total_length: usize,
iter: impl Iterator<Item = (usize, Option<Bitmap>)>,
) -> Option<Bitmap> {
let mut rechunked_length = 0;
let mut rechunked_validity = None;
for (chunk_length, validity) in iter {
if let Some(validity) = validity {
if validity.unset_bits() > 0 {
rechunked_validity
.get_or_insert_with(|| {
let mut bm = MutableBitmap::with_capacity(total_length);
bm.extend_constant(rechunked_length, true);
bm
})
.extend_from_bitmap(&validity);
}
}

rechunked_length += chunk_length;
}

// # SAFETY
// We don't modify the length and update the null count.
unsafe {
for ((arr, (lv, rv)), mask) in out
.chunks_mut()
.iter_mut()
.zip(validities)
.zip(mask.downcast_iter())
{
// TODO! we can optimize this and use a kernel that is able to broadcast wo/ allocating.
let (lv, rv) = match (lv.map(|b| b.len()), rv.map(|b| b.len())) {
(Some(1), Some(1)) if arr.len() != 1 => {
let lv = broadcast(lv, arr);
let rv = broadcast(rv, arr);
(Some(lv), Some(rv))
},
(Some(a), Some(b)) if a == b => (lv.cloned(), rv.cloned()),
(Some(1), _) => {
let lv = broadcast(lv, arr);
(Some(lv), rv.cloned())
},
(_, Some(1)) => {
let rv = broadcast(rv, arr);
(lv.cloned(), Some(rv))
},
(None, Some(_)) | (Some(_), None) | (None, None) => {
(lv.cloned(), rv.cloned())
},
(Some(a), Some(b)) => {
polars_bail!(InvalidOperation: "got different sizes in 'zip' operation, got length: {a} and {b}")
},
};
if let Some(rechunked_validity) = rechunked_validity.as_mut() {
rechunked_validity.extend_constant(total_length - rechunked_validity.len(), true);
}

rechunked_validity.map(MutableBitmap::freeze)
}

// broadcast mask
let validity = if mask.len() != arr.len() && mask.len() == 1 {
if mask.get(0).unwrap() {
lv
} else {
rv
// Zip the validities.
//
// We need to take two things into account:
// 1. The chunk lengths of `out` might not necessarily match `l`, `r` and `mask`.
// 2. `l` and `r` might still need to be broadcasted.
if (l.null_count + r.null_count) > 0 {
// Create one validity mask that spans the entirety of out.
let rechunked_validity = match (l.len(), r.len()) {
(1, _) => {
debug_assert!(r
.chunk_lengths()
.zip(mask.chunk_lengths())
.all(|(r, m)| r == m));

let combine = if l.null_count() == 0 {
|r: Option<&Bitmap>, m: &Bitmap| r.map(|r| arrow::bitmap::or_not(r, m))
} else {
|r: Option<&Bitmap>, m: &Bitmap| {
Some(r.map_or_else(|| m.clone(), |r| arrow::bitmap::and(r, m)))
}
};

if r.chunks().len() == 1 {
let r = r.chunks()[0].validity();
let m = mask.chunks()[0]
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.values();

let validity = combine(r, m);
validity.and_then(|v| (v.unset_bits() > 0).then_some(v))
} else {
if_then_else_validity(mask.values(), lv.as_ref(), rv.as_ref())
rechunk_bitmaps(
length,
r.chunks()
.iter()
.zip(mask.downcast_iter())
.map(|(chunk, mask)| {
(mask.len(), combine(chunk.validity(), mask.values()))
}),
)
}
},
(_, 1) => {
debug_assert!(l
.chunk_lengths()
.zip(mask.chunk_lengths())
.all(|(l, m)| l == m));

let combine = if r.null_count() == 0 {
|r: Option<&Bitmap>, m: &Bitmap| r.map(|r| arrow::bitmap::or(r, m))
} else {
|r: Option<&Bitmap>, m: &Bitmap| {
Some(r.map_or_else(|| m.clone(), |r| arrow::bitmap::and_not(r, m)))
}
};

*arr = arr.with_validity(validity);
if l.chunks().len() == 1 {
let l = l.chunks()[0].validity();
let m = mask.chunks()[0]
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.values();

let validity = combine(l, m);
validity.and_then(|v| (v.unset_bits() > 0).then_some(v))
} else {
rechunk_bitmaps(
length,
l.chunks()
.iter()
.zip(mask.downcast_iter())
.map(|(chunk, mask)| {
(mask.len(), combine(chunk.validity(), mask.values()))
}),
)
}
},
(_, _) => {
debug_assert!(l
.chunk_lengths()
.zip(r.chunk_lengths())
.all(|(l, r)| l == r));
debug_assert!(l
.chunk_lengths()
.zip(mask.chunk_lengths())
.all(|(l, r)| l == r));

let validities = l
.chunks()
.iter()
.zip(r.chunks())
.map(|(l, r)| (l.validity(), r.validity()));

rechunk_bitmaps(
length,
validities
.zip(mask.downcast_iter())
.map(|((lv, rv), mask)| {
(mask.len(), if_then_else_validity(mask.values(), lv, rv))
}),
)
},
};

// Apply the validity spreading over the chunks of out.
if let Some(mut rechunked_validity) = rechunked_validity {
assert_eq!(rechunked_validity.len(), out.len());

let num_chunks = out.chunks().len();
let null_count = rechunked_validity.unset_bits();

// SAFETY: We do not change the lengths of the chunks and we update the null_count
// afterwards.
let chunks = unsafe { out.chunks_mut() };

if num_chunks == 1 {
chunks[0] = chunks[0].with_validity(Some(rechunked_validity));
} else {
for chunk in chunks {
let chunk_len = chunk.len();
let chunk_validity;

// SAFETY: We know that rechunked_validity.len() == out.len()
(chunk_validity, rechunked_validity) =
unsafe { rechunked_validity.split_at_unchecked(chunk_len) };
*chunk = chunk.with_validity(
(chunk_validity.unset_bits() > 0).then_some(chunk_validity),
);
}
}

out.null_count = null_count as IdxSize;
} else {
// SAFETY: We do not change the lengths of the chunks and we update the null_count
// afterwards.
let chunks = unsafe { out.chunks_mut() };

for chunk in chunks {
*chunk = chunk.with_validity(None);
}

out.null_count = 0 as IdxSize;
}
}

if cfg!(debug_assertions) {
let start_length = out.len();
let start_null_count = out.null_count();

out.compute_len();

assert_eq!(start_length, out.len());
assert_eq!(start_null_count, out.null_count());
}
Ok(out)
}
Expand Down
29 changes: 29 additions & 0 deletions py-polars/tests/unit/datatypes/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,3 +1019,32 @@ def test_struct_group_by_shift_18107() -> None:
[{"lon": 60, "lat": 50}, {"lon": 70, "lat": 60}, None],
],
}


def test_struct_chunked_zip_18119() -> None:
dtype = pl.Struct({"x": pl.Null})

a_dfs = [pl.DataFrame([pl.Series("a", [None] * i, dtype)]) for i in range(5)]
b_dfs = [pl.DataFrame([pl.Series("b", [None] * i, dtype)]) for i in range(5)]
mask_dfs = [
pl.DataFrame([pl.Series("f", [None] * i, pl.Boolean)]) for i in range(5)
]

a = pl.concat([a_dfs[2], a_dfs[2], a_dfs[1]])
b = pl.concat([b_dfs[4], b_dfs[1]])
mask = pl.concat([mask_dfs[3], mask_dfs[2]])

df = pl.concat([a, b, mask], how="horizontal")

assert_frame_equal(
df.select(pl.when(pl.col.f).then(pl.col.a).otherwise(pl.col.b)),
pl.DataFrame([pl.Series("a", [None] * 5, dtype)]),
)


def test_struct_null_zip() -> None:
df = pl.Series("int", [], dtype=pl.Struct({"x": pl.Int64})).to_frame()
assert_frame_equal(
df.select(pl.when(pl.Series([True])).then(pl.col.int).otherwise(pl.col.int)),
pl.Series("int", [], dtype=pl.Struct({"x": pl.Int64})).to_frame(),
)

0 comments on commit ff0eaa0

Please sign in to comment.