diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 9bb4be587..19c810a6e 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Iterable +from typing import cast from narwhals import dtypes from narwhals._arrow.dataframe import ArrowDataFrame @@ -234,3 +235,123 @@ def min(self, *column_names: str) -> ArrowExpr: @property def selectors(self) -> ArrowSelectorNamespace: return ArrowSelectorNamespace(backend_version=self._backend_version) + + def when( + self, + *predicates: IntoArrowExpr, + ) -> ArrowWhen: + plx = self.__class__(backend_version=self._backend_version) + if predicates: + condition = plx.all_horizontal(*predicates) + else: + msg = "at least one predicate needs to be provided" + raise TypeError(msg) + + return ArrowWhen(condition, self._backend_version) + + +class ArrowWhen: + def __init__( + self, + condition: ArrowExpr, + backend_version: tuple[int, ...], + then_value: Any = None, + otherwise_value: Any = None, + ) -> None: + self._backend_version = backend_version + self._condition = condition + self._then_value = then_value + self._otherwise_value = otherwise_value + + def __call__(self, df: ArrowDataFrame) -> list[ArrowSeries]: + import pyarrow as pa # ignore-banned-import + import pyarrow.compute as pc # ignore-banned-import + + from narwhals._arrow.namespace import ArrowNamespace + from narwhals._expression_parsing import parse_into_expr + + plx = ArrowNamespace(backend_version=self._backend_version) + + condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] # type: ignore[arg-type] + try: + value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] # type: ignore[arg-type] + except TypeError: + # `self._otherwise_value` is a scalar and can't be converted to an expression + value_series = condition.__class__._from_iterable( # type: ignore[call-arg] + [self._then_value] * len(condition), + name="literal", + backend_version=self._backend_version, + ) + value_series = cast(ArrowSeries, value_series) + + value_series_native = value_series._native_series + condition_native = pc.invert(condition._native_series.combine_chunks()) + + if self._otherwise_value is None: + otherwise_native = pa.array( + [None] * len(condition_native), type=value_series_native.type + ) + return [ + value_series._from_native_series( + pc.replace_with_mask( + value_series_native, condition_native, otherwise_native + ) + ) + ] + try: + otherwise_series = parse_into_expr( + self._otherwise_value, namespace=plx + )._call(df)[0] # type: ignore[arg-type] + except TypeError: + # `self._otherwise_value` is a scalar and can't be converted to an expression + return [ + value_series._from_native_series( + pc.replace_with_mask( + value_series_native, condition_native, self._otherwise_value + ) + ) + ] + else: + otherwise_series = cast(ArrowSeries, otherwise_series) + condition = cast(ArrowSeries, condition) + return [value_series.zip_with(condition, otherwise_series)] + + def then(self, value: ArrowExpr | ArrowSeries | Any) -> ArrowThen: + self._then_value = value + + return ArrowThen( + self, + depth=0, + function_name="whenthen", + root_names=None, + output_names=None, + backend_version=self._backend_version, + ) + + +class ArrowThen(ArrowExpr): + def __init__( + self, + call: ArrowWhen, + *, + depth: int, + function_name: str, + root_names: list[str] | None, + output_names: list[str] | None, + backend_version: tuple[int, ...], + ) -> None: + self._backend_version = backend_version + + self._call = call + self._depth = depth + self._function_name = function_name + self._root_names = root_names + self._output_names = output_names + + def otherwise(self, value: ArrowExpr | ArrowSeries | Any) -> ArrowExpr: + # type ignore because we are setting the `_call` attribute to a + # callable object of type `PandasWhen`, base class has the attribute as + # only a `Callable` + self._call._otherwise_value = value # type: ignore[attr-defined] + self._function_name = "whenotherwise" + return self diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 39394924f..c76acb5f3 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -490,11 +490,12 @@ def value_counts( def zip_with(self: Self, mask: Self, other: Self) -> Self: import pyarrow.compute as pc # ignore-banned-import() + mask = pc.invert(mask._native_series.combine_chunks()) return self._from_native_series( pc.replace_with_mask( - self._native_series.combine_chunks(), - pc.invert(mask._native_series.combine_chunks()), - other._native_series.combine_chunks(), + self._native_series, + mask, + other._native_series.combine_chunks().filter(mask), ) ) diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 741c676a8..e0a20f2ab 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from typing import Any import numpy as np @@ -7,6 +8,7 @@ import narwhals.stable.v1 as nw from tests.utils import compare_dicts +from tests.utils import is_windows data = { "a": [1, 2, 3], @@ -18,7 +20,7 @@ def test_when(request: Any, constructor: Any) -> None: - if "pyarrow_table" in str(constructor) or "dask" in str(constructor): + if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -30,7 +32,7 @@ def test_when(request: Any, constructor: Any) -> None: def test_when_otherwise(request: Any, constructor: Any) -> None: - if "pyarrow_table" in str(constructor) or "dask" in str(constructor): + if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -42,7 +44,7 @@ def test_when_otherwise(request: Any, constructor: Any) -> None: def test_multiple_conditions(request: Any, constructor: Any) -> None: - if "pyarrow_table" in str(constructor) or "dask" in str(constructor): + if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -56,7 +58,7 @@ def test_multiple_conditions(request: Any, constructor: Any) -> None: def test_no_arg_when_fail(request: Any, constructor: Any) -> None: - if "pyarrow_table" in str(constructor) or "dask" in str(constructor): + if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -65,7 +67,7 @@ def test_no_arg_when_fail(request: Any, constructor: Any) -> None: def test_value_numpy_array(request: Any, constructor: Any) -> None: - if "pyarrow_table" in str(constructor) or "dask" in str(constructor): + if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -80,10 +82,7 @@ def test_value_numpy_array(request: Any, constructor: Any) -> None: compare_dicts(result, expected) -def test_value_series(request: Any, constructor_eager: Any) -> None: - if "pyarrow_table" in str(constructor_eager): - request.applymarker(pytest.mark.xfail) - +def test_value_series(constructor_eager: Any) -> None: df = nw.from_native(constructor_eager(data)) s_data = {"s": [3, 4, 5]} s = nw.from_native(constructor_eager(s_data))["s"] @@ -96,7 +95,7 @@ def test_value_series(request: Any, constructor_eager: Any) -> None: def test_value_expression(request: Any, constructor: Any) -> None: - if "pyarrow_table" in str(constructor) or "dask" in str(constructor): + if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -108,17 +107,19 @@ def test_value_expression(request: Any, constructor: Any) -> None: def test_otherwise_numpy_array(request: Any, constructor: Any) -> None: - if "pyarrow_table" in str(constructor) or "dask" in str(constructor): + if "dask" in str(constructor): + request.applymarker(pytest.mark.xfail) + if ( + "pyarrow_table" in str(constructor) and is_windows() and sys.version_info < (3, 9) + ): # pragma: no cover + # seriously... request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) import numpy as np result = df.select( - nw.when(nw.col("a") == 1) - .then(-1) - .otherwise(np.asanyarray([0, 9, 10])) - .alias("a_when") + nw.when(nw.col("a") == 1).then(-1).otherwise(np.array([0, 9, 10])).alias("a_when") ) expected = { "a_when": [-1, 9, 10], @@ -126,10 +127,7 @@ def test_otherwise_numpy_array(request: Any, constructor: Any) -> None: compare_dicts(result, expected) -def test_otherwise_series(request: Any, constructor_eager: Any) -> None: - if "pyarrow_table" in str(constructor_eager): - request.applymarker(pytest.mark.xfail) - +def test_otherwise_series(constructor_eager: Any) -> None: df = nw.from_native(constructor_eager(data)) s_data = {"s": [0, 9, 10]} s = nw.from_native(constructor_eager(s_data))["s"] @@ -142,7 +140,7 @@ def test_otherwise_series(request: Any, constructor_eager: Any) -> None: def test_otherwise_expression(request: Any, constructor: Any) -> None: - if "pyarrow_table" in str(constructor) or "dask" in str(constructor): + if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -156,7 +154,7 @@ def test_otherwise_expression(request: Any, constructor: Any) -> None: def test_when_then_otherwise_into_expr(request: Any, constructor: Any) -> None: - if "pyarrow_table" in str(constructor) or "dask" in str(constructor): + if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data))