From 6ffc338b9549be72d695729ecb075d1954c49c7e Mon Sep 17 00:00:00 2001 From: ritchie Date: Sun, 9 Jun 2024 19:38:22 +0200 Subject: [PATCH] fix: Ensure that splitted ChunkedArray also flattens chunks --- crates/polars-core/src/utils/mod.rs | 170 +++++++++++++++--- .../polars-ops/src/frame/join/asof/groups.rs | 30 ++-- 2 files changed, 158 insertions(+), 42 deletions(-) diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index 2ff86df29503..53f1f6231126 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -104,6 +104,7 @@ macro_rules! split_array { }}; } +// This one splits, but doesn't flatten chunks; pub fn split_ca(ca: &ChunkedArray, n: usize) -> PolarsResult>> where T: PolarsDataType, @@ -139,46 +140,147 @@ pub fn split_series(s: &Series, n: usize) -> PolarsResult> { split_array!(s, n, i64) } -/// Split a [`DataFrame`] in `target` elements. The target doesn't have to be respected if not -/// strict. Deviation of the target might be done to create more equal size chunks. -/// -/// # Panics -/// if chunks are not aligned -pub fn split_df_as_ref(df: &DataFrame, target: usize, strict: bool) -> Vec { - let total_len = df.height(); +#[allow(clippy::len_without_is_empty)] +pub trait Container: Clone { + fn slice(&self, offset: i64, len: usize) -> Self; + + fn len(&self) -> usize; + + fn iter_chunks(&self) -> impl Iterator; + + fn n_chunks(&self) -> usize; + + fn chunk_lengths(&self) -> impl Iterator; +} + +impl Container for DataFrame { + fn slice(&self, offset: i64, len: usize) -> Self { + DataFrame::slice(self, offset, len) + } + + fn len(&self) -> usize { + self.height() + } + + fn iter_chunks(&self) -> impl Iterator { + flatten_df_iter(self) + } + + fn n_chunks(&self) -> usize { + DataFrame::n_chunks(self) + } + + fn chunk_lengths(&self) -> impl Iterator { + self.get_columns()[0].chunk_lengths() + } +} + +impl Container for ChunkedArray { + fn slice(&self, offset: i64, len: usize) -> Self { + ChunkedArray::slice(self, offset, len) + } + + fn len(&self) -> usize { + ChunkedArray::len(self) + } + + fn iter_chunks(&self) -> impl Iterator { + self.downcast_iter() + .map(|arr| Self::with_chunk(self.name(), arr.clone())) + } + + fn n_chunks(&self) -> usize { + self.chunks().len() + } + + fn chunk_lengths(&self) -> impl Iterator { + ChunkedArray::chunk_lengths(self) + } +} + +impl Container for Series { + fn slice(&self, offset: i64, len: usize) -> Self { + self.0.slice(offset, len) + } + + fn len(&self) -> usize { + self.0.len() + } + + fn iter_chunks(&self) -> impl Iterator { + (0..self.0.n_chunks()).map(|i| self.select_chunk(i)) + } + + fn n_chunks(&self) -> usize { + self.chunks().len() + } + + fn chunk_lengths(&self) -> impl Iterator { + self.0.chunk_lengths() + } +} + +fn split_impl(container: &C, target: usize, chunk_size: usize) -> Vec { + let total_len = container.len(); + let mut out = Vec::with_capacity(target); + + for i in 0..target { + let offset = i * chunk_size; + let len = if i == (target - 1) { + total_len.saturating_sub(offset) + } else { + chunk_size + }; + let container = container.slice((i * chunk_size) as i64, len); + out.push(container); + } + out +} + +pub fn split(container: &C, target: usize) -> Vec { + let total_len = container.len(); if total_len == 0 { - return vec![df.clone()]; + return vec![container.clone()]; } let chunk_size = std::cmp::max(total_len / target, 1); - - if df.n_chunks() == target - && df.get_columns()[0] + if container.n_chunks() == target + && container .chunk_lengths() .all(|len| len.abs_diff(chunk_size) < 100) { - return flatten_df_iter(df).collect(); + return container.iter_chunks().collect(); } + split_impl(container, target, chunk_size) +} - let mut out = Vec::with_capacity(target); +/// Split a [`Container`] in `target` elements. The target doesn't have to be respected if not +/// Deviation of the target might be done to create more equal size chunks. +pub fn split_and_flatten(container: &C, target: usize) -> Vec { + let total_len = container.len(); + if total_len == 0 { + return vec![container.clone()]; + } - if df.n_chunks() == 1 || strict { - for i in 0..target { - let offset = i * chunk_size; - let len = if i == (target - 1) { - total_len.saturating_sub(offset) - } else { - chunk_size - }; - let df = df.slice((i * chunk_size) as i64, len); - out.push(df); - } + let chunk_size = std::cmp::max(total_len / target, 1); + + if container.n_chunks() == target + && container + .chunk_lengths() + .all(|len| len.abs_diff(chunk_size) < 100) + { + return container.iter_chunks().collect(); + } + + if container.n_chunks() == 1 { + split_impl(container, target, chunk_size) } else { - let chunks = flatten_df_iter(df); + let mut out = Vec::with_capacity(target); + let chunks = container.iter_chunks(); 'new_chunk: for mut chunk in chunks { loop { - let h = chunk.height(); + let h = chunk.len(); if h < chunk_size { // TODO if the chunk is much smaller than chunk size, we should try to merge it with the next one. out.push(chunk); @@ -191,14 +293,26 @@ pub fn split_df_as_ref(df: &DataFrame, target: usize, strict: bool) -> Vec Vec { + if strict { + split(df, target) + } else { + split_and_flatten(df, target) + } } #[doc(hidden)] diff --git a/crates/polars-ops/src/frame/join/asof/groups.rs b/crates/polars-ops/src/frame/join/asof/groups.rs index 4dc0b7010a34..22a72be8de9a 100644 --- a/crates/polars-ops/src/frame/join/asof/groups.rs +++ b/crates/polars-ops/src/frame/join/asof/groups.rs @@ -9,7 +9,7 @@ use polars_core::hashing::{ }; use polars_core::prelude::*; use polars_core::utils::flatten::flatten_nullable; -use polars_core::utils::{_set_partition_size, split_ca, split_df}; +use polars_core::utils::{_set_partition_size, split_and_flatten}; use polars_core::{with_match_physical_float_polars_type, IdBuildHasher, POOL}; use polars_utils::abs_diff::AbsDiff; use polars_utils::hashing::{hash_to_partition, DirtyHash}; @@ -169,21 +169,24 @@ where A: for<'a> AsofJoinState>, F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool, { - let left_asof = left_asof.rechunk(); - let right_asof = right_asof.rechunk(); + let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk()); let left_val_arr = left_asof.downcast_iter().next().unwrap(); let right_val_arr = right_asof.downcast_iter().next().unwrap(); let n_threads = POOL.current_num_threads(); - let split_by_left = split_ca(by_left, n_threads).unwrap(); - let split_by_right = split_ca(by_right, n_threads).unwrap(); + // `strict` is false so that we always flatten. Even if there are more chunks than threads. + let split_by_left = split_and_flatten(by_left, n_threads); + let split_by_right = split_and_flatten(by_right, n_threads); let offsets = compute_len_offsets(split_by_left.iter().map(|s| s.len())); // TODO: handle nulls more efficiently. Right now we just join on the value // ignoring the validity mask, and ignore the nulls later. let right_slices = split_by_right .iter() - .map(|ca| ca.downcast_iter().next().unwrap().values_iter().copied()) + .map(|ca| { + assert_eq!(ca.chunks().len(), 1); + ca.downcast_iter().next().unwrap().values_iter().copied() + }) .collect(); let hash_tbls = build_tables(right_slices, false); let n_tables = hash_tbls.len(); @@ -197,6 +200,7 @@ where let mut group_states: PlHashMap = PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); + assert_eq!(by_left.chunks().len(), 1); let by_left_chunk = by_left.downcast_iter().next().unwrap(); for (rel_idx_left, opt_by_left_k) in by_left_chunk.iter().enumerate() { let Some(by_left_k) = opt_by_left_k else { @@ -245,14 +249,13 @@ where A: for<'a> AsofJoinState>, F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool, { - let left_asof = left_asof.rechunk(); - let right_asof = right_asof.rechunk(); + let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk()); let left_val_arr = left_asof.downcast_iter().next().unwrap(); let right_val_arr = right_asof.downcast_iter().next().unwrap(); let n_threads = POOL.current_num_threads(); - let split_by_left = split_ca(by_left, n_threads).unwrap(); - let split_by_right = split_ca(by_right, n_threads).unwrap(); + let split_by_left = split_and_flatten(by_left, n_threads); + let split_by_right = split_and_flatten(by_right, n_threads); let offsets = compute_len_offsets(split_by_left.iter().map(|s| s.len())); let hb = RandomState::default(); @@ -311,14 +314,13 @@ where A: for<'a> AsofJoinState>, F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool, { - let left_asof = left_asof.rechunk(); - let right_asof = right_asof.rechunk(); + let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk()); let left_val_arr = left_asof.downcast_iter().next().unwrap(); let right_val_arr = right_asof.downcast_iter().next().unwrap(); let n_threads = POOL.current_num_threads(); - let split_by_left = split_df(by_left, n_threads, false); - let split_by_right = split_df(by_right, n_threads, false); + let split_by_left = split_and_flatten(by_left, n_threads); + let split_by_right = split_and_flatten(by_right, n_threads); let (build_hashes, random_state) = _df_rows_to_hashes_threaded_vertical(&split_by_right, None).unwrap();