Skip to content

Commit

Permalink
fix: Properly broadcast Struct when then validity
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite committed Oct 8, 2024
1 parent 133bf47 commit ffbb34d
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 74 deletions.
155 changes: 81 additions & 74 deletions crates/polars-core/src/chunked_array/ops/zip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ impl ChunkZip<StructType> for StructChunked {
let if_true = if_true.as_ref();
let if_false = if_false.as_ref();

let (l, r, mask) = align_chunks_ternary(if_true, if_false, mask);
let (if_true, if_false, mask) = align_chunks_ternary(if_true, if_false, mask);

// Prepare the boolean arrays such that Null maps to false.
// This prevents every field doing that.
Expand All @@ -287,10 +287,10 @@ impl ChunkZip<StructType> for StructChunked {
}

// Zip all the fields.
let fields = l
let fields = if_true
.fields_as_series()
.iter()
.zip(r.fields_as_series())
.zip(if_false.fields_as_series())
.map(|(lhs, rhs)| lhs.zip_with_same_type(&mask, &rhs))
.collect::<PolarsResult<Vec<_>>>()?;

Expand Down Expand Up @@ -330,138 +330,145 @@ impl ChunkZip<StructType> for StructChunked {
// We need to take two things into account:
// 1. The chunk lengths of `out` might not necessarily match `l`, `r` and `mask`.
// 2. `l` and `r` might still need to be broadcasted.
if (l.null_count + r.null_count) > 0 {
if (if_true.null_count + if_false.null_count) > 0 {
// Create one validity mask that spans the entirety of out.
let rechunked_validity = match (l.len(), r.len()) {
(1, 1) if length != 1 => match (l.null_count() == 0, r.null_count() == 0) {
(true, true) => None,
(false, true) => {
if mask.chunks().len() == 1 {
let m = mask.chunks()[0]
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.values();
Some(!m)
} else {
rechunk_bitmaps(
length,
mask.downcast_iter().map(|m| (m.len(), Some(!m.values()))),
)
}
},
(true, false) => {
if mask.chunks().len() == 1 {
let m = mask.chunks()[0]
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.values();
Some(m.clone())
} else {
rechunk_bitmaps(
length,
mask.downcast_iter()
.map(|m| (m.len(), Some(m.values().clone()))),
)
}
},
(false, false) => Some(Bitmap::new_zeroed(length)),
let rechunked_validity = match (if_true.len(), if_false.len()) {
(1, 1) if length != 1 => {
match (if_true.null_count() == 0, if_false.null_count() == 0) {
(true, true) => None,
(false, true) => {
if mask.chunks().len() == 1 {
let m = mask.chunks()[0]
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.values();
Some(!m)
} else {
rechunk_bitmaps(
length,
mask.downcast_iter()
.map(|m| (m.len(), Some(m.values().clone()))),
)
}
},
(true, false) => {
if mask.chunks().len() == 1 {
let m = mask.chunks()[0]
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.values();
Some(m.clone())
} else {
rechunk_bitmaps(
length,
mask.downcast_iter().map(|m| (m.len(), Some(!m.values()))),
)
}
},
(false, false) => Some(Bitmap::new_zeroed(length)),
}
},
(1, _) if length != 1 => {
debug_assert!(r
debug_assert!(if_false
.chunk_lengths()
.zip(mask.chunk_lengths())
.all(|(r, m)| r == m));

let combine = if l.null_count() == 0 {
|r: Option<&Bitmap>, m: &Bitmap| r.map(|r| arrow::bitmap::or(r, m))
let combine = if if_true.null_count() == 0 {
|if_false: Option<&Bitmap>, m: &Bitmap| {
if_false.map(|v| arrow::bitmap::or(v, m))
}
} else {
|r: Option<&Bitmap>, m: &Bitmap| {
Some(r.map_or_else(|| m.clone(), |r| arrow::bitmap::and_not(r, m)))
|if_false: Option<&Bitmap>, m: &Bitmap| {
Some(if_false.map_or_else(|| !m, |v| arrow::bitmap::and_not(v, m)))
}
};

if r.chunks().len() == 1 {
let r = r.chunks()[0].validity();
if if_false.chunks().len() == 1 {
let if_false = if_false.chunks()[0].validity();
let m = mask.chunks()[0]
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.values();

let validity = combine(r, m);
validity.and_then(|v| (v.unset_bits() > 0).then_some(v))
let validity = combine(if_false, m);
validity.filter(|v| v.unset_bits() > 0)
} else {
rechunk_bitmaps(
length,
r.chunks()
.iter()
.zip(mask.downcast_iter())
.map(|(chunk, mask)| {
if_false.chunks().iter().zip(mask.downcast_iter()).map(
|(chunk, mask)| {
(mask.len(), combine(chunk.validity(), mask.values()))
}),
},
),
)
}
},
(_, 1) if length != 1 => {
debug_assert!(l
debug_assert!(if_true
.chunk_lengths()
.zip(mask.chunk_lengths())
.all(|(l, m)| l == m));

let combine = if r.null_count() == 0 {
|l: Option<&Bitmap>, m: &Bitmap| l.map(|l| arrow::bitmap::or_not(l, m))
let combine = if if_false.null_count() == 0 {
|if_true: Option<&Bitmap>, m: &Bitmap| {
if_true.map(|v| arrow::bitmap::or_not(v, m))
}
} else {
|l: Option<&Bitmap>, m: &Bitmap| {
Some(l.map_or_else(|| m.clone(), |l| arrow::bitmap::and(l, m)))
|if_true: Option<&Bitmap>, m: &Bitmap| {
Some(if_true.map_or_else(|| m.clone(), |v| arrow::bitmap::and(v, m)))
}
};

if l.chunks().len() == 1 {
let l = l.chunks()[0].validity();
if if_true.chunks().len() == 1 {
let if_true = if_true.chunks()[0].validity();
let m = mask.chunks()[0]
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap()
.values();

let validity = combine(l, m);
validity.and_then(|v| (v.unset_bits() > 0).then_some(v))
let validity = combine(if_true, m);
validity.filter(|v| v.unset_bits() > 0)
} else {
rechunk_bitmaps(
length,
l.chunks()
.iter()
.zip(mask.downcast_iter())
.map(|(chunk, mask)| {
if_true.chunks().iter().zip(mask.downcast_iter()).map(
|(chunk, mask)| {
(mask.len(), combine(chunk.validity(), mask.values()))
}),
},
),
)
}
},
(_, _) => {
debug_assert!(l
debug_assert!(if_true
.chunk_lengths()
.zip(r.chunk_lengths())
.zip(if_false.chunk_lengths())
.all(|(l, r)| l == r));
debug_assert!(l
debug_assert!(if_true
.chunk_lengths()
.zip(mask.chunk_lengths())
.all(|(l, r)| l == r));

let validities = l
let validities = if_true
.chunks()
.iter()
.zip(r.chunks())
.zip(if_false.chunks())
.map(|(l, r)| (l.validity(), r.validity()));

rechunk_bitmaps(
length,
validities
.zip(mask.downcast_iter())
.map(|((lv, rv), mask)| {
(mask.len(), if_then_else_validity(mask.values(), lv, rv))
.map(|((if_true, if_false), mask)| {
(
mask.len(),
if_then_else_validity(mask.values(), if_true, if_false),
)
}),
)
},
Expand Down
46 changes: 46 additions & 0 deletions py-polars/tests/unit/functions/test_when_then.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,3 +690,49 @@ def test_when_then_chunked_structs_18673() -> None:
df.select(pl.when(pl.col.b).then(pl.first("x")).otherwise(pl.first("x"))),
pl.DataFrame({"x": [{"a": 1}, {"a": 1}]}),
)


some_scalar = pl.Series("a", [{"x": 2}], pl.Struct)
none_scalar = pl.Series("a", [None], pl.Struct({"x": pl.Int64}))
column = pl.Series("a", [{"x": 2}, {"x": 2}], pl.Struct)


@pytest.mark.parametrize(
"values",
[
(some_scalar, some_scalar),
(some_scalar, pl.col.a),
(some_scalar, none_scalar),
(some_scalar, column),
(none_scalar, pl.col.a),
(none_scalar, none_scalar),
(none_scalar, column),
(pl.col.a, pl.col.a),
(pl.col.a, column),
(column, column),
],
)
def test_struct_when_then_broadcasting_combinations_19122(
values: tuple[Any, Any],
) -> None:
lv, rv = values

df = pl.Series("a", [{"x": 1}, {"x": 1}], pl.Struct).to_frame()

assert_frame_equal(
df.select(
pl.when(pl.col.a.struct.field("x") == 0).then(lv).otherwise(rv).alias("a")
),
df.select(
pl.when(pl.col.a.struct.field("x") == 0).then(None).otherwise(rv).alias("a")
),
)

assert_frame_equal(
df.select(
pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(lv).alias("a")
),
df.select(
pl.when(pl.col.a.struct.field("x") != 0).then(rv).otherwise(None).alias("a")
),
)

0 comments on commit ffbb34d

Please sign in to comment.