Skip to content

Commit

Permalink
fix(rust): deal with realiases in cluster_with_columns (#16548)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored May 28, 2024
1 parent 54213a2 commit d53bcb2
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 9 deletions.
37 changes: 37 additions & 0 deletions crates/polars-arrow/src/bitmap/assign_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,22 @@ where
}
}

/// Apply a bitwise binary operation to a [`MutableBitmap`].
///
/// This function can be used for operations like `&=` to a [`MutableBitmap`].
/// # Panics
/// This function panics iff `lhs.len() != `rhs.len()`
pub fn binary_assign_mut<T: BitChunk, F>(lhs: &mut MutableBitmap, rhs: &MutableBitmap, op: F)
where
F: Fn(T, T) -> T,
{
assert_eq!(lhs.len(), rhs.len());

let slice = rhs.as_slice();
let iter = BitChunksExact::<T>::new(slice, rhs.len());
binary_assign_impl(lhs, iter, op)
}

#[inline]
/// Compute bitwise OR operation in-place
fn or_assign<T: BitChunk>(lhs: &mut MutableBitmap, rhs: &Bitmap) {
Expand All @@ -117,6 +133,27 @@ fn or_assign<T: BitChunk>(lhs: &mut MutableBitmap, rhs: &Bitmap) {
}
}

#[inline]
/// Compute bitwise OR operation in-place
fn or_assign_mut<T: BitChunk>(lhs: &mut MutableBitmap, rhs: &MutableBitmap) {
if rhs.unset_bits() == 0 {
assert_eq!(lhs.len(), rhs.len());
lhs.clear();
lhs.extend_constant(rhs.len(), true);
} else if rhs.unset_bits() == rhs.len() {
// bitmap remains
} else {
binary_assign_mut(lhs, rhs, |x: T, y| x | y)
}
}

impl<'a> std::ops::BitOrAssign<&'a MutableBitmap> for &mut MutableBitmap {
#[inline]
fn bitor_assign(&mut self, rhs: &'a MutableBitmap) {
or_assign_mut::<u64>(self, rhs)
}
}

