From a919601325d71f7b427c4d05dffaf28db636dbf6 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 10 Apr 2024 09:43:58 +0200 Subject: [PATCH] perf: Improve Bitmap construction performance (#15570) --- crates/polars-arrow/src/bitmap/bitmap_ops.rs | 65 ++++++++----------- crates/polars-arrow/src/compute/utils.rs | 60 ++++++++++++++++- crates/polars-arrow/src/types/native.rs | 3 +- .../ops/sort/arg_sort_multiple.rs | 8 ++- crates/polars-ops/src/frame/join/mod.rs | 10 +-- crates/polars-ops/src/series/ops/fused.rs | 17 ++--- .../src/executors/sinks/joins/row_values.rs | 9 +-- py-polars/tests/unit/operations/test_join.py | 22 +++++++ 8 files changed, 126 insertions(+), 68 deletions(-) diff --git a/crates/polars-arrow/src/bitmap/bitmap_ops.rs b/crates/polars-arrow/src/bitmap/bitmap_ops.rs index 433fbbc93a71..cda8a8bd2356 100644 --- a/crates/polars-arrow/src/bitmap/bitmap_ops.rs +++ b/crates/polars-arrow/src/bitmap/bitmap_ops.rs @@ -5,38 +5,33 @@ use super::Bitmap; use crate::bitmap::MutableBitmap; use crate::trusted_len::TrustedLen; -/// Creates a [Vec] from an [`Iterator`] of [`BitChunk`]. -/// # Safety -/// The iterator must be [`TrustedLen`]. -pub unsafe fn from_chunk_iter_unchecked>( - iterator: I, -) -> Vec { - let (_, upper) = iterator.size_hint(); - let upper = upper.expect("try_from_trusted_len_iter requires an upper limit"); - let len = upper * std::mem::size_of::(); - - let mut buffer = Vec::with_capacity(len); - - let mut dst = buffer.as_mut_ptr(); - for item in iterator { - let bytes = item.to_ne_bytes(); - for i in 0..std::mem::size_of::() { - std::ptr::write(dst, bytes[i]); - dst = dst.add(1); - } - } - assert_eq!( - dst.offset_from(buffer.as_ptr()) as usize, - len, - "Trusted iterator length was not accurately reported" - ); - buffer.set_len(len); - buffer +#[inline(always)] +pub(crate) fn push_bitchunk(buffer: &mut Vec, value: T) { + buffer.extend(value.to_ne_bytes()) } /// Creates a [`Vec`] from a [`TrustedLen`] of [`BitChunk`]. pub fn chunk_iter_to_vec>(iter: I) -> Vec { - unsafe { from_chunk_iter_unchecked(iter) } + let cap = iter.size_hint().0 * std::mem::size_of::(); + let mut buffer = Vec::with_capacity(cap); + for v in iter { + push_bitchunk(&mut buffer, v) + } + buffer +} + +fn chunk_iter_to_vec_and_remainder>( + iter: I, + remainder: T, +) -> Vec { + let cap = (iter.size_hint().0 + 1) * std::mem::size_of::(); + let mut buffer = Vec::with_capacity(cap); + for v in iter { + push_bitchunk(&mut buffer, v) + } + push_bitchunk(&mut buffer, remainder); + debug_assert_eq!(buffer.len(), cap); + buffer } /// Apply a bitwise operation `op` to four inputs and return the result as a [`Bitmap`]. @@ -62,9 +57,8 @@ where .zip(a3_chunks) .zip(a4_chunks) .map(|(((a1, a2), a3), a4)| op(a1, a2, a3, a4)); - let buffer = - chunk_iter_to_vec(chunks.chain(std::iter::once(op(rem_a1, rem_a2, rem_a3, rem_a4)))); + let buffer = chunk_iter_to_vec_and_remainder(chunks, op(rem_a1, rem_a2, rem_a3, rem_a4)); let length = a1.len(); Bitmap::from_u8_vec(buffer, length) @@ -90,8 +84,7 @@ where .zip(a3_chunks) .map(|((a1, a2), a3)| op(a1, a2, a3)); - let buffer = chunk_iter_to_vec(chunks.chain(std::iter::once(op(rem_a1, rem_a2, rem_a3)))); - + let buffer = chunk_iter_to_vec_and_remainder(chunks, op(rem_a1, rem_a2, rem_a3)); let length = a1.len(); Bitmap::from_u8_vec(buffer, length) @@ -112,8 +105,7 @@ where .zip(rhs_chunks) .map(|(left, right)| op(left, right)); - let buffer = chunk_iter_to_vec(chunks.chain(std::iter::once(op(rem_lhs, rem_rhs)))); - + let buffer = chunk_iter_to_vec_and_remainder(chunks, op(rem_lhs, rem_rhs)); let length = lhs.len(); Bitmap::from_u8_vec(buffer, length) @@ -125,10 +117,7 @@ where F: Fn(u64) -> u64, { let rem = op(iter.remainder()); - - let iterator = iter.map(op).chain(std::iter::once(rem)); - - let buffer = chunk_iter_to_vec(iterator); + let buffer = chunk_iter_to_vec_and_remainder(iter.map(op), rem); Bitmap::from_u8_vec(buffer, length) } diff --git a/crates/polars-arrow/src/compute/utils.rs b/crates/polars-arrow/src/compute/utils.rs index 3198abd444ca..0b8e1ecd69f4 100644 --- a/crates/polars-arrow/src/compute/utils.rs +++ b/crates/polars-arrow/src/compute/utils.rs @@ -1,9 +1,10 @@ +use std::borrow::Borrow; use std::ops::{BitAnd, BitOr}; use polars_error::{polars_ensure, PolarsResult}; use crate::array::Array; -use crate::bitmap::{and_not, ternary, Bitmap}; +use crate::bitmap::{and_not, push_bitchunk, ternary, Bitmap}; pub fn combine_validities_and3( opt1: Option<&Bitmap>, @@ -49,6 +50,63 @@ pub fn combine_validities_and_not( } } +pub fn combine_validities_and_many>(bitmaps: &[Option]) -> Option { + let mut bitmaps = bitmaps + .iter() + .flatten() + .map(|b| b.borrow()) + .collect::>(); + + match bitmaps.len() { + 0 => None, + 1 => bitmaps.pop().cloned(), + 2 => combine_validities_and(bitmaps.pop(), bitmaps.pop()), + 3 => combine_validities_and3(bitmaps.pop(), bitmaps.pop(), bitmaps.pop()), + _ => { + let mut iterators = bitmaps + .iter() + .map(|v| v.fast_iter_u64()) + .collect::>(); + let mut buffer = Vec::with_capacity(iterators.first().unwrap().size_hint().0 + 2); + + 'rows: loop { + // All ones so as identity for & operation + let mut out = u64::MAX; + for iter in iterators.iter_mut() { + if let Some(v) = iter.next() { + out &= v + } else { + break 'rows; + } + } + push_bitchunk(&mut buffer, out); + } + + // All ones so as identity for & operation + let mut out = [u64::MAX, u64::MAX]; + let mut len = 0; + for iter in iterators.into_iter() { + let (rem, rem_len) = iter.remainder(); + len = rem_len; + + for (out, rem) in out.iter_mut().zip(rem) { + *out &= rem; + } + } + push_bitchunk(&mut buffer, out[0]); + if len > 64 { + push_bitchunk(&mut buffer, out[1]); + } + let bitmap = Bitmap::from_u8_vec(buffer, bitmaps[0].len()); + if bitmap.unset_bits() == bitmap.len() { + None + } else { + Some(bitmap) + } + }, + } +} + // Errors iff the two arrays have a different length. #[inline] pub fn check_same_len(lhs: &dyn Array, rhs: &dyn Array) -> PolarsResult<()> { diff --git a/crates/polars-arrow/src/types/native.rs b/crates/polars-arrow/src/types/native.rs index f167170d0a29..6f869df32602 100644 --- a/crates/polars-arrow/src/types/native.rs +++ b/crates/polars-arrow/src/types/native.rs @@ -38,7 +38,8 @@ pub trait NativeType: + std::ops::IndexMut + for<'a> TryFrom<&'a [u8]> + std::fmt::Debug - + Default; + + Default + + IntoIterator; /// To bytes in little endian fn to_le_bytes(&self) -> Self::Bytes; diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index ed972bdf48ec..170639b364be 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -1,4 +1,4 @@ -use arrow::compute::utils::combine_validities_and; +use arrow::compute::utils::combine_validities_and_many; use compare_inner::NullOrderCmp; use polars_row::{convert_columns, EncodingField, RowsEncoded}; use polars_utils::iter::EnumerateIdxTrait; @@ -121,7 +121,7 @@ pub fn encode_rows_vertical_par_unordered_broadcast_nulls( .collect::>(); let rows = _get_rows_encoded_unordered(&sliced)?; - let validity = sliced + let validities = sliced .iter() .flat_map(|s| { let s = s.rechunk(); @@ -131,7 +131,9 @@ pub fn encode_rows_vertical_par_unordered_broadcast_nulls( .into_iter() .map(|arr| arr.validity().cloned()) }) - .fold(None, |l, r| combine_validities_and(l.as_ref(), r.as_ref())); + .collect::>(); + + let validity = combine_validities_and_many(&validities); Ok(rows.into_array().with_validity_typed(validity)) }); let chunks = POOL.install(|| chunks.collect::>>()); diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index 442d9bdd172f..2c06882b4c35 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -45,8 +45,6 @@ use polars_utils::hashing::BytesHash; use rayon::prelude::*; use super::IntoDf; -const LHS_NAME: &str = "POLARS_K_L"; -const RHS_NAME: &str = "POLARS_K_R"; pub trait DataFrameJoinOps: IntoDf { /// Generic join method. Can be used to join on multiple columns. @@ -260,12 +258,8 @@ pub trait DataFrameJoinOps: IntoDf { }; } - let lhs_keys = prepare_keys_multiple(&selected_left, args.join_nulls)? - .into_series() - .with_name(LHS_NAME); - let rhs_keys = prepare_keys_multiple(&selected_right, args.join_nulls)? - .into_series() - .with_name(RHS_NAME); + let lhs_keys = prepare_keys_multiple(&selected_left, args.join_nulls)?.into_series(); + let rhs_keys = prepare_keys_multiple(&selected_right, args.join_nulls)?.into_series(); let names_right = selected_right.iter().map(|s| s.name()).collect::>(); // Multiple keys. diff --git a/crates/polars-ops/src/series/ops/fused.rs b/crates/polars-ops/src/series/ops/fused.rs index d52aecd12774..7456cdfc6f70 100644 --- a/crates/polars-ops/src/series/ops/fused.rs +++ b/crates/polars-ops/src/series/ops/fused.rs @@ -1,5 +1,5 @@ use arrow::array::PrimitiveArray; -use arrow::compute::utils::combine_validities_and; +use arrow::compute::utils::combine_validities_and3; use polars_core::prelude::*; use polars_core::utils::align_chunks_ternary; use polars_core::with_match_physical_numeric_polars_type; @@ -11,10 +11,7 @@ fn fma_arr( c: &PrimitiveArray, ) -> PrimitiveArray { assert_eq!(a.len(), b.len()); - let validity = combine_validities_and( - combine_validities_and(a.validity(), b.validity()).as_ref(), - c.validity(), - ); + let validity = combine_validities_and3(a.validity(), b.validity(), c.validity()); let a = a.values().as_slice(); let b = b.values().as_slice(); let c = c.values().as_slice(); @@ -65,10 +62,7 @@ fn fsm_arr( c: &PrimitiveArray, ) -> PrimitiveArray { assert_eq!(a.len(), b.len()); - let validity = combine_validities_and( - combine_validities_and(a.validity(), b.validity()).as_ref(), - c.validity(), - ); + let validity = combine_validities_and3(a.validity(), b.validity(), c.validity()); let a = a.values().as_slice(); let b = b.values().as_slice(); let c = c.values().as_slice(); @@ -118,10 +112,7 @@ fn fms_arr( c: &PrimitiveArray, ) -> PrimitiveArray { assert_eq!(a.len(), b.len()); - let validity = combine_validities_and( - combine_validities_and(a.validity(), b.validity()).as_ref(), - c.validity(), - ); + let validity = combine_validities_and3(a.validity(), b.validity(), c.validity()); let a = a.values().as_slice(); let b = b.values().as_slice(); let c = c.values().as_slice(); diff --git a/crates/polars-pipe/src/executors/sinks/joins/row_values.rs b/crates/polars-pipe/src/executors/sinks/joins/row_values.rs index b144e98cf87d..ecbbbadac8b0 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/row_values.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/row_values.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, BinaryArray, StaticArray}; -use arrow::compute::utils::combine_validities_and; +use arrow::compute::utils::combine_validities_and_many; use polars_core::error::PolarsResult; use polars_row::RowsEncoded; @@ -80,11 +80,12 @@ impl RowValues { Ok(if join_nulls { array } else { - let validity = self + let validities = self .join_columns_material .iter() - .map(|arr| arr.validity().cloned()) - .fold(None, |l, r| combine_validities_and(l.as_ref(), r.as_ref())); + .map(|arr| arr.validity()) + .collect::>(); + let validity = combine_validities_and_many(&validities); array.with_validity_typed(validity) }) } diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index 2c9e441def3e..ea8ec9e8e2fe 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -835,3 +835,25 @@ def test_join_list_non_numeric() -> None: "lists": [["a", "b", "c"], ["a", "c", "b"], ["a", "c", "d"]], "count": [1, 2, 1], } + + +@pytest.mark.slow() +def test_join_4_columns_with_validity() -> None: + # join on 4 columns so we trigger combine validities + # use 138 as that is 2 u64 and a remainder + a = pl.DataFrame( + {"a": [None if a % 6 == 0 else a for a in range(138)]} + ).with_columns( + b=pl.col("a"), + c=pl.col("a"), + d=pl.col("a"), + ) + + assert a.join(a, on=["a", "b", "c", "d"], how="inner", join_nulls=True).shape == ( + 644, + 4, + ) + assert a.join(a, on=["a", "b", "c", "d"], how="inner", join_nulls=False).shape == ( + 115, + 4, + )