From 8d2544fcae996caad9aca8485c440dc022e8098c Mon Sep 17 00:00:00 2001 From: ritchie Date: Sun, 8 Oct 2023 10:35:23 +0200 Subject: [PATCH] fix: expand all literals before group_by --- crates/polars-core/src/frame/group_by/mod.rs | 13 +++++++++---- py-polars/tests/unit/operations/test_group_by.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index 08ca7f5e9f47..fc457c8bfdba 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -60,16 +60,21 @@ impl DataFrame { !by.is_empty(), ComputeError: "at least one key is required in a group_by operation" ); - let by_len = by[0].len(); + let minimal_by_len = by.iter().map(|s| s.len()).min().expect("at least 1 key"); + let df_height = self.height(); // we only throw this error if self.width > 0 // so that we can still call this on a dummy dataframe where we provide the keys - if (by_len != self.height()) && (self.width() > 0) { + if (minimal_by_len != df_height) && (self.width() > 0) { polars_ensure!( - by_len == 1, + minimal_by_len == 1, ShapeMismatch: "series used as keys should have the same length as the dataframe" ); - by[0] = by[0].new_from_index(0, self.height()) + for by_key in by.iter_mut() { + if by_key.len() == minimal_by_len { + *by_key = by_key.new_from_index(0, df_height) + } + } }; let n_partitions = _set_partition_size(); diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 956e6cc5d994..aa6b10d8f0f4 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -898,3 +898,19 @@ def test_groupby_dynamic_deprecated() -> None: expected = df.group_by_dynamic("date", every="2d").agg(pl.sum("value")) assert_frame_equal(result, expected, check_row_order=False) assert_frame_equal(result_lazy, expected, check_row_order=False) + + +def test_group_by_multiple_keys_one_literal() -> None: + df = pl.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) + + expected = {"a": [1, 2], "literal": [1, 1], "b": [5, 6]} + for streaming in [True, False]: + assert ( + df.lazy() + .group_by("a", pl.lit(1)) + .agg(pl.col("b").max()) + .sort(["a", "b"]) + .collect(streaming=streaming) + .to_dict(False) + == expected + )