From 4dc17d9f32fa2f597b6d907caeac4679285ebd42 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 29 May 2024 09:18:47 +0200 Subject: [PATCH] fix: Ensure strict chunking in chunked partitioned group by (#16561) --- crates/polars-core/src/frame/mod.rs | 4 ++-- crates/polars-core/src/utils/mod.rs | 4 ++-- crates/polars-io/src/parquet/read/read_impl.rs | 2 +- .../physical_plan/executors/group_by_partitioned.rs | 4 ++-- crates/polars-ops/src/frame/join/asof/groups.rs | 4 ++-- .../polars-pipe/src/executors/sinks/group_by/ooc.rs | 2 +- .../polars-pipe/src/executors/sinks/sort/source.rs | 2 +- crates/polars-pipe/src/executors/sources/frame.rs | 2 +- crates/polars/tests/it/core/joins.rs | 5 +++-- py-polars/tests/unit/operations/test_group_by.py | 12 ++++++++++++ 10 files changed, 27 insertions(+), 14 deletions(-) diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index ce38e941d8d9..2653eec0939e 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -1627,7 +1627,7 @@ impl DataFrame { let n_threads = POOL.current_num_threads(); let masks = split_ca(mask, n_threads).unwrap(); - let dfs = split_df(self, n_threads); + let dfs = split_df(self, n_threads, false); let dfs: PolarsResult> = POOL.install(|| { masks .par_iter() @@ -2826,7 +2826,7 @@ impl DataFrame { &mut self, hasher_builder: Option, ) -> PolarsResult { - let dfs = split_df(self, POOL.current_num_threads()); + let dfs = split_df(self, POOL.current_num_threads(), false); let (cas, _) = _df_rows_to_hashes_threaded_vertical(&dfs, hasher_builder)?; let mut iter = cas.into_iter(); diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index 12da72a1b290..2ff86df29503 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -204,13 +204,13 @@ pub fn split_df_as_ref(df: &DataFrame, target: usize, strict: bool) -> Vec Vec { +pub fn split_df(df: &mut DataFrame, target: usize, strict: bool) -> Vec { if target == 0 || df.is_empty() { return vec![df.clone()]; } // make sure that chunks are aligned. df.align_chunks(); - split_df_as_ref(df, target, false) + split_df_as_ref(df, target, strict) } pub fn slice_slice(vals: &[T], offset: i64, len: usize) -> &[T] { diff --git a/crates/polars-io/src/parquet/read/read_impl.rs b/crates/polars-io/src/parquet/read/read_impl.rs index 5647726ca16c..0226e0679ba2 100644 --- a/crates/polars-io/src/parquet/read/read_impl.rs +++ b/crates/polars-io/src/parquet/read/read_impl.rs @@ -772,7 +772,7 @@ impl BatchedParquetReader { // make sure that the chunks are not too large let n = df.height() / self.chunk_size; if n > 1 { - for df in split_df(&mut df, n) { + for df in split_df(&mut df, n, false) { self.chunks_fifo.push_back(df) } } else { diff --git a/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs b/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs index 94cb433afed5..3867012d3f0c 100644 --- a/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs +++ b/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs @@ -71,13 +71,13 @@ fn run_partitions( // We do a partitioned group_by. // Meaning that we first do the group_by operation arbitrarily // split on several threads. Than the final result we apply the same group_by again. - let dfs = split_df(df, n_threads); + let dfs = split_df(df, n_threads, true); let phys_aggs = &exec.phys_aggs; let keys = &exec.phys_keys; let mut keys = DataFrame::from_iter(compute_keys(keys, df, state)?); - let splitted_keys = split_df(&mut keys, n_threads); + let splitted_keys = split_df(&mut keys, n_threads, true); POOL.install(|| { dfs.into_par_iter() diff --git a/crates/polars-ops/src/frame/join/asof/groups.rs b/crates/polars-ops/src/frame/join/asof/groups.rs index 5222b36f745c..4dc0b7010a34 100644 --- a/crates/polars-ops/src/frame/join/asof/groups.rs +++ b/crates/polars-ops/src/frame/join/asof/groups.rs @@ -317,8 +317,8 @@ where let right_val_arr = right_asof.downcast_iter().next().unwrap(); let n_threads = POOL.current_num_threads(); - let split_by_left = split_df(by_left, n_threads); - let split_by_right = split_df(by_right, n_threads); + let split_by_left = split_df(by_left, n_threads, false); + let split_by_right = split_df(by_right, n_threads, false); let (build_hashes, random_state) = _df_rows_to_hashes_threaded_vertical(&split_by_right, None).unwrap(); diff --git a/crates/polars-pipe/src/executors/sinks/group_by/ooc.rs b/crates/polars-pipe/src/executors/sinks/group_by/ooc.rs index b83767694ffd..f80979334e5e 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/ooc.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/ooc.rs @@ -99,7 +99,7 @@ impl Source for GroupBySource { } } - let dfs = split_df(&mut df, self.morsels_per_sink); + let dfs = split_df(&mut df, self.morsels_per_sink, false); let chunks = dfs .into_iter() .map(|data| { diff --git a/crates/polars-pipe/src/executors/sinks/sort/source.rs b/crates/polars-pipe/src/executors/sinks/sort/source.rs index 47a37cd2ea8c..1c1fa2984a0e 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/source.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/source.rs @@ -133,7 +133,7 @@ impl SortSource { }?; // convert to chunks - let dfs = split_df(&mut df, self.n_threads); + let dfs = split_df(&mut df, self.n_threads, true); Ok(SourceResult::GotMoreData(self.finish_batch(dfs))) } fn print_verbose(&self, verbose: bool) { diff --git a/crates/polars-pipe/src/executors/sources/frame.rs b/crates/polars-pipe/src/executors/sources/frame.rs index d4346a7d7a7a..3d3d6d2298f1 100644 --- a/crates/polars-pipe/src/executors/sources/frame.rs +++ b/crates/polars-pipe/src/executors/sources/frame.rs @@ -18,7 +18,7 @@ pub struct DataFrameSource { impl DataFrameSource { pub(crate) fn from_df(mut df: DataFrame) -> Self { let n_threads = POOL.current_num_threads(); - let dfs = split_df(&mut df, n_threads); + let dfs = split_df(&mut df, n_threads, false); let dfs = dfs.into_iter().enumerate(); Self { dfs, n_threads } } diff --git a/crates/polars/tests/it/core/joins.rs b/crates/polars/tests/it/core/joins.rs index ac2acb8d91db..47baf1388ecd 100644 --- a/crates/polars/tests/it/core/joins.rs +++ b/crates/polars/tests/it/core/joins.rs @@ -16,8 +16,9 @@ fn test_chunked_left_join() -> PolarsResult<()> { "plays" => ["guitar", "bass", "guitar"] ]?; - let band_instruments = accumulate_dataframes_vertical(split_df(&mut band_instruments, 2))?; - let band_members = accumulate_dataframes_vertical(split_df(&mut band_members, 2))?; + let band_instruments = + accumulate_dataframes_vertical(split_df(&mut band_instruments, 2, false))?; + let band_members = accumulate_dataframes_vertical(split_df(&mut band_members, 2, false))?; assert_eq!(band_instruments.n_chunks(), 2); assert_eq!(band_members.n_chunks(), 2); diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index dc758b1709fe..6d61b2e7db5d 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -1059,3 +1059,15 @@ def test_boolean_min_max_agg() -> None: schema=schema, ) assert_frame_equal(result, expected) + + +def test_partitioned_group_by_chunked(partition_limit: int) -> None: + n = partition_limit + df1 = pl.DataFrame(np.random.randn(n, 2)) + df2 = pl.DataFrame(np.random.randn(n, 2)) + gps = pl.Series(name="oo", values=[0] * n + [1] * n) + df = pl.concat([df1, df2], rechunk=False) + assert_frame_equal( + df.group_by(gps).sum().sort("oo"), + df.rechunk().group_by(gps, maintain_order=True).sum(), + )