Skip to content

Commit

Permalink
feat: implement when/then/otherwise for PyArrow (#859)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Aug 24, 2024
1 parent 83dd6e1 commit 047de73
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 24 deletions.
121 changes: 121 additions & 0 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import cast

from narwhals import dtypes
from narwhals._arrow.dataframe import ArrowDataFrame
Expand Down Expand Up @@ -234,3 +235,123 @@ def min(self, *column_names: str) -> ArrowExpr:
@property
def selectors(self) -> ArrowSelectorNamespace:
return ArrowSelectorNamespace(backend_version=self._backend_version)

def when(
self,
*predicates: IntoArrowExpr,
) -> ArrowWhen:
plx = self.__class__(backend_version=self._backend_version)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)

return ArrowWhen(condition, self._backend_version)


class ArrowWhen:
def __init__(
self,
condition: ArrowExpr,
backend_version: tuple[int, ...],
then_value: Any = None,
otherwise_value: Any = None,
) -> None:
self._backend_version = backend_version
self._condition = condition
self._then_value = then_value
self._otherwise_value = otherwise_value

def __call__(self, df: ArrowDataFrame) -> list[ArrowSeries]:
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

from narwhals._arrow.namespace import ArrowNamespace
from narwhals._expression_parsing import parse_into_expr

plx = ArrowNamespace(backend_version=self._backend_version)

condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] # type: ignore[arg-type]
try:
value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] # type: ignore[arg-type]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
value_series = condition.__class__._from_iterable( # type: ignore[call-arg]
[self._then_value] * len(condition),
name="literal",
backend_version=self._backend_version,
)
value_series = cast(ArrowSeries, value_series)

value_series_native = value_series._native_series
condition_native = pc.invert(condition._native_series.combine_chunks())

if self._otherwise_value is None:
otherwise_native = pa.array(
[None] * len(condition_native), type=value_series_native.type
)
return [
value_series._from_native_series(
pc.replace_with_mask(
value_series_native, condition_native, otherwise_native
)
)
]
try:
otherwise_series = parse_into_expr(
self._otherwise_value, namespace=plx
)._call(df)[0] # type: ignore[arg-type]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
return [
value_series._from_native_series(
pc.replace_with_mask(
value_series_native, condition_native, self._otherwise_value
)
)
]
else:
otherwise_series = cast(ArrowSeries, otherwise_series)
condition = cast(ArrowSeries, condition)
return [value_series.zip_with(condition, otherwise_series)]

def then(self, value: ArrowExpr | ArrowSeries | Any) -> ArrowThen:
self._then_value = value

return ArrowThen(
self,
depth=0,
function_name="whenthen",
root_names=None,
output_names=None,
backend_version=self._backend_version,
)


class ArrowThen(ArrowExpr):
def __init__(
self,
call: ArrowWhen,
*,
depth: int,
function_name: str,
root_names: list[str] | None,
output_names: list[str] | None,
backend_version: tuple[int, ...],
) -> None:
self._backend_version = backend_version

self._call = call
self._depth = depth
self._function_name = function_name
self._root_names = root_names
self._output_names = output_names

def otherwise(self, value: ArrowExpr | ArrowSeries | Any) -> ArrowExpr:
# type ignore because we are setting the `_call` attribute to a
# callable object of type `PandasWhen`, base class has the attribute as
# only a `Callable`
self._call._otherwise_value = value # type: ignore[attr-defined]
self._function_name = "whenotherwise"
return self
7 changes: 4 additions & 3 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,11 +490,12 @@ def value_counts(
def zip_with(self: Self, mask: Self, other: Self) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

mask = pc.invert(mask._native_series.combine_chunks())
return self._from_native_series(
pc.replace_with_mask(
self._native_series.combine_chunks(),
pc.invert(mask._native_series.combine_chunks()),
other._native_series.combine_chunks(),
self._native_series,
mask,
other._native_series.combine_chunks().filter(mask),
)
)

Expand Down
40 changes: 19 additions & 21 deletions tests/expr_and_series/when_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import sys
from typing import Any

import numpy as np
import pytest

import narwhals.stable.v1 as nw
from tests.utils import compare_dicts
from tests.utils import is_windows

data = {
"a": [1, 2, 3],
Expand All @@ -18,7 +20,7 @@


def test_when(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor) or "dask" in str(constructor):
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand All @@ -30,7 +32,7 @@ def test_when(request: Any, constructor: Any) -> None:


def test_when_otherwise(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor) or "dask" in str(constructor):
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand All @@ -42,7 +44,7 @@ def test_when_otherwise(request: Any, constructor: Any) -> None:


def test_multiple_conditions(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor) or "dask" in str(constructor):
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand All @@ -56,7 +58,7 @@ def test_multiple_conditions(request: Any, constructor: Any) -> None:


def test_no_arg_when_fail(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor) or "dask" in str(constructor):
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand All @@ -65,7 +67,7 @@ def test_no_arg_when_fail(request: Any, constructor: Any) -> None:


def test_value_numpy_array(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor) or "dask" in str(constructor):
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand All @@ -80,10 +82,7 @@ def test_value_numpy_array(request: Any, constructor: Any) -> None:
compare_dicts(result, expected)


def test_value_series(request: Any, constructor_eager: Any) -> None:
if "pyarrow_table" in str(constructor_eager):
request.applymarker(pytest.mark.xfail)

def test_value_series(constructor_eager: Any) -> None:
df = nw.from_native(constructor_eager(data))
s_data = {"s": [3, 4, 5]}
s = nw.from_native(constructor_eager(s_data))["s"]
Expand All @@ -96,7 +95,7 @@ def test_value_series(request: Any, constructor_eager: Any) -> None:


def test_value_expression(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor) or "dask" in str(constructor):
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand All @@ -108,28 +107,27 @@ def test_value_expression(request: Any, constructor: Any) -> None:


def test_otherwise_numpy_array(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor) or "dask" in str(constructor):
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
if (
"pyarrow_table" in str(constructor) and is_windows() and sys.version_info < (3, 9)
): # pragma: no cover
# seriously...
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
import numpy as np

result = df.select(
nw.when(nw.col("a") == 1)
.then(-1)
.otherwise(np.asanyarray([0, 9, 10]))
.alias("a_when")
nw.when(nw.col("a") == 1).then(-1).otherwise(np.array([0, 9, 10])).alias("a_when")
)
expected = {
"a_when": [-1, 9, 10],
}
compare_dicts(result, expected)


def test_otherwise_series(request: Any, constructor_eager: Any) -> None:
if "pyarrow_table" in str(constructor_eager):
request.applymarker(pytest.mark.xfail)

def test_otherwise_series(constructor_eager: Any) -> None:
df = nw.from_native(constructor_eager(data))
s_data = {"s": [0, 9, 10]}
s = nw.from_native(constructor_eager(s_data))["s"]
Expand All @@ -142,7 +140,7 @@ def test_otherwise_series(request: Any, constructor_eager: Any) -> None:


def test_otherwise_expression(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor) or "dask" in str(constructor):
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand All @@ -156,7 +154,7 @@ def test_otherwise_expression(request: Any, constructor: Any) -> None:


def test_when_then_otherwise_into_expr(request: Any, constructor: Any) -> None:
if "pyarrow_table" in str(constructor) or "dask" in str(constructor):
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand Down

0 comments on commit 047de73

Please sign in to comment.