Skip to content

Commit

Permalink
fix: Ensure strict chunking in chunked partitioned group by (#16561)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored May 29, 2024
1 parent 4ecc8a6 commit 4dc17d9
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 14 deletions.
4 changes: 2 additions & 2 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,7 @@ impl DataFrame {
let n_threads = POOL.current_num_threads();

let masks = split_ca(mask, n_threads).unwrap();
let dfs = split_df(self, n_threads);
let dfs = split_df(self, n_threads, false);
let dfs: PolarsResult<Vec<_>> = POOL.install(|| {
masks
.par_iter()
Expand Down Expand Up @@ -2826,7 +2826,7 @@ impl DataFrame {
&mut self,
hasher_builder: Option<ahash::RandomState>,
) -> PolarsResult<UInt64Chunked> {
let dfs = split_df(self, POOL.current_num_threads());
let dfs = split_df(self, POOL.current_num_threads(), false);
let (cas, _) = _df_rows_to_hashes_threaded_vertical(&dfs, hasher_builder)?;

let mut iter = cas.into_iter();
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,13 @@ pub fn split_df_as_ref(df: &DataFrame, target: usize, strict: bool) -> Vec<DataF
#[doc(hidden)]
/// Split a [`DataFrame`] into `n` parts. We take a `&mut` to be able to repartition/align chunks.
/// `strict` in that it respects `n` even if the chunks are suboptimal.
pub fn split_df(df: &mut DataFrame, target: usize) -> Vec<DataFrame> {
pub fn split_df(df: &mut DataFrame, target: usize, strict: bool) -> Vec<DataFrame> {
if target == 0 || df.is_empty() {
return vec![df.clone()];
}
// make sure that chunks are aligned.
df.align_chunks();
split_df_as_ref(df, target, false)
split_df_as_ref(df, target, strict)
}

pub fn slice_slice<T>(vals: &[T], offset: i64, len: usize) -> &[T] {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/parquet/read/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ impl BatchedParquetReader {
// make sure that the chunks are not too large
let n = df.height() / self.chunk_size;
if n > 1 {
for df in split_df(&mut df, n) {
for df in split_df(&mut df, n, false) {
self.chunks_fifo.push_back(df)
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ fn run_partitions(
// We do a partitioned group_by.
// Meaning that we first do the group_by operation arbitrarily
// split on several threads. Than the final result we apply the same group_by again.
let dfs = split_df(df, n_threads);
let dfs = split_df(df, n_threads, true);

let phys_aggs = &exec.phys_aggs;
let keys = &exec.phys_keys;

let mut keys = DataFrame::from_iter(compute_keys(keys, df, state)?);
let splitted_keys = split_df(&mut keys, n_threads);
let splitted_keys = split_df(&mut keys, n_threads, true);

POOL.install(|| {
dfs.into_par_iter()
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-ops/src/frame/join/asof/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ where
let right_val_arr = right_asof.downcast_iter().next().unwrap();

let n_threads = POOL.current_num_threads();
let split_by_left = split_df(by_left, n_threads);
let split_by_right = split_df(by_right, n_threads);
let split_by_left = split_df(by_left, n_threads, false);
let split_by_right = split_df(by_right, n_threads, false);

let (build_hashes, random_state) =
_df_rows_to_hashes_threaded_vertical(&split_by_right, None).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-pipe/src/executors/sinks/group_by/ooc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl Source for GroupBySource {
}
}

let dfs = split_df(&mut df, self.morsels_per_sink);
let dfs = split_df(&mut df, self.morsels_per_sink, false);
let chunks = dfs
.into_iter()
.map(|data| {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-pipe/src/executors/sinks/sort/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ impl SortSource {
}?;

// convert to chunks
let dfs = split_df(&mut df, self.n_threads);
let dfs = split_df(&mut df, self.n_threads, true);
Ok(SourceResult::GotMoreData(self.finish_batch(dfs)))
}
fn print_verbose(&self, verbose: bool) {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-pipe/src/executors/sources/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub struct DataFrameSource {
impl DataFrameSource {
pub(crate) fn from_df(mut df: DataFrame) -> Self {
let n_threads = POOL.current_num_threads();
let dfs = split_df(&mut df, n_threads);
let dfs = split_df(&mut df, n_threads, false);
let dfs = dfs.into_iter().enumerate();
Self { dfs, n_threads }
}
Expand Down
5 changes: 3 additions & 2 deletions crates/polars/tests/it/core/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ fn test_chunked_left_join() -> PolarsResult<()> {
"plays" => ["guitar", "bass", "guitar"]
]?;

let band_instruments = accumulate_dataframes_vertical(split_df(&mut band_instruments, 2))?;
let band_members = accumulate_dataframes_vertical(split_df(&mut band_members, 2))?;
let band_instruments =
accumulate_dataframes_vertical(split_df(&mut band_instruments, 2, false))?;
let band_members = accumulate_dataframes_vertical(split_df(&mut band_members, 2, false))?;
assert_eq!(band_instruments.n_chunks(), 2);
assert_eq!(band_members.n_chunks(), 2);

Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/operations/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,3 +1059,15 @@ def test_boolean_min_max_agg() -> None:
schema=schema,
)
assert_frame_equal(result, expected)


def test_partitioned_group_by_chunked(partition_limit: int) -> None:
n = partition_limit
df1 = pl.DataFrame(np.random.randn(n, 2))
df2 = pl.DataFrame(np.random.randn(n, 2))
gps = pl.Series(name="oo", values=[0] * n + [1] * n)
df = pl.concat([df1, df2], rechunk=False)
assert_frame_equal(
df.group_by(gps).sum().sort("oo"),
df.rechunk().group_by(gps, maintain_order=True).sum(),
)

0 comments on commit 4dc17d9

Please sign in to comment.