Skip to content

Commit

Permalink
feat: Support more indexing: boolean lists to DataFrame.filter and Se…
Browse files Browse the repository at this point in the history
…ries.filter, add DataFrame.row, (#847)
  • Loading branch information
MarcoGorelli authored Aug 23, 2024
1 parent 037e047 commit a1e90db
Show file tree
Hide file tree
Showing 15 changed files with 139 additions and 39 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/dataframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
- null_count
- pipe
- rename
- row
- rows
- schema
- select
Expand Down
4 changes: 2 additions & 2 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from narwhals.dtypes import UInt64
from narwhals.dtypes import Unknown
from narwhals.expr import Expr
from narwhals.expr import all
from narwhals.expr import all_ as all
from narwhals.expr import all_horizontal
from narwhals.expr import any_horizontal
from narwhals.expr import col
from narwhals.expr import len
from narwhals.expr import len_ as len
from narwhals.expr import lit
from narwhals.expr import max
from narwhals.expr import mean
Expand Down
20 changes: 15 additions & 5 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def shape(self) -> tuple[int, int]:
def __len__(self) -> int:
return len(self._native_frame)

def row(self, index: int) -> tuple[Any, ...]:
return tuple(col[index] for col in self._native_frame)

def rows(
self, *, named: bool = False
) -> list[tuple[Any, ...]] | list[dict[str, Any]]:
Expand Down Expand Up @@ -377,11 +380,18 @@ def filter(
self,
*predicates: IntoArrowExpr,
) -> Self:
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(*predicates)
# Safety: all_horizontal's expression only returns a single column.
mask = expr._call(self)[0]
return self._from_native_frame(self._native_frame.filter(mask._native_series))
if (
len(predicates) == 1
and isinstance(predicates[0], list)
and all(isinstance(x, bool) for x in predicates[0])
):
mask = predicates[0]
else:
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(*predicates)
# Safety: all_horizontal's expression only returns a single column.
mask = expr._call(self)[0]._native_series
return self._from_native_frame(self._native_frame.filter(mask))

def null_count(self) -> Self:
import pyarrow as pa # ignore-banned-import()
Expand Down
3 changes: 2 additions & 1 deletion narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def len(self) -> int:
return len(self._native_series)

def filter(self, other: Any) -> Self:
other = validate_column_comparand(other)
if not (isinstance(other, list) and all(isinstance(x, bool) for x in other)):
other = validate_column_comparand(other)
return self._from_native_series(self._native_series.filter(other))

def mean(self) -> int:
Expand Down
17 changes: 12 additions & 5 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,19 @@ def filter(
self,
*predicates: DaskExpr,
) -> Self:
from narwhals._dask.namespace import DaskNamespace
if (
len(predicates) == 1
and isinstance(predicates[0], list)
and all(isinstance(x, bool) for x in predicates[0])
):
mask = predicates[0]
else:
from narwhals._dask.namespace import DaskNamespace

plx = DaskNamespace(backend_version=self._backend_version)
expr = plx.all_horizontal(*predicates)
# Safety: all_horizontal's expression only returns a single column.
mask = expr._call(self)[0]
plx = DaskNamespace(backend_version=self._backend_version)
expr = plx.all_horizontal(*predicates)
# Safety: all_horizontal's expression only returns a single column.
mask = expr._call(self)[0]
return self._from_native_frame(self._native_frame.loc[mask])

def lazy(self) -> Self:
Expand Down
18 changes: 14 additions & 4 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,15 +283,25 @@ def with_row_index(self, name: str) -> Self:
)
)

def row(self, row: int) -> tuple[Any, ...]:
return tuple(x for x in self._native_frame.iloc[row])

