Skip to content

Commit

Permalink
misc: make when the chaining stable
Browse files Browse the repository at this point in the history
  • Loading branch information
aivanoved committed Jul 30, 2024
1 parent c0d2f7a commit 10245c4
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 14 deletions.
72 changes: 66 additions & 6 deletions narwhals/stable/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from narwhals.dtypes import UInt32
from narwhals.dtypes import UInt64
from narwhals.dtypes import Unknown
from narwhals.expr import ChainedThen as NwChainedThen
from narwhals.expr import ChainedWhen as NwChainedWhen
from narwhals.expr import Expr as NwExpr
from narwhals.expr import Then as NwThen
from narwhals.expr import When as NwWhen
Expand Down Expand Up @@ -479,12 +481,38 @@ def _stableify(obj: NwSeries) -> Series: ...
@overload
def _stableify(obj: NwExpr) -> Expr: ...
@overload
def _stableify(when_then: NwWhen) -> When: ...
@overload
def _stableify(when_then: NwChainedWhen) -> ChainedWhen: ...
@overload
def _stableify(when_then: NwThen) -> Then: ...
@overload
def _stableify(when_then: NwChainedThen) -> ChainedThen: ...
@overload
def _stableify(obj: Any) -> Any: ...


def _stableify(
obj: NwDataFrame[IntoFrameT] | NwLazyFrame[IntoFrameT] | NwSeries | NwExpr | Any,
) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series | Expr | Any:
obj: NwDataFrame[IntoFrameT]
| NwLazyFrame[IntoFrameT]
| NwSeries
| NwExpr
| NwWhen
| NwChainedWhen
| NwThen
| NwChainedThen
| Any,
) -> (
DataFrame[IntoFrameT]
| LazyFrame[IntoFrameT]
| Series
| Expr
| When
| ChainedWhen
| Then
| ChainedThen
| Any
):
if isinstance(obj, NwDataFrame):
return DataFrame(
obj._compliant_frame,
Expand All @@ -500,6 +528,14 @@ def _stableify(
obj._compliant_series,
level=obj._level,
)
elif isinstance(obj, NwChainedWhen):
return ChainedWhen.from_base(obj)
if isinstance(obj, NwWhen):
return When.from_base(obj)
elif isinstance(obj, NwChainedThen):
return ChainedThen.from_base(obj)
elif isinstance(obj, NwThen):
return Then.from_base(obj)
if isinstance(obj, NwExpr):
return Expr(obj._call)
return obj
Expand Down Expand Up @@ -1474,18 +1510,42 @@ def get_level(

class When(NwWhen):
@classmethod
def from_when(cls, when: NwWhen) -> Self:
def from_base(cls, when: NwWhen) -> Self:
return cls(*when._predicates)

def then(self, value: Any) -> Then:
return Then.from_then(super().then(value))
return _stableify(super().then(value))


class Then(NwThen, Expr):
@classmethod
def from_then(cls, then: NwThen) -> Self:
def from_base(cls, then: NwThen) -> Self:
return cls(then._call)

def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen:
return _stableify(super().when(*predicates))

def otherwise(self, value: Any) -> Expr:
return _stableify(super().otherwise(value))


class ChainedWhen(NwChainedWhen):
@classmethod
def from_base(cls, chained_when: NwChainedWhen) -> Self:
return cls(_stableify(chained_when._above_then), *chained_when._predicates)

def then(self, value: Any) -> ChainedThen:
return _stableify(super().then(value))


class ChainedThen(NwChainedThen, Expr):
@classmethod
def from_base(cls, chained_then: NwChainedThen) -> Self:
return cls(chained_then._call)

def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen:
return _stableify(super().when(*predicates))

def otherwise(self, value: Any) -> Expr:
return _stableify(super().otherwise(value))

Expand Down Expand Up @@ -1535,7 +1595,7 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When:
│ 3 ┆ 15 ┆ 6 │
└─────┴─────┴────────┘
"""
return When.from_when(nw_when(*predicates))
return _stableify(nw_when(*predicates))


def from_dict(
Expand Down
17 changes: 9 additions & 8 deletions tests/expr_and_series/when_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

import pytest

import narwhals as nw
from narwhals import when
import narwhals.stable.v1 as nw
from tests.utils import compare_dicts

data = {
Expand All @@ -21,7 +20,7 @@ def test_when(request: Any, constructor: Any) -> None:
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
result = df.with_columns(when(nw.col("a") == 1).then(value=3).alias("a_when"))
result = df.with_columns(nw.when(nw.col("a") == 1).then(value=3).alias("a_when"))
expected = {
"a": [1, 2, 3, 4, 5],
"b": ["a", "b", "c", "d", "e"],
Expand All @@ -37,7 +36,9 @@ def test_when_otherwise(request: Any, constructor: Any) -> None:
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
result = df.with_columns(when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when"))
result = df.with_columns(
nw.when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when")
)
expected = {
"a": [1, 2, 3, 4, 5],
"b": ["a", "b", "c", "d", "e"],
Expand All @@ -54,7 +55,7 @@ def test_chained_when(request: Any, constructor: Any) -> None:

df = nw.from_native(constructor(data))
result = df.with_columns(
when(nw.col("a") == 1)
nw.when(nw.col("a") == 1)
.then(3)
.when(nw.col("a") == 2)
.then(5)
Expand All @@ -76,7 +77,7 @@ def test_when_with_multiple_conditions(request: Any, constructor: Any) -> None:
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.with_columns(
when(nw.col("a") == 1)
nw.when(nw.col("a") == 1)
.then(3)
.when(nw.col("a") == 2)
.then(5)
Expand All @@ -101,7 +102,7 @@ def test_multiple_conditions(request: Any, constructor: Any) -> None:

df = nw.from_native(constructor(data))
result = df.with_columns(
when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when")
nw.when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when")
)
expected = {
"a": [1, 2, 3, 4, 5],
Expand All @@ -119,4 +120,4 @@ def test_no_arg_when_fail(request: Any, constructor: Any) -> None:

df = nw.from_native(constructor(data))
with pytest.raises(TypeError):
df.with_columns(when().then(value=3).alias("a_when"))
df.with_columns(nw.when().then(value=3).alias("a_when"))

0 comments on commit 10245c4

Please sign in to comment.