Skip to content

Commit

Permalink
fix: Preserve Series name in __rpow__ operation (#20072)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Nov 30, 2024
1 parent ca8c1ef commit 954000c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
6 changes: 5 additions & 1 deletion py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,7 +1194,11 @@ def __pow__(self, exponent: int | float | Series) -> Series:
return self.pow(exponent)

def __rpow__(self, other: Any) -> Series:
return self.to_frame().select_seq(other ** F.col(self.name)).to_series()
return (
self.to_frame()
.select_seq((other ** F.col(self.name)).alias(self.name))
.to_series()
)

def __matmul__(self, other: Any) -> float | Series | None:
if isinstance(other, Sequence) or (
Expand Down
10 changes: 8 additions & 2 deletions py-polars/tests/unit/operations/arithmetic/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,8 +549,8 @@ def test_power_series() -> None:
assert_series_equal(a**j, pl.Series([1, 4], dtype=Int64))

# rpow
assert_series_equal(2.0**a, pl.Series("literal", [2.0, 4.0], dtype=Float64))
assert_series_equal(2**b, pl.Series("literal", [None, 4.0], dtype=Float64))
assert_series_equal(2.0**a, pl.Series(None, [2.0, 4.0], dtype=Float64))
assert_series_equal(2**b, pl.Series(None, [None, 4.0], dtype=Float64))

with pytest.raises(ColumnNotFoundError):
"hi" ** a
Expand All @@ -559,6 +559,12 @@ def test_power_series() -> None:
assert_series_equal(a.pow(2), pl.Series([1, 4], dtype=Int64))


def test_rpow_name_20071() -> None:
result = 1 ** pl.Series("a", [1, 2])
expected = pl.Series("a", [1, 1], pl.Int32)
assert_series_equal(result, expected)


@pytest.mark.parametrize(
("expected", "expr", "column_names"),
[
Expand Down

0 comments on commit 954000c

Please sign in to comment.