From 09bea00cbc5f70ca6a80b5a20c990677baf2e85c Mon Sep 17 00:00:00 2001 From: Alexander Ivanov Date: Thu, 12 Sep 2024 00:02:41 +0300 Subject: [PATCH] feat: add when then chaining back --- narwhals/_pandas_like/namespace.py | 260 +++++++++++++++++------------ narwhals/expr.py | 60 +++---- tests/expr_and_series/when_test.py | 127 ++++++++------ 3 files changed, 269 insertions(+), 178 deletions(-) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 753b49f69..dae72573c 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -291,6 +291,8 @@ def __init__( self._then_value = then_value self._otherwise_value = otherwise_value + self._already_set = self._condition + def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: from narwhals._expression_parsing import parse_into_expr from narwhals._pandas_like.namespace import PandasLikeNamespace @@ -372,7 +374,7 @@ def __init__( self._root_names = root_names self._output_names = output_names - def otherwise(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasLikeExpr: + def otherwise(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasThen: # type ignore because we are setting the `_call` attribute to a # callable object of type `PandasWhen`, base class has the attribute as # only a `Callable` @@ -380,106 +382,158 @@ def otherwise(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasLik self._function_name = "whenotherwise" return self + def when(self, *predicates: IntoPandasLikeExpr) -> PandasChainedWhen: + plx = PandasLikeNamespace(self._implementation, self._backend_version) + if predicates: + condition = plx.all_horizontal(*predicates) + else: + msg = "at least one predicate needs to be provided" + raise TypeError(msg) + return PandasChainedWhen( + self, condition, self._depth + 1, self._implementation, self._backend_version + ) + + +class PandasChainedWhen: + def __init__( + self, + above_then: PandasThen | PandasChainedThen, + condition: PandasLikeExpr, + depth: int, + implementation: Implementation, + backend_version: tuple[int, ...], + then_value: Any = None, + otherise_value: Any = None, + ) -> None: + self._implementation = implementation + self._depth = depth + self._backend_version = backend_version + self._condition = condition + self._above_then = above_then + self._then_value = then_value + self._otherwise_value = otherise_value + + # TODO @aivanoved: this is way slow as during computation time this takes + # quadratic time need to improve this to linear time + self._above_already_set = self._above_then._call._already_set # type: ignore[attr-defined] + self._already_set = self._above_already_set | self._condition -# class PandasChainedWhen: -# def __init__( -# self, -# above_when: PandasWhen | PandasChainedWhen, -# condition: PandasLikeExpr, -# depth: int, -# implementation: Implementation, -# backend_version: tuple[int, ...], -# then_value: Any = None, -# otherise_value: Any = None, -# ) -> None: -# self._implementation = implementation -# self._depth = depth -# self._backend_version = backend_version -# self._condition = condition -# self._above_when = above_when -# self._then_value = then_value -# self._otherwise_value = otherise_value -# -# # TODO @aivanoved: this is way slow as during computation time this takes -# # quadratic time need to improve this to linear time -# self._condition = self._condition & (~self._above_when._already_set) # type: ignore[has-type] -# self._already_set = self._above_when._already_set | self._condition # type: ignore[has-type] -# -# def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: -# from narwhals._pandas_like.namespace import PandasLikeNamespace -# -# plx = PandasLikeNamespace( -# implementation=self._implementation, backend_version=self._backend_version -# ) -# -# set_then = self._condition._call(df)[0] -# already_set = self._already_set._call(df)[0] -# -# value_series = plx._create_broadcast_series_from_scalar( -# self._then_value, set_then -# ) -# otherwise_series = plx._create_broadcast_series_from_scalar( -# self._otherwise_value, set_then -# ) -# -# above_result = self._above_when(df)[0] -# -# result = value_series.zip_with(set_then, above_result).zip_with( -# already_set, otherwise_series -# ) -# -# return [result] -# -# def then(self, value: Any) -> PandasChainedThen: -# self._then_value = value -# return PandasChainedThen( -# self, -# depth=self._depth, -# implementation=self._implementation, -# function_name="chainedwhen", -# root_names=None, -# output_names=None, -# backend_version=self._backend_version, -# ) -# -# -# class PandasChainedThen(PandasLikeExpr): -# def __init__( -# self, -# call: PandasChainedWhen, -# *, -# depth: int, -# function_name: str, -# root_names: list[str] | None, -# output_names: list[str] | None, -# implementation: Implementation, -# backend_version: tuple[int, ...], -# ) -> None: -# self._implementation = implementation -# self._backend_version = backend_version -# -# self._call = call -# self._depth = depth -# self._function_name = function_name -# self._root_names = root_names -# self._output_names = output_names -# -# def when( -# self, -# *predicates: IntoPandasLikeExpr | Iterable[IntoPandasLikeExpr], -# ) -> PandasChainedWhen: -# return PandasChainedWhen( -# self._call, # type: ignore[arg-type] -# when_processing( -# PandasLikeNamespace(self._implementation, self._backend_version), -# *predicates, -# ), -# depth=self._depth + 1, -# implementation=self._implementation, -# backend_version=self._backend_version, -# ) -# -# def otherwise(self, value: Any) -> PandasLikeExpr: -# self._call._otherwise_value = value # type: ignore[attr-defined] -# self._function_name = "chainedwhenotherwise" -# return self + def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: + from narwhals._expression_parsing import parse_into_expr + from narwhals._pandas_like.namespace import PandasLikeNamespace + + plx = PandasLikeNamespace( + implementation=self._implementation, backend_version=self._backend_version + ) + + condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] # type: ignore[arg-type] + try: + value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] # type: ignore[arg-type] + except TypeError: + # `self._otherwise_value` is a scalar and can't be converted to an expression + value_series = condition.__class__._from_iterable( # type: ignore[call-arg] + [self._then_value] * len(condition), + name="literal", + index=condition._native_series.index, + implementation=self._implementation, + backend_version=self._backend_version, + ) + value_series = cast(PandasLikeSeries, value_series) + + set_then = condition + set_then_native = set_then._native_series + above_already_set = parse_into_expr(self._above_already_set, namespace=plx)._call( + df # type: ignore[arg-type] + )[0] + + value_series_native = value_series._native_series + + above_result = self._above_then._call(df)[0] + above_result_native = above_result._native_series + set_then_native = set_then._native_series + above_already_set_native = above_already_set._native_series + if self._otherwise_value is None: + return [ + above_result._from_native_series( + value_series_native.where( + ~above_already_set_native & set_then_native, above_result_native + ) + ) + ] + + try: + otherwise_series = parse_into_expr( + self._otherwise_value, namespace=plx + )._call(df)[0] # type: ignore[arg-type] + except TypeError: + # `self._otherwise_value` is a scalar and can't be converted to an expression + otherwise_series = condition.__class__._from_iterable( # type: ignore[call-arg] + [self._otherwise_value] * len(condition), + name="literal", + index=condition._native_series.index, + implementation=self._implementation, + backend_version=self._backend_version, + ) + otherwise_series = cast(PandasLikeSeries, otherwise_series) + return [ + above_result.zip_with( + above_already_set, value_series.zip_with(set_then, otherwise_series) + ) + ] + + def then(self, value: Any) -> PandasChainedThen: + self._then_value = value + return PandasChainedThen( + self, + depth=self._depth, + implementation=self._implementation, + function_name="chainedwhen", + root_names=None, + output_names=None, + backend_version=self._backend_version, + ) + + +class PandasChainedThen(PandasLikeExpr): + def __init__( + self, + call: PandasChainedWhen, + *, + depth: int, + function_name: str, + root_names: list[str] | None, + output_names: list[str] | None, + implementation: Implementation, + backend_version: tuple[int, ...], + ) -> None: + self._implementation = implementation + self._backend_version = backend_version + + self._call = call + self._depth = depth + self._function_name = function_name + self._root_names = root_names + self._output_names = output_names + + def when( + self, + *predicates: IntoPandasLikeExpr, + ) -> PandasChainedWhen: + plx = PandasLikeNamespace(self._implementation, self._backend_version) + if predicates: + condition = plx.all_horizontal(*predicates) + else: + msg = "at least one predicate needs to be provided" + raise TypeError(msg) + return PandasChainedWhen( + self, + condition, + depth=self._depth + 1, + implementation=self._implementation, + backend_version=self._backend_version, + ) + + def otherwise(self, value: Any) -> PandasChainedThen: + self._call._otherwise_value = value # type: ignore[attr-defined] + self._function_name = "chainedwhenotherwise" + return self diff --git a/narwhals/expr.py b/narwhals/expr.py index 5c5ff7d2e..a8407915a 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3995,16 +3995,17 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: ) +def _extract_predicates(plx: Any, predicates: IntoExpr | Iterable[IntoExpr]) -> Any: + return [extract_compliant(plx, v) for v in flatten([predicates])] + + class When: def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None: self._predicates = flatten([predicates]) - def _extract_predicates(self, plx: Any) -> Any: - return [extract_compliant(plx, v) for v in self._predicates] - def then(self, value: Any) -> Then: return Then( - lambda plx: plx.when(*self._extract_predicates(plx)).then( + lambda plx: plx.when(*_extract_predicates(plx, self._predicates)).then( extract_compliant(plx, value) ) ) @@ -4017,36 +4018,39 @@ def __init__(self, call: Callable[[Any], Any]) -> None: def otherwise(self, value: Any) -> Expr: return Expr(lambda plx: self._call(plx).otherwise(extract_compliant(plx, value))) + def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen: + return ChainedWhen(self, *predicates) -# class ChainedWhen: -# def __init__( -# self, -# above_then: Then | ChainedThen, -# *predicates: IntoExpr | Iterable[IntoExpr], -# ) -> None: -# self._above_then = above_then -# self._predicates = predicates -# def then(self, value: Any) -> ChainedThen: -# return ChainedThen( -# lambda plx: self._above_then._call(plx) -# .when(*_extract_predicates(plx, flatten([self._predicates]))) -# .then(value) -# ) +class ChainedWhen: + def __init__( + self, + above_then: Then | ChainedThen, + *predicates: IntoExpr | Iterable[IntoExpr], + ) -> None: + self._above_then = above_then + self._predicates = flatten([predicates]) + def then(self, value: Any) -> ChainedThen: + return ChainedThen( + lambda plx: self._above_then._call(plx) + .when(*_extract_predicates(plx, self._predicates)) + .then(value) + ) -# class ChainedThen(Expr): -# def __init__(self, call: Callable[[Any], Any]) -> None: -# self._call = call -# def when( -# self, -# *predicates: IntoExpr | Iterable[IntoExpr], -# ) -> ChainedWhen: -# return ChainedWhen(self, *predicates) +class ChainedThen(Expr): + def __init__(self, call: Callable[[Any], Any]) -> None: + self._call = call -# def otherwise(self, value: Any) -> Expr: -# return Expr(lambda plx: self._call(plx).otherwise(value)) + def when( + self, + *predicates: IntoExpr | Iterable[IntoExpr], + ) -> ChainedWhen: + return ChainedWhen(self, *predicates) + + def otherwise(self, value: Any) -> Expr: + return Expr(lambda plx: self._call(plx).otherwise(value)) def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When: diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 50d69f5f5..0e3049440 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -16,6 +16,12 @@ "e": [7.0, 2.0, 1.1], } +large_data = { + "a": [1, 2, 3, 4, 5, 6], + "b": ["a", "b", "c", "d", "e", "f"], + "c": [True, False, True, False, True, False], +} + def test_when(constructor: Any) -> None: df = nw.from_native(constructor(data)) @@ -136,53 +142,80 @@ def test_when_then_otherwise_into_expr(request: Any, constructor: Any) -> None: request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) - result = df.select(nw.when(nw.col("a") > 1).then("c").otherwise("e")) - expected = {"c": [7, 5, 6]} + result = df.select(nw.when(nw.col("a") > 1).then("c").otherwise("e").alias("a_when")) + expected = {"a_when": [7, 5, 6]} + compare_dicts(result, expected) + + +def test_chained_when(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.select( + nw.when(nw.col("a") == 1).then(3).when(nw.col("a") == 2).then(5).alias("a_when"), + ) + expected = { + "a_when": [3, 5, np.nan], + } + compare_dicts(result, expected) + + +def test_chained_when_otherewise(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(data)) + result = df.select( + nw.when(nw.col("a") == 1) + .then(3) + .when(nw.col("a") == 2) + .then(5) + .otherwise(7) + .alias("a_when"), + ) + expected = { + "a_when": [3, 5, 7], + } compare_dicts(result, expected) -# def test_chained_when(request: Any, constructor: Any) -> None: -# if "pyarrow_table" in str(constructor): -# request.applymarker(pytest.mark.xfail) - -# df = nw.from_native(constructor(data)) -# result = df.with_columns( -# nw.when(nw.col("a") == 1) -# .then(3) -# .when(nw.col("a") == 2) -# .then(5) -# .otherwise(7) -# .alias("a_when"), -# ) -# expected = { -# "a": [1, 2, 3, 4, 5], -# "b": ["a", "b", "c", "d", "e"], -# "c": [4.1, 5.0, 6.0, 7.0, 8.0], -# "d": [True, False, True, False, True], -# "a_when": [3, 5, 7, 7, 7], -# } -# compare_dicts(result, expected) - - -# def test_when_with_multiple_conditions(request: Any, constructor: Any) -> None: -# if "pyarrow_table" in str(constructor): -# request.applymarker(pytest.mark.xfail) -# df = nw.from_native(constructor(data)) -# result = df.with_columns( -# nw.when(nw.col("a") == 1) -# .then(3) -# .when(nw.col("a") == 2) -# .then(5) -# .when(nw.col("a") == 3) -# .then(7) -# .otherwise(9) -# .alias("a_when"), -# ) -# expected = { -# "a": [1, 2, 3, 4, 5], -# "b": ["a", "b", "c", "d", "e"], -# "c": [4.1, 5.0, 6.0, 7.0, 8.0], -# "d": [True, False, True, False, True], -# "a_when": [3, 5, 7, 9, 9], -# } -# compare_dicts(result, expected) +def test_multi_chained_when(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(large_data)) + result = df.select( + nw.when(nw.col("a") == 1) + .then(3) + .when(nw.col("a") == 2) + .then(5) + .when(nw.col("a") == 3) + .then(7) + .alias("a_when"), + ) + expected = { + "a_when": [3, 5, 7, np.nan, np.nan, np.nan], + } + compare_dicts(result, expected) + + +def test_multi_chained_when_otherewise(request: Any, constructor: Any) -> None: + if "dask" in str(constructor) or "pyarrow_table" in str(constructor): + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(large_data)) + result = df.select( + nw.when(nw.col("a") == 1) + .then(3) + .when(nw.col("a") == 2) + .then(5) + .when(nw.col("a") == 3) + .then(7) + .otherwise(9) + .alias("a_when"), + ) + expected = { + "a_when": [3, 5, 7, 9, 9, 9], + } + compare_dicts(result, expected)