Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Improve Bitmap construction performance #15570

Merged
merged 2 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 27 additions & 38 deletions crates/polars-arrow/src/bitmap/bitmap_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,33 @@ use super::Bitmap;
use crate::bitmap::MutableBitmap;
use crate::trusted_len::TrustedLen;

/// Creates a [Vec<u8>] from an [`Iterator`] of [`BitChunk`].
/// # Safety
/// The iterator must be [`TrustedLen`].
pub unsafe fn from_chunk_iter_unchecked<T: BitChunk, I: Iterator<Item = T>>(
iterator: I,
) -> Vec<u8> {
let (_, upper) = iterator.size_hint();
let upper = upper.expect("try_from_trusted_len_iter requires an upper limit");
let len = upper * std::mem::size_of::<T>();

let mut buffer = Vec::with_capacity(len);

let mut dst = buffer.as_mut_ptr();
for item in iterator {
let bytes = item.to_ne_bytes();
for i in 0..std::mem::size_of::<T>() {
std::ptr::write(dst, bytes[i]);
dst = dst.add(1);
}
}
assert_eq!(
dst.offset_from(buffer.as_ptr()) as usize,
len,
"Trusted iterator length was not accurately reported"
);
buffer.set_len(len);
buffer
#[inline(always)]
pub(crate) fn push_bitchunk<T: BitChunk>(buffer: &mut Vec<u8>, value: T) {
buffer.extend(value.to_ne_bytes())
}

/// Creates a [`Vec<u8>`] from a [`TrustedLen`] of [`BitChunk`].
pub fn chunk_iter_to_vec<T: BitChunk, I: TrustedLen<Item = T>>(iter: I) -> Vec<u8> {
unsafe { from_chunk_iter_unchecked(iter) }
let cap = iter.size_hint().0 * std::mem::size_of::<T>();
let mut buffer = Vec::with_capacity(cap);
for v in iter {
push_bitchunk(&mut buffer, v)
}
buffer
}

fn chunk_iter_to_vec_and_remainder<T: BitChunk, I: TrustedLen<Item = T>>(
iter: I,
remainder: T,
) -> Vec<u8> {
let cap = (iter.size_hint().0 + 1) * std::mem::size_of::<T>();
let mut buffer = Vec::with_capacity(cap);
for v in iter {
push_bitchunk(&mut buffer, v)
}
push_bitchunk(&mut buffer, remainder);
debug_assert_eq!(buffer.len(), cap);
buffer
}

/// Apply a bitwise operation `op` to four inputs and return the result as a [`Bitmap`].
Expand All @@ -62,9 +57,8 @@ where
.zip(a3_chunks)
.zip(a4_chunks)
.map(|(((a1, a2), a3), a4)| op(a1, a2, a3, a4));
let buffer =
chunk_iter_to_vec(chunks.chain(std::iter::once(op(rem_a1, rem_a2, rem_a3, rem_a4))));

let buffer = chunk_iter_to_vec_and_remainder(chunks, op(rem_a1, rem_a2, rem_a3, rem_a4));
let length = a1.len();

Bitmap::from_u8_vec(buffer, length)
Expand All @@ -90,8 +84,7 @@ where
.zip(a3_chunks)
.map(|((a1, a2), a3)| op(a1, a2, a3));

let buffer = chunk_iter_to_vec(chunks.chain(std::iter::once(op(rem_a1, rem_a2, rem_a3))));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove chains as they require extra branches on iteration.


let buffer = chunk_iter_to_vec_and_remainder(chunks, op(rem_a1, rem_a2, rem_a3));
let length = a1.len();

Bitmap::from_u8_vec(buffer, length)
Expand All @@ -112,8 +105,7 @@ where
.zip(rhs_chunks)
.map(|(left, right)| op(left, right));

let buffer = chunk_iter_to_vec(chunks.chain(std::iter::once(op(rem_lhs, rem_rhs))));

let buffer = chunk_iter_to_vec_and_remainder(chunks, op(rem_lhs, rem_rhs));
let length = lhs.len();

Bitmap::from_u8_vec(buffer, length)
Expand All @@ -125,10 +117,7 @@ where
F: Fn(u64) -> u64,
{
let rem = op(iter.remainder());

let iterator = iter.map(op).chain(std::iter::once(rem));

let buffer = chunk_iter_to_vec(iterator);
let buffer = chunk_iter_to_vec_and_remainder(iter.map(op), rem);

Bitmap::from_u8_vec(buffer, length)
}
Expand Down
60 changes: 59 additions & 1 deletion crates/polars-arrow/src/compute/utils.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::borrow::Borrow;
use std::ops::{BitAnd, BitOr};

use polars_error::{polars_ensure, PolarsResult};

use crate::array::Array;
use crate::bitmap::{and_not, ternary, Bitmap};
use crate::bitmap::{and_not, push_bitchunk, ternary, Bitmap};

pub fn combine_validities_and3(
opt1: Option<&Bitmap>,
Expand Down Expand Up @@ -49,6 +50,63 @@ pub fn combine_validities_and_not(
}
}

pub fn combine_validities_and_many<B: Borrow<Bitmap>>(bitmaps: &[Option<B>]) -> Option<Bitmap> {
let mut bitmaps = bitmaps
.iter()
.flatten()
.map(|b| b.borrow())
.collect::<Vec<_>>();

match bitmaps.len() {
0 => None,
1 => bitmaps.pop().cloned(),
2 => combine_validities_and(bitmaps.pop(), bitmaps.pop()),
3 => combine_validities_and3(bitmaps.pop(), bitmaps.pop(), bitmaps.pop()),
_ => {
let mut iterators = bitmaps
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single pass/ allocation to combine bitmaps.

.iter()
.map(|v| v.fast_iter_u64())
.collect::<Vec<_>>();
let mut buffer = Vec::with_capacity(iterators.first().unwrap().size_hint().0 + 2);

'rows: loop {
// All ones so as identity for & operation
let mut out = u64::MAX;
for iter in iterators.iter_mut() {
if let Some(v) = iter.next() {
out &= v
} else {
break 'rows;
}
}
push_bitchunk(&mut buffer, out);
}

// All ones so as identity for & operation
let mut out = [u64::MAX, u64::MAX];
let mut len = 0;
for iter in iterators.into_iter() {
let (rem, rem_len) = iter.remainder();
len = rem_len;

for (out, rem) in out.iter_mut().zip(rem) {
*out &= rem;
}
}
push_bitchunk(&mut buffer, out[0]);
if len > 64 {
push_bitchunk(&mut buffer, out[1]);
}
let bitmap = Bitmap::from_u8_vec(buffer, bitmaps[0].len());
if bitmap.unset_bits() == bitmap.len() {
None
} else {
Some(bitmap)
}
},
}
}

