diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 49c2e2470534..354d1dedecb3 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -13,7 +13,6 @@ Sized, ) from io import BytesIO, StringIO -from operator import itemgetter from pathlib import Path from typing import ( IO, @@ -10436,69 +10435,30 @@ def rows_by_key( {'w': 'b', 'x': 'q', 'y': 3.0, 'z': 7}], ('a', 'k'): [{'w': 'a', 'x': 'k', 'y': 4.5, 'z': 6}]}) """ - from polars.selectors import expand_selector, is_selector + key = _expand_selectors(self, key) - if is_selector(key): - key_tuple = expand_selector(target=self, selector=key) - elif not isinstance(key, str): - key_tuple = tuple(key) # type: ignore[arg-type] - else: - key_tuple = (key,) + keys = ( + iter(self.get_column(key[0])) + if len(key) == 1 + else self.select(key).iter_rows() + ) - # establish index or name-based getters for the key and data values - data_cols = [k for k in self.schema if k not in key_tuple] - if named: - get_data = itemgetter(*data_cols) - get_key = itemgetter(*key_tuple) + if include_key: + values = self else: - data_idxs, index_idxs = [], [] - for idx, c in enumerate(self.columns): - if c in key_tuple: - index_idxs.append(idx) - else: - data_idxs.append(idx) - if not index_idxs: - msg = f"no columns found for key: {key_tuple!r}" - raise ValueError(msg) - get_data = itemgetter(*data_idxs) # type: ignore[arg-type] - get_key = itemgetter(*index_idxs) # type: ignore[arg-type] + data_cols = [k for k in self.schema if k not in key] + values = self.select(data_cols) + + zipped = zip(keys, values.iter_rows(named=named)) # type: ignore[call-overload] # if unique, we expect to write just one entry per key; otherwise, we're # returning a list of rows for each key, so append into a defaultdict. - rows: dict[Any, Any] = {} if unique else defaultdict(list) - - # return named values (key -> dict | list of dicts), eg: - # "{(key,): [{col:val, col:val, ...}], - # (key,): [{col:val, col:val, ...}],}" - if named: - if unique and include_key: - rows = {get_key(row): row for row in self.iter_rows(named=True)} - else: - for d in self.iter_rows(named=True): - k = get_key(d) - if not include_key: - for ix in key_tuple: - del d[ix] # type: ignore[arg-type] - if unique: - rows[k] = d - else: - rows[k].append(d) - - # return values (key -> tuple | list of tuples), eg: - # "{(key,): [(val, val, ...)], - # (key,): [(val, val, ...)], ...}" - elif unique: - rows = ( - {get_key(row): row for row in self.iter_rows()} - if include_key - else {get_key(row): get_data(row) for row in self.iter_rows()} - ) - elif include_key: - for row in self.iter_rows(named=False): - rows[get_key(row)].append(row) + if unique: + rows = dict(zipped) else: - for row in self.iter_rows(named=False): - rows[get_key(row)].append(get_data(row)) + rows = defaultdict(list) + for key, data in zipped: + rows[key].append(data) return rows diff --git a/py-polars/tests/unit/test_rows.py b/py-polars/tests/unit/test_rows.py index 91d001a58c96..e3191ad877da 100644 --- a/py-polars/tests/unit/test_rows.py +++ b/py-polars/tests/unit/test_rows.py @@ -102,9 +102,9 @@ def test_rows_by_key() -> None: "b": ("b", "q", 3.0, 7), } assert df.rows_by_key(["x", "w"]) == { - ("a", "q"): [(1.0, 9)], - ("b", "q"): [(2.5, 8), (3.0, 7)], - ("a", "k"): [(4.5, 6)], + ("q", "a"): [(1.0, 9)], + ("q", "b"): [(2.5, 8), (3.0, 7)], + ("k", "a"): [(4.5, 6)], } assert df.rows_by_key(["w", "x"], include_key=True) == { ("a", "q"): [("a", "q", 1.0, 9)],