Skip to content

Commit

Permalink
remaining methods
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Sep 16, 2024
1 parent 994bb07 commit 47be886
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 78 deletions.
11 changes: 8 additions & 3 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,17 @@ def select(

col_order = list(new_series.keys())

left_most_series = next( # pragma: no cover
s for s in new_series.values() if not isinstance(s, de._collection.Scalar)
left_most_name, left_most_series = next( # pragma: no cover
(name, s)
for name, s in new_series.items()
if not isinstance(s, de._collection.Scalar)
)
new_series.pop(left_most_name)

return self._from_native_frame(
left_most_series.to_frame().assign(**new_series).loc[:, col_order]
left_most_series.to_frame(name=left_most_name)
.assign(**new_series)
.loc[:, col_order]
)

def drop_nulls(self: Self, subset: str | list[str] | None) -> Self:
Expand Down
111 changes: 77 additions & 34 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,31 +688,55 @@ def func(df: DaskLazyFrame) -> list[Any]:
backend_version=self._backend_version,
)

def mode(self: Self) -> Self:
msg = "`Expr.mode` is not supported for the Dask backend."
raise NotImplementedError(msg)
def cast(
self: Self,
dtype: DType | type[DType],
) -> Self:
def func(_input: Any, dtype: DType | type[DType]) -> Any:
dtype = reverse_translate_dtype(dtype)
return _input.astype(dtype)

return self._from_call(
func,
"cast",
dtype,
returns_scalar=False,
modifies_index=False,
)

# Index modifiers

def sort(self: Self, *, descending: bool = False, nulls_last: bool = False) -> Self:
na_position = "last" if nulls_last else "first"
msg = "`Expr.sort` is not supported for the Dask backend. Please use `LazyFrame.sort` instead."
raise NotImplementedError(msg)

def func(_input: Any, ascending: bool, na_position: bool) -> Any: # noqa: FBT001
name = _input.name
def gather_every(self: Self, n: int, offset: int = 0) -> NoReturn:
msg = "`Expr.gather_every` is not supported for the Dask backend. Please use `LazyFrame.gather_every` instead."
raise NotImplementedError(msg)

def sample(
self: Self,
n: int | None = None,
*,
fraction: float | None = None,
with_replacement: bool = False,
seed: int | None = None,
) -> NoReturn:
msg = "`Expr.sample` is not supported for the Dask backend."
raise NotImplementedError(msg)

return _input.to_frame(name=name).sort_values(
by=name, ascending=ascending, na_position=na_position
)[name]
def mode(self: Self) -> Self:
def func(_input: Any) -> Any:
name = _input.name
return _input.to_frame(name=name).mode()[name]

return self._from_call(
func,
"sort",
not descending,
na_position,
"mode",
returns_scalar=False,
modifies_index=True,
)

# Index modifiers

def drop_nulls(self: Self) -> Self:
return self._from_call(
lambda _input: _input.dropna(),
Expand Down Expand Up @@ -753,10 +777,45 @@ def unique(self: Self) -> Self:
modifies_index=True,
)

def gather_every(self: Self, n: int, offset: int = 0) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.gather_every` is not supported for the Dask backend. Please use `LazyFrame.gather_every` instead."
raise NotImplementedError(msg)
def filter(self: Self, *predicates: Any) -> Self:
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(*predicates)

def func(df: DaskLazyFrame) -> list[Any]:
if self._output_names is None:
msg = (
"Anonymous expressions are not supported in filter.\n"
"Instead of `nw.all()`, try using a named expression, such as "
"`nw.col('a', 'b')`\n"
)
raise ValueError(msg)
mask = expr._call(df)[0]
return [df._native_frame[name].loc[mask] for name in self._output_names]

return self.__class__(
func,
depth=self._depth + 1,
function_name=self._function_name + "->filter",
root_names=self._root_names,
output_names=self._output_names,
returns_scalar=False,
modifies_index=True,
backend_version=self._backend_version,
)

def arg_true(self: Self) -> Self:
def func(_input: dask_expr.Series) -> dask_expr.Series:
name = _input.name
return add_row_index(_input.to_frame(name=name), name).loc[_input, name]

return self._from_call(
func,
"arg_true",
returns_scalar=False,
modifies_index=True,
)

# Namespaces

@property
def str(self: Self) -> DaskExprStringNamespace:
Expand All @@ -770,22 +829,6 @@ def dt(self: Self) -> DaskExprDateTimeNamespace:
def name(self: Self) -> DaskExprNameNamespace:
return DaskExprNameNamespace(self)

def cast(
self: Self,
dtype: DType | type[DType],
) -> Self:
def func(_input: Any, dtype: DType | type[DType]) -> Any:
dtype = reverse_translate_dtype(dtype)
return _input.astype(dtype)

return self._from_call(
func,
"cast",
dtype,
returns_scalar=False,
modifies_index=False,
)


class DaskExprStringNamespace:
def __init__(self, expr: DaskExpr) -> None:
Expand Down
6 changes: 1 addition & 5 deletions tests/expr_and_series/arg_true_test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from typing import Any

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import compare_dicts


def test_arg_true(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_arg_true(constructor: Constructor) -> None:
df = nw.from_native(constructor({"a": [1, None, None, 3]}))
result = df.select(nw.col("a").is_null().arg_true())
expected = {"a": [1, 2]}
Expand Down
7 changes: 4 additions & 3 deletions tests/expr_and_series/cat/get_categories_test.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from __future__ import annotations

from typing import Any

import pyarrow as pa
import pytest

import narwhals.stable.v1 as nw
from narwhals.utils import parse_version
from tests.utils import Constructor
from tests.utils import compare_dicts

data = {"a": ["one", "two", "two"]}


def test_get_categories(request: pytest.FixtureRequest, constructor_eager: Any) -> None:
def test_get_categories(
request: pytest.FixtureRequest, constructor_eager: Constructor
) -> None:
if "pyarrow_table" in str(constructor_eager) and parse_version(
pa.__version__
) < parse_version("15.0.0"):
Expand Down
15 changes: 12 additions & 3 deletions tests/expr_and_series/filter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,24 @@
}


def test_filter(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_filter_single_expr(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.col("a").filter(nw.col("i") < 2, nw.col("c") == 5))
expected = {"a": [0]}
compare_dicts(result, expected)


def test_filter_multi_expr(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(nw.col("a").filter(nw.col("i") < 2, nw.col("c") == 5), nw.col("b"))
expected = {"a": [0] * 5, "b": [1, 2, 3, 5, 3]}
compare_dicts(result, expected)


def test_filter_series(constructor_eager: Any) -> None:
df = nw.from_native(constructor_eager(data), eager_only=True)
result = df.select(df["a"].filter((df["i"] < 2) & (df["c"] == 5)))
Expand Down
6 changes: 1 addition & 5 deletions tests/expr_and_series/len_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Any

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import compare_dicts
Expand All @@ -18,11 +16,9 @@ def test_len_no_filter(constructor: Constructor) -> None:
compare_dicts(df, expected)


def test_len_chaining(constructor: Constructor, request: pytest.FixtureRequest) -> None:
def test_len_chaining(constructor: Constructor) -> None:
data = {"a": list("xyz"), "b": [1, 2, 1]}
expected = {"a1": [2], "a2": [1]}
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data)).select(
nw.col("a").filter(nw.col("b") == 1).len().alias("a1"),
nw.col("a").filter(nw.col("b") == 2).len().alias("a2"),
Expand Down
7 changes: 1 addition & 6 deletions tests/expr_and_series/mode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@
}


def test_mode_single_expr(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_mode_single_expr(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.col("a").mode()).sort("a")
expected = {"a": [1, 2]}
Expand Down
34 changes: 23 additions & 11 deletions tests/expr_and_series/sort_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import compare_dicts

data = {"a": [0, 0, 2, -1], "b": [1, 3, 2, None]}
Expand All @@ -18,8 +19,14 @@
],
)
def test_sort_single_expr(
constructor: Any, descending: Any, nulls_last: Any, expected: Any
constructor: Constructor,
descending: bool, # noqa: FBT001
nulls_last: bool, # noqa: FBT001
expected: dict[str, float],
request: pytest.FixtureRequest,
) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(nw.col("b").sort(descending=descending, nulls_last=nulls_last))
compare_dicts(result, expected)
Expand All @@ -35,7 +42,11 @@ def test_sort_single_expr(
],
)
def test_sort_multiple_expr(
constructor: Any, descending: Any, nulls_last: Any, expected: Any, request: Any
constructor: Constructor,
descending: bool, # noqa: FBT001
nulls_last: bool, # noqa: FBT001
expected: dict[str, float],
request: pytest.FixtureRequest,
) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
Expand All @@ -51,17 +62,18 @@ def test_sort_multiple_expr(
@pytest.mark.parametrize(
("descending", "nulls_last", "expected"),
[
(True, True, [3, 2, 1, None]),
(True, False, [None, 3, 2, 1]),
(False, True, [1, 2, 3, None]),
(False, False, [None, 1, 2, 3]),
(True, True, [3, 2, 1, float("nan")]),
(True, False, [float("nan"), 3, 2, 1]),
(False, True, [1, 2, 3, float("nan")]),
(False, False, [float("nan"), 1, 2, 3]),
],
)
def test_sort_series(
constructor_eager: Any, descending: Any, nulls_last: Any, expected: Any
constructor_eager: Constructor,
descending: bool, # noqa: FBT001
nulls_last: bool, # noqa: FBT001
expected: dict[str, float],
) -> None:
series = nw.from_native(constructor_eager(data), eager_only=True)["b"]
result = series.sort(descending=descending, nulls_last=nulls_last)
assert (
result == nw.from_native(constructor_eager({"a": expected}), eager_only=True)["a"]
)
compare_dicts({"b": result}, {"b": expected})
6 changes: 2 additions & 4 deletions tpch/queries/q20.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,15 @@ def query(

return (
part_ds.filter(nw.col("p_name").str.starts_with(var4))
.select("p_partkey")
.unique("p_partkey")
.select(nw.col("p_partkey").unique())
.join(partsupp_ds, left_on="p_partkey", right_on="ps_partkey")
.join(
query1,
left_on=["ps_suppkey", "p_partkey"],
right_on=["l_suppkey", "l_partkey"],
)
.filter(nw.col("ps_availqty") > nw.col("sum_quantity"))
.select("ps_suppkey")
.unique("ps_suppkey")
.select(nw.col("ps_suppkey").unique())
.join(query3, left_on="ps_suppkey", right_on="s_suppkey")
.select("s_name", "s_address")
.sort("s_name")
Expand Down
6 changes: 2 additions & 4 deletions tpch/queries/q22.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ def query(customer_ds: FrameT, orders_ds: FrameT) -> FrameT:
nw.col("c_acctbal").mean().alias("avg_acctbal")
)

q3 = (
orders_ds.select("o_custkey")
.unique("o_custkey")
.with_columns(nw.col("o_custkey").alias("c_custkey"))
q3 = orders_ds.select(nw.col("o_custkey").unique()).with_columns(
nw.col("o_custkey").alias("c_custkey")
)

return (
Expand Down

0 comments on commit 47be886

Please sign in to comment.