From 782a1547cdf57bbb9a8df282bc14945eff281196 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Thu, 30 May 2024 08:44:41 +0200 Subject: [PATCH] fix(rust): small safety issue in CWC filtermap I woke up in the middle of the night and realized I wrote an unsound implementation here. This fixes an issue in CWC where some unsafe code could potentially lead to unsoundness. This should never actually have triggered. But it the spirit of keeping things maintainable in the future, I thought I would properly fix it. --- .../optimizer/cluster_with_columns.rs | 149 +++++++++++------- 1 file changed, 89 insertions(+), 60 deletions(-) diff --git a/crates/polars-plan/src/logical_plan/optimizer/cluster_with_columns.rs b/crates/polars-plan/src/logical_plan/optimizer/cluster_with_columns.rs index 8ce64f2448bf..e3244076d4ad 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cluster_with_columns.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cluster_with_columns.rs @@ -29,59 +29,6 @@ fn column_map_set(bitset: &mut MutableBitmap, column_map: &mut ColumnMap, column }); } -/// Perform a inplace `filtermap` over two vectors at the same time. -fn inplace_zip_filtermap( - x: &mut Vec, - y: &mut Vec, - mut f: impl FnMut(&mut T, &mut U) -> bool, -) { - assert_eq!(x.len(), y.len()); - - let mut num_deleted = 0; - - let x_ptr = x.as_mut_ptr(); - let y_ptr = y.as_mut_ptr(); - - // SAFETY: - // - // We know we have a exclusive reference to x and y. - // - // We know that `i` is always smaller than `x.len()` and `y.len()`. Furthermore, we also know - // that `i - num_deleted > 0`. - // - // We know we don't have ownership of any element in x or y when we call `f`, so it is safe to - // panic. - // - // Items that are deleted are also dropped. - for i in 0..x.len() { - let xi = unsafe { x_ptr.wrapping_add(i).as_mut().unwrap_unchecked() }; - let yi = unsafe { y_ptr.wrapping_add(i).as_mut().unwrap_unchecked() }; - - // We cannot just give `f` ownership over x[i] and y[i], because a panic would then mean - // that x[i] and y[i] are dropped twice. - let do_use = f(xi, yi); - - // Now we take ownership of x[i] and y[i] - let xi = unsafe { x_ptr.wrapping_add(i).read() }; - let yi = unsafe { y_ptr.wrapping_add(i).read() }; - - if do_use { - unsafe { - x_ptr.wrapping_add(i - num_deleted).write(xi); - y_ptr.wrapping_add(i - num_deleted).write(yi); - } - } else { - // Here we drop x[i] and y[i] which is intentional - num_deleted += 1; - } - } - - unsafe { - x.set_len(x.len() - num_deleted); - y.set_len(y.len() - num_deleted); - } -} - pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) { let mut ir_stack = Vec::with_capacity(16); ir_stack.push(root); @@ -172,12 +119,13 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) inplace_zip_filtermap( current_exprs.exprs_mut(), &mut current_expr_livesets, - |expr, liveset| { - let does_input_assign_column_that_expr_used = input_genset.intersects_with(liveset); + |mut expr, liveset| { + let does_input_assign_column_that_expr_used = + input_genset.intersects_with(&liveset); if does_input_assign_column_that_expr_used { pushable.push(false); - return true; + return Some((expr, liveset)); } let column_name = expr.output_name_arc(); @@ -207,8 +155,8 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) // @NOTE // Since we are reassigning a column and we are pushing to the input, we do // not need to change the schema of the current or input nodes. - std::mem::swap(expr, input_expr); - return false; + std::mem::swap(&mut expr, input_expr); + return None; } // We cannot have multiple assignments to the same column in one WITH_COLUMNS @@ -220,7 +168,7 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) if !does_input_alias_also_expr && is_alias_live_in_current { potential_pushable.push(pushable.len()); pushable.push(false); - return true; + return Some((expr, liveset)); } !does_input_alias_also_expr && !is_alias_live_in_current @@ -229,7 +177,7 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) }; pushable.push(is_pushable); - true + Some((expr, liveset)) }, ); @@ -351,3 +299,84 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) } } } + +/// Perform a inplace `filtermap` over two vectors at the same time. +fn inplace_zip_filtermap( + x: &mut Vec, + y: &mut Vec, + mut f: impl FnMut(T, U) -> Option<(T, U)>, +) { + assert_eq!(x.len(), y.len()); + + let length = x.len(); + + struct OwnedBuffer { + end: *mut T, + length: usize, + } + + impl Drop for OwnedBuffer { + fn drop(&mut self) { + for i in 0..self.length { + unsafe { self.end.wrapping_sub(i + 1).read() }; + } + } + } + + let x_ptr = x.as_mut_ptr(); + let y_ptr = y.as_mut_ptr(); + + let mut x_buf = OwnedBuffer { + end: x_ptr.wrapping_add(length), + length, + }; + let mut y_buf = OwnedBuffer { + end: y_ptr.wrapping_add(length), + length, + }; + + // SAFETY: All items are now owned by `x_buf` and `y_buf`. Since we know that `x_buf` and + // `y_buf` will be dropped before the vecs representing `x` and `y`, this is safe. + unsafe { + x.set_len(0); + y.set_len(0); + } + + // SAFETY: + // + // We know we have a exclusive reference to x and y. + // + // We know that `i` is always smaller than `x.len()` and `y.len()`. Furthermore, we also know + // that `i - num_deleted > 0`. + // + // Items are dropped exactly once, even if `f` panics. + for i in 0..length { + let xi = unsafe { x_ptr.wrapping_add(i).read() }; + let yi = unsafe { y_ptr.wrapping_add(i).read() }; + + x_buf.length -= 1; + y_buf.length -= 1; + + // We hold the invariant here that all items that are not yet deleted are either in + // - `xi` or `yi` + // - `x_buf` or `y_buf` + // ` `x` or `y` + // + // This way if `f` ever panics, we are sure that all items are dropped exactly once. + // Deleted items will be dropped when they are deleted. + let result = f(xi, yi); + + if let Some((xi, yi)) = result { + x.push(xi); + y.push(yi); + } + } + + debug_assert_eq!(x_buf.length, 0); + debug_assert_eq!(y_buf.length, 0); + + // We are safe to forget `x_buf` and `y_buf` here since they will not deallocate anything + // anymore. + std::mem::forget(x_buf); + std::mem::forget(y_buf); +}