Skip to content

Commit

Permalink
fix: Ensure that splitted ChunkedArray also flattens chunks (#16837)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jun 9, 2024
1 parent ceed90f commit 5553c3e
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 42 deletions.
170 changes: 142 additions & 28 deletions crates/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ macro_rules! split_array {
}};
}

// This one splits, but doesn't flatten chunks;
pub fn split_ca<T>(ca: &ChunkedArray<T>, n: usize) -> PolarsResult<Vec<ChunkedArray<T>>>
where
T: PolarsDataType,
Expand Down Expand Up @@ -139,46 +140,147 @@ pub fn split_series(s: &Series, n: usize) -> PolarsResult<Vec<Series>> {
split_array!(s, n, i64)
}

/// Split a [`DataFrame`] in `target` elements. The target doesn't have to be respected if not
/// strict. Deviation of the target might be done to create more equal size chunks.
///
/// # Panics
/// if chunks are not aligned
pub fn split_df_as_ref(df: &DataFrame, target: usize, strict: bool) -> Vec<DataFrame> {
let total_len = df.height();
#[allow(clippy::len_without_is_empty)]
pub trait Container: Clone {
fn slice(&self, offset: i64, len: usize) -> Self;

fn len(&self) -> usize;

fn iter_chunks(&self) -> impl Iterator<Item = Self>;

fn n_chunks(&self) -> usize;

fn chunk_lengths(&self) -> impl Iterator<Item = usize>;
}

impl Container for DataFrame {
fn slice(&self, offset: i64, len: usize) -> Self {
DataFrame::slice(self, offset, len)
}

fn len(&self) -> usize {
self.height()
}

fn iter_chunks(&self) -> impl Iterator<Item = Self> {
flatten_df_iter(self)
}

fn n_chunks(&self) -> usize {
DataFrame::n_chunks(self)
}

fn chunk_lengths(&self) -> impl Iterator<Item = usize> {
self.get_columns()[0].chunk_lengths()
}
}

impl<T: PolarsDataType> Container for ChunkedArray<T> {
fn slice(&self, offset: i64, len: usize) -> Self {
ChunkedArray::slice(self, offset, len)
}

fn len(&self) -> usize {
ChunkedArray::len(self)
}

fn iter_chunks(&self) -> impl Iterator<Item = Self> {
self.downcast_iter()
.map(|arr| Self::with_chunk(self.name(), arr.clone()))
}

fn n_chunks(&self) -> usize {
self.chunks().len()
}

fn chunk_lengths(&self) -> impl Iterator<Item = usize> {
ChunkedArray::chunk_lengths(self)
}
}

impl Container for Series {
fn slice(&self, offset: i64, len: usize) -> Self {
self.0.slice(offset, len)
}

fn len(&self) -> usize {
self.0.len()
}

fn iter_chunks(&self) -> impl Iterator<Item = Self> {
(0..self.0.n_chunks()).map(|i| self.select_chunk(i))
}

fn n_chunks(&self) -> usize {
self.chunks().len()
}

fn chunk_lengths(&self) -> impl Iterator<Item = usize> {
self.0.chunk_lengths()
}
}

fn split_impl<C: Container>(container: &C, target: usize, chunk_size: usize) -> Vec<C> {
let total_len = container.len();
let mut out = Vec::with_capacity(target);

for i in 0..target {
let offset = i * chunk_size;
let len = if i == (target - 1) {
total_len.saturating_sub(offset)
} else {
chunk_size
};
let container = container.slice((i * chunk_size) as i64, len);
out.push(container);
}
out
}

pub fn split<C: Container>(container: &C, target: usize) -> Vec<C> {
let total_len = container.len();
if total_len == 0 {
return vec![df.clone()];
return vec![container.clone()];
}

let chunk_size = std::cmp::max(total_len / target, 1);

if df.n_chunks() == target
&& df.get_columns()[0]
if container.n_chunks() == target
&& container
.chunk_lengths()
.all(|len| len.abs_diff(chunk_size) < 100)
{
return flatten_df_iter(df).collect();
return container.iter_chunks().collect();
}
split_impl(container, target, chunk_size)
}

let mut out = Vec::with_capacity(target);
/// Split a [`Container`] in `target` elements. The target doesn't have to be respected if not
/// Deviation of the target might be done to create more equal size chunks.
pub fn split_and_flatten<C: Container>(container: &C, target: usize) -> Vec<C> {
let total_len = container.len();
if total_len == 0 {
return vec![container.clone()];
}

if df.n_chunks() == 1 || strict {
for i in 0..target {
let offset = i * chunk_size;
let len = if i == (target - 1) {
total_len.saturating_sub(offset)
} else {
chunk_size
};
let df = df.slice((i * chunk_size) as i64, len);
out.push(df);
}
let chunk_size = std::cmp::max(total_len / target, 1);

if container.n_chunks() == target
&& container
.chunk_lengths()
.all(|len| len.abs_diff(chunk_size) < 100)
{
return container.iter_chunks().collect();
}

if container.n_chunks() == 1 {
split_impl(container, target, chunk_size)
} else {
let chunks = flatten_df_iter(df);
let mut out = Vec::with_capacity(target);
let chunks = container.iter_chunks();

'new_chunk: for mut chunk in chunks {
loop {
let h = chunk.height();
let h = chunk.len();
if h < chunk_size {
// TODO if the chunk is much smaller than chunk size, we should try to merge it with the next one.
out.push(chunk);
Expand All @@ -191,14 +293,26 @@ pub fn split_df_as_ref(df: &DataFrame, target: usize, strict: bool) -> Vec<DataF
continue 'new_chunk;
}

// This would be faster if we had a `split` operation.
// TODO! use `split` operation here. That saves a null count.
out.push(chunk.slice(0, chunk_size));
chunk = chunk.slice(chunk_size as i64, h - chunk_size);
}
}
out
}
}

out
/// Split a [`DataFrame`] in `target` elements. The target doesn't have to be respected if not
/// strict. Deviation of the target might be done to create more equal size chunks.
///
/// # Panics
/// if chunks are not aligned
pub fn split_df_as_ref(df: &DataFrame, target: usize, strict: bool) -> Vec<DataFrame> {
if strict {
split(df, target)
} else {
split_and_flatten(df, target)
}
}

#[doc(hidden)]
Expand Down
30 changes: 16 additions & 14 deletions crates/polars-ops/src/frame/join/asof/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use polars_core::hashing::{
};
use polars_core::prelude::*;
use polars_core::utils::flatten::flatten_nullable;
use polars_core::utils::{_set_partition_size, split_ca, split_df};
use polars_core::utils::{_set_partition_size, split_and_flatten};
use polars_core::{with_match_physical_float_polars_type, IdBuildHasher, POOL};
use polars_utils::abs_diff::AbsDiff;
use polars_utils::hashing::{hash_to_partition, DirtyHash};
Expand Down Expand Up @@ -169,21 +169,24 @@ where
A: for<'a> AsofJoinState<T::Physical<'a>>,
F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool,
{
let left_asof = left_asof.rechunk();
let right_asof = right_asof.rechunk();
let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk());
let left_val_arr = left_asof.downcast_iter().next().unwrap();
let right_val_arr = right_asof.downcast_iter().next().unwrap();

let n_threads = POOL.current_num_threads();
let split_by_left = split_ca(by_left, n_threads).unwrap();
let split_by_right = split_ca(by_right, n_threads).unwrap();
// `strict` is false so that we always flatten. Even if there are more chunks than threads.
let split_by_left = split_and_flatten(by_left, n_threads);
let split_by_right = split_and_flatten(by_right, n_threads);
let offsets = compute_len_offsets(split_by_left.iter().map(|s| s.len()));

// TODO: handle nulls more efficiently. Right now we just join on the value
// ignoring the validity mask, and ignore the nulls later.
let right_slices = split_by_right
.iter()
.map(|ca| ca.downcast_iter().next().unwrap().values_iter().copied())
.map(|ca| {
assert_eq!(ca.chunks().len(), 1);
ca.downcast_iter().next().unwrap().values_iter().copied()
})
.collect();
let hash_tbls = build_tables(right_slices, false);
let n_tables = hash_tbls.len();
Expand All @@ -197,6 +200,7 @@ where
let mut group_states: PlHashMap<IdxSize, A> =
PlHashMap::with_capacity(_HASHMAP_INIT_SIZE);

assert_eq!(by_left.chunks().len(), 1);
let by_left_chunk = by_left.downcast_iter().next().unwrap();
for (rel_idx_left, opt_by_left_k) in by_left_chunk.iter().enumerate() {
let Some(by_left_k) = opt_by_left_k else {
Expand Down Expand Up @@ -245,14 +249,13 @@ where
A: for<'a> AsofJoinState<T::Physical<'a>>,
F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool,
{
let left_asof = left_asof.rechunk();
let right_asof = right_asof.rechunk();
let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk());
let left_val_arr = left_asof.downcast_iter().next().unwrap();
let right_val_arr = right_asof.downcast_iter().next().unwrap();

let n_threads = POOL.current_num_threads();
let split_by_left = split_ca(by_left, n_threads).unwrap();
let split_by_right = split_ca(by_right, n_threads).unwrap();
let split_by_left = split_and_flatten(by_left, n_threads);
let split_by_right = split_and_flatten(by_right, n_threads);
let offsets = compute_len_offsets(split_by_left.iter().map(|s| s.len()));

let hb = RandomState::default();
Expand Down Expand Up @@ -311,14 +314,13 @@ where
A: for<'a> AsofJoinState<T::Physical<'a>>,
F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool,
{
let left_asof = left_asof.rechunk();
let right_asof = right_asof.rechunk();
let (left_asof, right_asof) = POOL.join(|| left_asof.rechunk(), || right_asof.rechunk());
let left_val_arr = left_asof.downcast_iter().next().unwrap();
let right_val_arr = right_asof.downcast_iter().next().unwrap();

let n_threads = POOL.current_num_threads();
let split_by_left = split_df(by_left, n_threads, false);
let split_by_right = split_df(by_right, n_threads, false);
let split_by_left = split_and_flatten(by_left, n_threads);
let split_by_right = split_and_flatten(by_right, n_threads);

let (build_hashes, random_state) =
_df_rows_to_hashes_threaded_vertical(&split_by_right, None).unwrap();
Expand Down

0 comments on commit 5553c3e

Please sign in to comment.