Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): DataFrame rows_by_key returning key tuples with elements in wrong order #19486

Merged
merged 7 commits into from
Nov 18, 2024
74 changes: 17 additions & 57 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
Sized,
)
from io import BytesIO, StringIO
from operator import itemgetter
from pathlib import Path
from typing import (
IO,
Expand Down Expand Up @@ -10436,69 +10435,30 @@ def rows_by_key(
{'w': 'b', 'x': 'q', 'y': 3.0, 'z': 7}],
('a', 'k'): [{'w': 'a', 'x': 'k', 'y': 4.5, 'z': 6}]})
"""
from polars.selectors import expand_selector, is_selector
key = _expand_selectors(self, key)

if is_selector(key):
key_tuple = expand_selector(target=self, selector=key)
elif not isinstance(key, str):
key_tuple = tuple(key) # type: ignore[arg-type]
else:
key_tuple = (key,)
keys = (
iter(self.get_column(key[0]))
if len(key) == 1
else self.select(key).iter_rows()
)

# establish index or name-based getters for the key and data values
data_cols = [k for k in self.schema if k not in key_tuple]
if named:
get_data = itemgetter(*data_cols)
get_key = itemgetter(*key_tuple)
if include_key:
values = self
else:
data_idxs, index_idxs = [], []
for idx, c in enumerate(self.columns):
if c in key_tuple:
index_idxs.append(idx)
else:
data_idxs.append(idx)
if not index_idxs:
msg = f"no columns found for key: {key_tuple!r}"
raise ValueError(msg)
get_data = itemgetter(*data_idxs) # type: ignore[arg-type]
get_key = itemgetter(*index_idxs) # type: ignore[arg-type]
data_cols = [k for k in self.schema if k not in key]
values = self.select(data_cols)

zipped = zip(keys, values.iter_rows(named=named)) # type: ignore[call-overload]

# if unique, we expect to write just one entry per key; otherwise, we're
# returning a list of rows for each key, so append into a defaultdict.
rows: dict[Any, Any] = {} if unique else defaultdict(list)

# return named values (key -> dict | list of dicts), eg:
# "{(key,): [{col:val, col:val, ...}],
# (key,): [{col:val, col:val, ...}],}"
if named:
if unique and include_key:
rows = {get_key(row): row for row in self.iter_rows(named=True)}
else:
for d in self.iter_rows(named=True):
k = get_key(d)
if not include_key:
for ix in key_tuple:
del d[ix] # type: ignore[arg-type]
if unique:
rows[k] = d
else:
rows[k].append(d)

# return values (key -> tuple | list of tuples), eg:
# "{(key,): [(val, val, ...)],
# (key,): [(val, val, ...)], ...}"
elif unique:
rows = (
{get_key(row): row for row in self.iter_rows()}
if include_key
else {get_key(row): get_data(row) for row in self.iter_rows()}
)
elif include_key:
for row in self.iter_rows(named=False):
rows[get_key(row)].append(row)
if unique:
rows = dict(zipped)
else:
for row in self.iter_rows(named=False):
rows[get_key(row)].append(get_data(row))
rows = defaultdict(list)
for key, data in zipped:
rows[key].append(data)

return rows

Expand Down
6 changes: 3 additions & 3 deletions py-polars/tests/unit/test_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def test_rows_by_key() -> None:
"b": ("b", "q", 3.0, 7),
}
assert df.rows_by_key(["x", "w"]) == {
("a", "q"): [(1.0, 9)],
("b", "q"): [(2.5, 8), (3.0, 7)],
("a", "k"): [(4.5, 6)],
("q", "a"): [(1.0, 9)],
("q", "b"): [(2.5, 8), (3.0, 7)],
("k", "a"): [(4.5, 6)],
}
assert df.rows_by_key(["w", "x"], include_key=True) == {
("a", "q"): [("a", "q", 1.0, 9)],
Expand Down