def filter(
self,
*predicates: IntoPandasLikeExpr,
) -> Self:
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(*predicates)
# Safety: all_horizontal's expression only returns a single column.
mask = expr._call(self)[0]
_mask = validate_dataframe_comparand(self._native_frame.index, mask)
if (
len(predicates) == 1
and isinstance(predicates[0], list)
and all(isinstance(x, bool) for x in predicates[0])
):
_mask = predicates[0]
else:
expr = plx.all_horizontal(*predicates)
# Safety: all_horizontal's expression only returns a single column.
mask = expr._call(self)[0]
_mask = validate_dataframe_comparand(self._native_frame.index, mask)
return self._from_native_frame(self._native_frame.loc[_mask])

def with_columns(
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def arg_true(self) -> Self:

def filter(self, *predicates: Any) -> Self:
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(*predicates)
return reuse_series_implementation(self, "filter", other=expr)
other = plx.all_horizontal(*predicates)
return reuse_series_implementation(self, "filter", other=other)

def drop_nulls(self) -> Self:
return reuse_series_implementation(self, "drop_nulls")
Expand Down
3 changes: 2 additions & 1 deletion narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ def arg_true(self) -> PandasLikeSeries:

def filter(self, other: Any) -> PandasLikeSeries:
ser = self._native_series
other = validate_column_comparand(self._native_series.index, other)
if not (isinstance(other, list) and all(isinstance(x, bool) for x in other)):
other = validate_column_comparand(self._native_series.index, other)
return self._from_native_series(self._rename(ser.loc[other], ser.name))

def __eq__(self, other: object) -> PandasLikeSeries: # type: ignore[override]
Expand Down
55 changes: 49 additions & 6 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,13 @@ def unique(
)
)

def filter(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> Self:
predicates, _ = self._flatten_and_extract(*predicates)
def filter(self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool]) -> Self:
if not (
len(predicates) == 1
and isinstance(predicates[0], list)
and all(isinstance(x, bool) for x in predicates[0])
):
predicates, _ = self._flatten_and_extract(*predicates)
return self._from_compliant_dataframe(
self._compliant_frame.filter(*predicates),
)
Expand Down Expand Up @@ -599,6 +604,8 @@ def __getitem__(
2
]
"""
if isinstance(item, int):
item = [item]
if (
isinstance(item, tuple)
and len(item) == 2
Expand Down Expand Up @@ -693,6 +700,40 @@ def to_dict(
}
return self._compliant_frame.to_dict(as_series=as_series) # type: ignore[no-any-return]

def row(self, index: int) -> tuple[Any, ...]:
"""
Get values at given row.
!!!note
You should NEVER use this method to iterate over a DataFrame;
if you require row-iteration you should strongly prefer use of iter_rows() instead.
Arguments:
index: Row number.
Examples:
>>> import narwhals as nw
>>> import pandas as pd
>>> import polars as pl
>>> data = {"a": [1, 2, 3], "b": [4, 5, 6]}
>>> df_pd = pd.DataFrame(data)
>>> df_pl = pl.DataFrame(data)
Let's define a library-agnostic function to get the second row.
>>> @nw.narwhalify
... def func(df):
... return df.row(1)
We can then pass pandas / Polars / any other supported library:
>>> func(df_pd)
(2, 5)
>>> func(df_pl)
(2, 5)
"""
return self._compliant_frame.row(index) # type: ignore[no-any-return]

# inherited
def pipe(self, function: Callable[[Any], Self], *args: Any, **kwargs: Any) -> Self:
"""
Expand Down Expand Up @@ -1483,14 +1524,15 @@ def unique(
"""
return super().unique(subset, keep=keep, maintain_order=maintain_order)

def filter(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> Self:
def filter(self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool]) -> Self:
r"""
Filter the rows in the DataFrame based on one or more predicate expressions.
The original order of the remaining rows is preserved.
Arguments:
predicates: Expression(s) that evaluates to a boolean Series.
*predicates: Expression(s) that evaluates to a boolean Series. Can
also be a (single!) boolean list.
Examples:
>>> import pandas as pd
Expand Down Expand Up @@ -2949,14 +2991,15 @@ def unique(
"""
return super().unique(subset, keep=keep, maintain_order=maintain_order)

def filter(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> Self:
def filter(self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool]) -> Self:
r"""
Filter the rows in the LazyFrame based on a predicate expression.
The original order of the remaining rows is preserved.
Arguments:
*predicates: Expression that evaluates to a boolean Series.
*predicates: Expression that evaluates to a boolean Series. Can
also be a (single!) boolean list.
Examples:
>>> import pandas as pd
Expand Down
6 changes: 4 additions & 2 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3657,7 +3657,8 @@ def func(plx: Any) -> Any:
return Expr(func)


def all() -> Expr:
# Add underscore so it doesn't conflict with builtin `all`
def all_() -> Expr:
"""
Instantiate an expression representing all columns.
Expand Down Expand Up @@ -3696,7 +3697,8 @@ def all() -> Expr:
return Expr(lambda plx: plx.all())


def len() -> Expr:
# Add underscore so it doesn't conflict with builtin `len`
def len_() -> Expr:
"""
Return the number of rows.
Expand Down
3 changes: 3 additions & 0 deletions tests/expr_and_series/filter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ def test_filter_series(constructor_eager: Any) -> None:
result = df.select(df["a"].filter((df["i"] < 2) & (df["c"] == 5)))
expected = {"a": [0]}
compare_dicts(result, expected)
result_s = df["a"].filter([True, False, False, False, False])
expected = {"a": [0]}
compare_dicts({"a": result_s}, expected)
8 changes: 3 additions & 5 deletions tests/frame/filter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ def test_filter(constructor: Any) -> None:
compare_dicts(result, expected)


def test_filter_series(constructor_eager: Any) -> None:
def test_filter_with_boolean_list(constructor: Any) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df = nw.from_native(constructor_eager(data), eager_only=True).with_columns(
mask=nw.col("a") > 1
)
result = df.filter(df["mask"]).drop("mask")
df = nw.from_native(constructor(data))
result = df.filter([False, True, True])
expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]}
compare_dicts(result, expected)
9 changes: 6 additions & 3 deletions tests/frame/get_column_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def test_non_string_name() -> None:
result = nw.from_native(df, eager_only=True).get_column(0) # type: ignore[arg-type]
compare_dicts({"a": result}, {"a": [1, 2]})
assert result.name == 0 # type: ignore[comparison-overlap]
with pytest.raises(TypeError, match="Expected str or slice"):
# Check that getitem would have raised
nw.from_native(df, eager_only=True)[0] # type: ignore[call-overload]


