diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 37c889549..0be930556 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -280,6 +280,8 @@ def __init__( self._then_value = then_value self._otherwise_value = otherwise_value + self._already_set = self._condition + def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: from narwhals._expression_parsing import parse_into_expr from narwhals._pandas_like.namespace import PandasLikeNamespace @@ -361,10 +363,166 @@ def __init__( self._root_names = root_names self._output_names = output_names - def otherwise(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasLikeExpr: + def otherwise(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasThen: # type ignore because we are setting the `_call` attribute to a # callable object of type `PandasWhen`, base class has the attribute as # only a `Callable` self._call._otherwise_value = value # type: ignore[attr-defined] self._function_name = "whenotherwise" return self + + def when(self, *predicates: IntoPandasLikeExpr) -> PandasChainedWhen: + plx = PandasLikeNamespace(self._implementation, self._backend_version) + if predicates: + condition = plx.all_horizontal(*predicates) + else: + msg = "at least one predicate needs to be provided" + raise TypeError(msg) + return PandasChainedWhen( + self, condition, self._depth + 1, self._implementation, self._backend_version + ) + + +class PandasChainedWhen: + def __init__( + self, + above_then: PandasThen | PandasChainedThen, + condition: PandasLikeExpr, + depth: int, + implementation: Implementation, + backend_version: tuple[int, ...], + then_value: Any = None, + otherise_value: Any = None, + ) -> None: + self._implementation = implementation + self._depth = depth + self._backend_version = backend_version + self._condition = condition + self._above_then = above_then + self._then_value = then_value + self._otherwise_value = otherise_value + + # TODO @aivanoved: this is way slow as during computation time this takes + # quadratic time need to improve this to linear time + self._above_already_set = self._above_then._call._already_set # type: ignore[attr-defined] + self._already_set = self._above_already_set | self._condition + + def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: + from narwhals._expression_parsing import parse_into_expr + from narwhals._pandas_like.namespace import PandasLikeNamespace + + plx = PandasLikeNamespace( + implementation=self._implementation, backend_version=self._backend_version + ) + + condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] # type: ignore[arg-type] + try: + value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] # type: ignore[arg-type] + except TypeError: + # `self._then_value` is a scalar and can't be converted to an expression + value_series = condition.__class__._from_iterable( # type: ignore[call-arg] + [self._then_value] * len(condition), + name="literal", + index=condition._native_series.index, + implementation=self._implementation, + backend_version=self._backend_version, + ) + value_series = cast(PandasLikeSeries, value_series) + + set_then = condition + set_then_native = set_then._native_series + above_already_set = parse_into_expr(self._above_already_set, namespace=plx)._call( + df # type: ignore[arg-type] + )[0] + + value_series_native = value_series._native_series + + above_result = self._above_then._call(df)[0] + above_result_native = above_result._native_series + set_then_native = set_then._native_series + above_already_set_native = above_already_set._native_series + if self._otherwise_value is None: + return [ + above_result._from_native_series( + value_series_native.where( + ~above_already_set_native & set_then_native, above_result_native + ) + ) + ] + + try: + otherwise_series = parse_into_expr( + self._otherwise_value, namespace=plx + )._call(df)[0] # type: ignore[arg-type] + except TypeError: + # `self._otherwise_value` is a scalar and can't be converted to an expression + otherwise_series = condition.__class__._from_iterable( # type: ignore[call-arg] + [self._otherwise_value] * len(condition), + name="literal", + index=condition._native_series.index, + implementation=self._implementation, + backend_version=self._backend_version, + ) + otherwise_series = cast(PandasLikeSeries, otherwise_series) + return [ + above_result.zip_with( + above_already_set, value_series.zip_with(set_then, otherwise_series) + ) + ] + + def then(self, value: Any) -> PandasChainedThen: + self._then_value = value + return PandasChainedThen( + self, + depth=self._depth, + implementation=self._implementation, + function_name="chainedwhen", + root_names=None, + output_names=None, + backend_version=self._backend_version, + ) + + +class PandasChainedThen(PandasLikeExpr): + def __init__( + self, + call: PandasChainedWhen, + *, + depth: int, + function_name: str, + root_names: list[str] | None, + output_names: list[str] | None, + implementation: Implementation, + backend_version: tuple[int, ...], + ) -> None: + self._implementation = implementation + self._backend_version = backend_version + + self._call = call + self._depth = depth + self._function_name = function_name + self._root_names = root_names + self._output_names = output_names + + def when( + self, + *predicates: IntoPandasLikeExpr, + ) -> PandasChainedWhen: + plx = PandasLikeNamespace(self._implementation, self._backend_version) + if predicates: + condition = plx.all_horizontal(*predicates) + else: + msg = "at least one predicate needs to be provided" + raise TypeError(msg) + return PandasChainedWhen( + self, + condition, + depth=self._depth + 1, + implementation=self._implementation, + backend_version=self._backend_version, + ) + + def otherwise(self, value: Any) -> PandasChainedThen: + self._call._otherwise_value = value # type: ignore[attr-defined] + self._function_name = "chainedwhenotherwise" + return self diff --git a/narwhals/expr.py b/narwhals/expr.py index b39a59818..13c50dec7 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3994,16 +3994,17 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: ) +def _extract_predicates(plx: Any, predicates: IntoExpr | Iterable[IntoExpr]) -> Any: + return [extract_compliant(plx, v) for v in flatten([predicates])] + + class When: def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None: self._predicates = flatten([predicates]) - def _extract_predicates(self, plx: Any) -> Any: - return [extract_compliant(plx, v) for v in self._predicates] - def then(self, value: Any) -> Then: return Then( - lambda plx: plx.when(*self._extract_predicates(plx)).then( + lambda plx: plx.when(*_extract_predicates(plx, self._predicates)).then( extract_compliant(plx, value) ) ) @@ -4016,6 +4017,40 @@ def __init__(self, call: Callable[[Any], Any]) -> None: def otherwise(self, value: Any) -> Expr: return Expr(lambda plx: self._call(plx).otherwise(extract_compliant(plx, value))) + def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen: + return ChainedWhen(self, *predicates) + + +class ChainedWhen: + def __init__( + self, + above_then: Then | ChainedThen, + *predicates: IntoExpr | Iterable[IntoExpr], + ) -> None: + self._above_then = above_then + self._predicates = flatten([predicates]) + + def then(self, value: Any) -> ChainedThen: + return ChainedThen( + lambda plx: self._above_then._call(plx) + .when(*_extract_predicates(plx, self._predicates)) + .then(value) + ) + + +class ChainedThen(Expr): + def __init__(self, call: Callable[[Any], Any]) -> None: + self._call = call + + def when( + self, + *predicates: IntoExpr | Iterable[IntoExpr], + ) -> ChainedWhen: + return ChainedWhen(self, *predicates) + + def otherwise(self, value: Any) -> Expr: + return Expr(lambda plx: self._call(plx).otherwise(value)) + def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: """ diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index fa98fd96f..e5282c38b 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -33,6 +33,8 @@ from narwhals.dtypes import UInt32 from narwhals.dtypes import UInt64 from narwhals.dtypes import Unknown +from narwhals.expr import ChainedThen as NwChainedThen +from narwhals.expr import ChainedWhen as NwChainedWhen from narwhals.expr import Expr as NwExpr from narwhals.expr import Then as NwThen from narwhals.expr import When as NwWhen @@ -491,12 +493,34 @@ def _stableify(obj: NwSeries) -> Series: ... @overload def _stableify(obj: NwExpr) -> Expr: ... @overload +def _stableify(obj: NwWhen) -> When: ... +@overload +def _stableify(obj: NwChainedWhen) -> ChainedWhen: ... +@overload def _stableify(obj: Any) -> Any: ... def _stableify( - obj: NwDataFrame[IntoFrameT] | NwLazyFrame[IntoFrameT] | NwSeries | NwExpr | Any, -) -> DataFrame[IntoFrameT] | LazyFrame[IntoFrameT] | Series | Expr | Any: + obj: NwDataFrame[IntoFrameT] + | NwLazyFrame[IntoFrameT] + | NwSeries + | NwExpr + | NwWhen + | NwChainedWhen + | NwThen + | NwChainedThen + | Any, +) -> ( + DataFrame[IntoFrameT] + | LazyFrame[IntoFrameT] + | Series + | Expr + | When + | ChainedWhen + | Then + | ChainedThen + | Any +): if isinstance(obj, NwDataFrame): return DataFrame( obj._compliant_frame, @@ -512,6 +536,14 @@ def _stableify( obj._compliant_series, level=obj._level, ) + elif isinstance(obj, NwChainedWhen): + return ChainedWhen.from_base(obj) + if isinstance(obj, NwWhen): + return When.from_base(obj) + elif isinstance(obj, NwChainedThen): + return ChainedThen.from_base(obj) + elif isinstance(obj, NwThen): + return Then.from_base(obj) if isinstance(obj, NwExpr): return Expr(obj._call) return obj @@ -1692,21 +1724,45 @@ def get_level( class When(NwWhen): @classmethod - def from_when(cls, when: NwWhen) -> Self: + def from_base(cls, when: NwWhen) -> Self: return cls(*when._predicates) def then(self, value: Any) -> Then: - return Then.from_then(super().then(value)) + return Then.from_base(super().then(value)) class Then(NwThen, Expr): @classmethod - def from_then(cls, then: NwThen) -> Self: + def from_base(cls, then: NwThen) -> Self: return cls(then._call) def otherwise(self, value: Any) -> Expr: return _stableify(super().otherwise(value)) + def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen: + return _stableify(super().when(*predicates)) + + +class ChainedWhen(NwChainedWhen): + @classmethod + def from_base(cls, chained_when: NwChainedWhen) -> Self: + return cls(_stableify(chained_when._above_then), *chained_when._predicates) # type: ignore[arg-type] + + def then(self, value: Any) -> ChainedThen: + return _stableify(super().then(value)) # type: ignore[return-value] + + +class ChainedThen(NwChainedThen, Expr): + @classmethod + def from_base(cls, chained_then: NwChainedThen) -> Self: + return cls(chained_then._call) + + def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen: + return _stableify(super().when(*predicates)) + + def otherwise(self, value: Any) -> Expr: + return _stableify(super().otherwise(value)) + def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: """ @@ -1753,7 +1809,7 @@ def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: │ 3 ┆ 15 ┆ 6 │ └─────┴─────┴────────┘ """ - return When.from_when(nw_when(*predicates)) + return _stableify(nw_when(*predicates)) def new_series( diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index bcd796a4a..fbfc5f932 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -17,6 +17,12 @@ "e": [7.0, 2.0, 1.1], } +large_data = { + "a": [1, 2, 3, 4, 5, 6], + "b": ["a", "b", "c", "d", "e", "f"], + "c": [True, False, True, False, True, False], +} + def test_when(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) @@ -145,6 +151,107 @@ def test_when_then_otherwise_into_expr( request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) - result = df.select(nw.when(nw.col("a") > 1).then("c").otherwise("e")) - expected = {"c": [7, 5, 6]} + result = df.select(nw.when(nw.col("a") > 1).then("c").otherwise("e").alias("a_when")) + expected = {"a_when": [7, 5, 6]} + compare_dicts(result, expected) + + +def test_chained_when(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.select( + nw.when(nw.col("a") == 1).then(3).when(nw.col("a") == 2).then(5).alias("a_when"), + ) + expected = { + "a_when": [3, 5, np.nan], + } + compare_dicts(result, expected) + + +def test_chained_when_otherewise(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.select( + nw.when(nw.col("a") == 1) + .then(3) + .when(nw.col("a") == 2) + .then(5) + .otherwise(7) + .alias("a_when"), + ) + expected = { + "a_when": [3, 5, 7], + } + compare_dicts(result, expected) + + +def test_multi_chained_when(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(large_data)) + result = df.select( + nw.when(nw.col("a") == 1) + .then(3) + .when(nw.col("a") == 2) + .then(5) + .when(nw.col("a") == 3) + .then(7) + .alias("a_when"), + ) + expected = { + "a_when": [3, 5, 7, np.nan, np.nan, np.nan], + } + compare_dicts(result, expected) + + +def test_multi_chained_when_otherwise(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(large_data)) + result = df.select( + nw.when(nw.col("a") == 1) + .then(3) + .when(nw.col("a") == 2) + .then(5) + .when(nw.col("a") == 3) + .then(7) + .otherwise(9) + .alias("a_when"), + ) + expected = { + "a_when": [3, 5, 7, 9, 9, 9], + } compare_dicts(result, expected) + + +def test_then_when_no_condition(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + + with pytest.raises((TypeError, ValueError)): + df.select(nw.when(nw.col("a") == 1).then(value=3).when().then(value=7)) + + +def test_then_chained_when_no_condition(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + + with pytest.raises((TypeError, ValueError)): + df.select( + nw.when(nw.col("a") == 1) + .then(value=3) + .when(nw.col("a") == 3) + .then(value=7) + .when() + .then(value=9) + )