impl<'a> std::ops::BitOrAssign<&'a Bitmap> for &mut MutableBitmap {
#[inline]
fn bitor_assign(&mut self, rhs: &'a Bitmap) {
Expand Down
108 changes: 99 additions & 9 deletions crates/polars-plan/src/logical_plan/optimizer/cluster_with_columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ pub fn optimize(root: Node, lp_arena: &mut Arena<IR>, expr_arena: &Arena<AExpr>)
// We define these here to reuse the allocations across the loops
let mut column_map = ColumnMap::with_capacity(8);
let mut input_genset = MutableBitmap::with_capacity(16);
let mut current_livesets: Vec<MutableBitmap> = Vec::with_capacity(16);
let mut current_expr_livesets: Vec<MutableBitmap> = Vec::with_capacity(16);
let mut current_liveset = MutableBitmap::with_capacity(16);
let mut pushable = MutableBitmap::with_capacity(16);
let mut potential_pushable = Vec::with_capacity(4);

while let Some(current) = ir_stack.pop() {
let current_ir = lp_arena.get(current);
Expand Down Expand Up @@ -73,8 +75,10 @@ pub fn optimize(root: Node, lp_arena: &mut Arena<IR>, expr_arena: &Arena<AExpr>)
// Reuse the allocations of the previous loop
column_map.clear();
input_genset.clear();
current_livesets.clear();
current_expr_livesets.clear();
current_liveset.clear();
pushable.clear();
potential_pushable.clear();

// @NOTE
// We can pushdown any column that utilizes no live columns that are generated in the
Expand All @@ -95,23 +99,109 @@ pub fn optimize(root: Node, lp_arena: &mut Arena<IR>, expr_arena: &Arena<AExpr>)
column_map_set(&mut liveset, column_map, live.clone());
}

current_livesets.push(liveset);
current_expr_livesets.push(liveset);
}

// Force that column_map is not further mutated from this point on
let column_map = column_map as &_;

column_map_finalize_bitset(&mut input_genset, column_map);

// Check for every expression in the current WITH_COLUMNS node whether it can be pushed
// down.
for expr_liveset in &mut current_livesets {
current_liveset.extend_constant(column_map.len(), false);
for expr_liveset in &mut current_expr_livesets {
use std::ops::BitOrAssign;
column_map_finalize_bitset(expr_liveset, column_map);
(&mut current_liveset).bitor_assign(expr_liveset as &_);
}

// Check for every expression in the current WITH_COLUMNS node whether it can be pushed
// down or pruned.
*current_exprs.exprs_mut() = std::mem::take(current_exprs.exprs_mut())
.into_iter()
.zip(current_expr_livesets.iter_mut())
.filter_map(|(mut expr, liveset)| {
let does_input_assign_column_that_expr_used = input_genset.intersects_with(liveset);

if does_input_assign_column_that_expr_used {
pushable.push(false);
return Some(expr);
}

let column_name = expr.output_name_arc();
let is_pushable = if let Some(idx) = column_map.get(column_name) {
let does_input_alias_also_expr = input_genset.get(*idx);
let is_alias_live_in_current = current_liveset.get(*idx);

if does_input_alias_also_expr && !is_alias_live_in_current {
let column_name = column_name.as_ref();

// @NOTE: Pruning of re-assigned columns
//
// We checked if this expression output is also assigned by the input and
// that that assignment is not used in the current WITH_COLUMNS.
// Consequently, we are free to prune the input's assignment to the output.
//
// We immediately prune here to simplify the later code.
//
// @NOTE: Expressions in a `WITH_COLUMNS` cannot alias to the same column.
// Otherwise, this would be faulty and would panic.
let input_expr = input_exprs
.exprs_mut()
.iter_mut()
.find(|input_expr| column_name == input_expr.output_name())
.expect("No assigning expression for generated column");

// @NOTE
// Since we are reassigning a column and we are pushing to the input, we do
// not need to change the schema of the current or input nodes.
std::mem::swap(&mut expr, input_expr);
return None;
}

// We cannot have multiple assignments to the same column in one WITH_COLUMNS
// and we need to make sure that we are not changing the column value that
// neighbouring expressions are seeing.

// @NOTE: In this case it might be possible to push this down if all the
// expressions that use the output are also being pushed down.
if !does_input_alias_also_expr && is_alias_live_in_current {
potential_pushable.push(pushable.len());
pushable.push(false);
return Some(expr);
}

!does_input_alias_also_expr && !is_alias_live_in_current
} else {
true
};

pushable.push(is_pushable);
Some(expr)
})
.collect();

debug_assert_eq!(pushable.len(), current_exprs.len());

let has_intersection = input_genset.intersects_with(expr_liveset);
let is_pushable = !has_intersection;
// Here we do a last check for expressions to push down.
// This will pushdown the expressions that "has an output column that is mentioned by
// neighbour columns, but all those neighbours were being pushed down".
for candidate in potential_pushable.iter().copied() {
let column_name = current_exprs.as_exprs()[candidate].output_name_arc();
let column_idx = column_map.get(column_name).unwrap();

pushable.push(is_pushable);
current_liveset.clear();
current_liveset.extend_constant(column_map.len(), false);
for (i, expr_liveset) in current_expr_livesets.iter().enumerate() {
if pushable.get(i) || i == candidate {
continue;
}
use std::ops::BitOrAssign;
(&mut current_liveset).bitor_assign(expr_liveset as &_);
}

if !current_liveset.get(*column_idx) {
pushable.set(candidate, true);
}
}

let pushable_set_bits = pushable.set_bits();
Expand Down
55 changes: 55 additions & 0 deletions py-polars/tests/unit/test_cwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,58 @@ def test_reverse_order() -> None:
)

df.collect()


def test_realias_of_unread_column_16530() -> None:
df = (
pl.LazyFrame({"x": [True]})
.with_columns(x=pl.lit(False))
.with_columns(y=~pl.col("x"))
.with_columns(y=pl.lit(False))
)

explain = df.explain()

assert explain.count("WITH_COLUMNS") == 1
assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False]}))


def test_realias_with_dependencies() -> None:
df = (
pl.LazyFrame({"x": [True]})
.with_columns(x=pl.lit(False))
.with_columns(y=~pl.col("x"))
.with_columns(y=pl.lit(False), z=pl.col("y") | True)
)

explain = df.explain()

assert explain.count("WITH_COLUMNS") == 3
assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False], "z": [True]}))


def test_refuse_pushdown_with_aliases() -> None:
df = (
pl.LazyFrame({"x": [True]})
.with_columns(x=pl.lit(False))
.with_columns(y=pl.lit(True))
.with_columns(y=pl.lit(False), z=pl.col("y") | True)
)

explain = df.explain()

assert explain.count("WITH_COLUMNS") == 2
assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False], "z": [True]}))


def test_neighbour_live_expr() -> None:
df = (
pl.LazyFrame({"x": [True]})
.with_columns(y=pl.lit(False))
.with_columns(x=pl.lit(False), z=pl.col("x") | False)
)

explain = df.explain()

assert explain.count("WITH_COLUMNS") == 1
assert df.collect().equals(pl.DataFrame({"x": [False], "y": [False], "z": [True]}))

0 comments on commit d53bcb2

Please sign in to comment.