def test_get_single_row() -> None:
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
result = nw.from_native(df, eager_only=True)[0] # type: ignore[call-overload]
compare_dicts(result, {"a": [1], "b": [3]})
14 changes: 14 additions & 0 deletions tests/frame/row_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Any

import narwhals.stable.v1 as nw


def test_row_column(constructor_eager: Any) -> None:
data = {
"a": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
"b": [11, 12, 13, 14, 15, 16],
}
result = nw.from_native(constructor_eager(data), eager_only=True).row(2)
if "pyarrow_table" in str(constructor_eager):
result = tuple(x.as_py() for x in result)
assert result == (3.0, 13)
13 changes: 10 additions & 3 deletions tests/frame/slice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,16 @@ def test_slice_lazy_fails() -> None:
_ = nw.from_native(pl.LazyFrame(data))[1:]


def test_slice_int_fails(constructor_eager: Any) -> None:
with pytest.raises(TypeError, match="Expected str or slice, got: <class 'int'>"):
_ = nw.from_native(constructor_eager(data))[1] # type: ignore[call-overload,index]
def test_slice_int(constructor_eager: Any) -> None:
result = nw.from_native(constructor_eager(data), eager_only=True)[1] # type: ignore[call-overload]
compare_dicts(result, {"a": [2], "b": [12]})


def test_slice_fails(constructor_eager: Any) -> None:
class Foo: ...

with pytest.raises(TypeError, match="Expected str or slice, got:"):
nw.from_native(constructor_eager(data), eager_only=True)[Foo()] # type: ignore[call-overload]


def test_gather(constructor_eager: Any) -> None:
Expand Down

0 comments on commit a1e90db

Please sign in to comment.