// Errors iff the two arrays have a different length.
#[inline]
pub fn check_same_len(lhs: &dyn Array, rhs: &dyn Array) -> PolarsResult<()> {
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-arrow/src/types/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ pub trait NativeType:
+ std::ops::IndexMut<usize, Output = u8>
+ for<'a> TryFrom<&'a [u8]>
+ std::fmt::Debug
+ Default;
+ Default
+ IntoIterator<Item = u8>;

/// To bytes in little endian
fn to_le_bytes(&self) -> Self::Bytes;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use arrow::compute::utils::combine_validities_and;
use arrow::compute::utils::combine_validities_and_many;
use compare_inner::NullOrderCmp;
use polars_row::{convert_columns, EncodingField, RowsEncoded};
use polars_utils::iter::EnumerateIdxTrait;
Expand Down Expand Up @@ -121,7 +121,7 @@ pub fn encode_rows_vertical_par_unordered_broadcast_nulls(
.collect::<Vec<_>>();
let rows = _get_rows_encoded_unordered(&sliced)?;

let validity = sliced
let validities = sliced
.iter()
.flat_map(|s| {
let s = s.rechunk();
Expand All @@ -131,7 +131,9 @@ pub fn encode_rows_vertical_par_unordered_broadcast_nulls(
.into_iter()
.map(|arr| arr.validity().cloned())
})
.fold(None, |l, r| combine_validities_and(l.as_ref(), r.as_ref()));
.collect::<Vec<_>>();

let validity = combine_validities_and_many(&validities);
Ok(rows.into_array().with_validity_typed(validity))
});
let chunks = POOL.install(|| chunks.collect::<PolarsResult<Vec<_>>>());
Expand Down
10 changes: 2 additions & 8 deletions crates/polars-ops/src/frame/join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ use polars_utils::hashing::BytesHash;
use rayon::prelude::*;

use super::IntoDf;
const LHS_NAME: &str = "POLARS_K_L";
const RHS_NAME: &str = "POLARS_K_R";

pub trait DataFrameJoinOps: IntoDf {
/// Generic join method. Can be used to join on multiple columns.
Expand Down Expand Up @@ -260,12 +258,8 @@ pub trait DataFrameJoinOps: IntoDf {
};
}

let lhs_keys = prepare_keys_multiple(&selected_left, args.join_nulls)?
.into_series()
.with_name(LHS_NAME);
let rhs_keys = prepare_keys_multiple(&selected_right, args.join_nulls)?
.into_series()
.with_name(RHS_NAME);
let lhs_keys = prepare_keys_multiple(&selected_left, args.join_nulls)?.into_series();
let rhs_keys = prepare_keys_multiple(&selected_right, args.join_nulls)?.into_series();
let names_right = selected_right.iter().map(|s| s.name()).collect::<Vec<_>>();

// Multiple keys.
Expand Down
17 changes: 4 additions & 13 deletions crates/polars-ops/src/series/ops/fused.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use arrow::array::PrimitiveArray;
use arrow::compute::utils::combine_validities_and;
use arrow::compute::utils::combine_validities_and3;
use polars_core::prelude::*;
use polars_core::utils::align_chunks_ternary;
use polars_core::with_match_physical_numeric_polars_type;
Expand All @@ -11,10 +11,7 @@ fn fma_arr<T: NumericNative>(
c: &PrimitiveArray<T>,
) -> PrimitiveArray<T> {
assert_eq!(a.len(), b.len());
let validity = combine_validities_and(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could trigger two allocations/passes.

combine_validities_and(a.validity(), b.validity()).as_ref(),
c.validity(),
);
let validity = combine_validities_and3(a.validity(), b.validity(), c.validity());
let a = a.values().as_slice();
let b = b.values().as_slice();
let c = c.values().as_slice();
Expand Down Expand Up @@ -65,10 +62,7 @@ fn fsm_arr<T: NumericNative>(
c: &PrimitiveArray<T>,
) -> PrimitiveArray<T> {
assert_eq!(a.len(), b.len());
let validity = combine_validities_and(
combine_validities_and(a.validity(), b.validity()).as_ref(),
c.validity(),
);
let validity = combine_validities_and3(a.validity(), b.validity(), c.validity());
let a = a.values().as_slice();
let b = b.values().as_slice();
let c = c.values().as_slice();
Expand Down Expand Up @@ -118,10 +112,7 @@ fn fms_arr<T: NumericNative>(
c: &PrimitiveArray<T>,
) -> PrimitiveArray<T> {
assert_eq!(a.len(), b.len());
let validity = combine_validities_and(
combine_validities_and(a.validity(), b.validity()).as_ref(),
c.validity(),
);
let validity = combine_validities_and3(a.validity(), b.validity(), c.validity());
let a = a.values().as_slice();
let b = b.values().as_slice();
let c = c.values().as_slice();
Expand Down
9 changes: 5 additions & 4 deletions crates/polars-pipe/src/executors/sinks/joins/row_values.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use arrow::array::{ArrayRef, BinaryArray, StaticArray};
use arrow::compute::utils::combine_validities_and;
use arrow::compute::utils::combine_validities_and_many;
use polars_core::error::PolarsResult;
use polars_row::RowsEncoded;

Expand Down Expand Up @@ -80,11 +80,12 @@ impl RowValues {
Ok(if join_nulls {
array
} else {
let validity = self
let validities = self
.join_columns_material
.iter()
.map(|arr| arr.validity().cloned())
.fold(None, |l, r| combine_validities_and(l.as_ref(), r.as_ref()));
.map(|arr| arr.validity())
.collect::<Vec<_>>();
let validity = combine_validities_and_many(&validities);
array.with_validity_typed(validity)
})
}
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/operations/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,3 +835,25 @@ def test_join_list_non_numeric() -> None:
"lists": [["a", "b", "c"], ["a", "c", "b"], ["a", "c", "d"]],
"count": [1, 2, 1],
}


@pytest.mark.slow()
def test_join_4_columns_with_validity() -> None:
# join on 4 columns so we trigger combine validities
# use 138 as that is 2 u64 and a remainder
a = pl.DataFrame(
{"a": [None if a % 6 == 0 else a for a in range(138)]}
).with_columns(
b=pl.col("a"),
c=pl.col("a"),
d=pl.col("a"),
)

assert a.join(a, on=["a", "b", "c", "d"], how="inner", join_nulls=True).shape == (
644,
4,
)
assert a.join(a, on=["a", "b", "c", "d"], how="inner", join_nulls=False).shape == (
115,
4,
)
Loading