From a1e90db7c545f6cebcbf938708d248b3a4912f69 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Fri, 23 Aug 2024 17:04:28 +0100 Subject: [PATCH] feat: Support more indexing: boolean lists to DataFrame.filter and Series.filter, add DataFrame.row, (#847) --- docs/api-reference/dataframe.md | 1 + narwhals/__init__.py | 4 +- narwhals/_arrow/dataframe.py | 20 +++++++--- narwhals/_arrow/series.py | 3 +- narwhals/_dask/dataframe.py | 17 ++++++--- narwhals/_pandas_like/dataframe.py | 18 +++++++-- narwhals/_pandas_like/expr.py | 4 +- narwhals/_pandas_like/series.py | 3 +- narwhals/dataframe.py | 55 +++++++++++++++++++++++++--- narwhals/expr.py | 6 ++- tests/expr_and_series/filter_test.py | 3 ++ tests/frame/filter_test.py | 8 ++-- tests/frame/get_column_test.py | 9 +++-- tests/frame/row_test.py | 14 +++++++ tests/frame/slice_test.py | 13 +++++-- 15 files changed, 139 insertions(+), 39 deletions(-) create mode 100644 tests/frame/row_test.py diff --git a/docs/api-reference/dataframe.md b/docs/api-reference/dataframe.md index 08ad0f7c8..c144b4af0 100644 --- a/docs/api-reference/dataframe.md +++ b/docs/api-reference/dataframe.md @@ -26,6 +26,7 @@ - null_count - pipe - rename + - row - rows - schema - select diff --git a/narwhals/__init__.py b/narwhals/__init__.py index 70345995f..716ffeb5f 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -22,11 +22,11 @@ from narwhals.dtypes import UInt64 from narwhals.dtypes import Unknown from narwhals.expr import Expr -from narwhals.expr import all +from narwhals.expr import all_ as all from narwhals.expr import all_horizontal from narwhals.expr import any_horizontal from narwhals.expr import col -from narwhals.expr import len +from narwhals.expr import len_ as len from narwhals.expr import lit from narwhals.expr import max from narwhals.expr import mean diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 65a155e04..20e507166 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -63,6 +63,9 @@ def shape(self) -> tuple[int, int]: def __len__(self) -> int: return len(self._native_frame) + def row(self, index: int) -> tuple[Any, ...]: + return tuple(col[index] for col in self._native_frame) + def rows( self, *, named: bool = False ) -> list[tuple[Any, ...]] | list[dict[str, Any]]: @@ -377,11 +380,18 @@ def filter( self, *predicates: IntoArrowExpr, ) -> Self: - plx = self.__narwhals_namespace__() - expr = plx.all_horizontal(*predicates) - # Safety: all_horizontal's expression only returns a single column. - mask = expr._call(self)[0] - return self._from_native_frame(self._native_frame.filter(mask._native_series)) + if ( + len(predicates) == 1 + and isinstance(predicates[0], list) + and all(isinstance(x, bool) for x in predicates[0]) + ): + mask = predicates[0] + else: + plx = self.__narwhals_namespace__() + expr = plx.all_horizontal(*predicates) + # Safety: all_horizontal's expression only returns a single column. + mask = expr._call(self)[0]._native_series + return self._from_native_frame(self._native_frame.filter(mask)) def null_count(self) -> Self: import pyarrow as pa # ignore-banned-import() diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index fb15f3aaf..39394924f 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -233,7 +233,8 @@ def len(self) -> int: return len(self._native_series) def filter(self, other: Any) -> Self: - other = validate_column_comparand(other) + if not (isinstance(other, list) and all(isinstance(x, bool) for x in other)): + other = validate_column_comparand(other) return self._from_native_series(self._native_series.filter(other)) def mean(self) -> int: diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 35ef28bba..9774d6c8e 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -72,12 +72,19 @@ def filter( self, *predicates: DaskExpr, ) -> Self: - from narwhals._dask.namespace import DaskNamespace + if ( + len(predicates) == 1 + and isinstance(predicates[0], list) + and all(isinstance(x, bool) for x in predicates[0]) + ): + mask = predicates[0] + else: + from narwhals._dask.namespace import DaskNamespace - plx = DaskNamespace(backend_version=self._backend_version) - expr = plx.all_horizontal(*predicates) - # Safety: all_horizontal's expression only returns a single column. - mask = expr._call(self)[0] + plx = DaskNamespace(backend_version=self._backend_version) + expr = plx.all_horizontal(*predicates) + # Safety: all_horizontal's expression only returns a single column. + mask = expr._call(self)[0] return self._from_native_frame(self._native_frame.loc[mask]) def lazy(self) -> Self: diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index c6a51be2d..193955cbd 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -283,15 +283,25 @@ def with_row_index(self, name: str) -> Self: ) ) + def row(self, row: int) -> tuple[Any, ...]: + return tuple(x for x in self._native_frame.iloc[row]) + def filter( self, *predicates: IntoPandasLikeExpr, ) -> Self: plx = self.__narwhals_namespace__() - expr = plx.all_horizontal(*predicates) - # Safety: all_horizontal's expression only returns a single column. - mask = expr._call(self)[0] - _mask = validate_dataframe_comparand(self._native_frame.index, mask) + if ( + len(predicates) == 1 + and isinstance(predicates[0], list) + and all(isinstance(x, bool) for x in predicates[0]) + ): + _mask = predicates[0] + else: + expr = plx.all_horizontal(*predicates) + # Safety: all_horizontal's expression only returns a single column. + mask = expr._call(self)[0] + _mask = validate_dataframe_comparand(self._native_frame.index, mask) return self._from_native_frame(self._native_frame.loc[_mask]) def with_columns( diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 193b1786c..44154453d 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -227,8 +227,8 @@ def arg_true(self) -> Self: def filter(self, *predicates: Any) -> Self: plx = self.__narwhals_namespace__() - expr = plx.all_horizontal(*predicates) - return reuse_series_implementation(self, "filter", other=expr) + other = plx.all_horizontal(*predicates) + return reuse_series_implementation(self, "filter", other=other) def drop_nulls(self) -> Self: return reuse_series_implementation(self, "drop_nulls") diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 90434ebd5..9532c39c2 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -229,7 +229,8 @@ def arg_true(self) -> PandasLikeSeries: def filter(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + if not (isinstance(other, list) and all(isinstance(x, bool) for x in other)): + other = validate_column_comparand(self._native_series.index, other) return self._from_native_series(self._rename(ser.loc[other], ser.name)) def __eq__(self, other: object) -> PandasLikeSeries: # type: ignore[override] diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index b09f12cc0..205149721 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -158,8 +158,13 @@ def unique( ) ) - def filter(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> Self: - predicates, _ = self._flatten_and_extract(*predicates) + def filter(self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool]) -> Self: + if not ( + len(predicates) == 1 + and isinstance(predicates[0], list) + and all(isinstance(x, bool) for x in predicates[0]) + ): + predicates, _ = self._flatten_and_extract(*predicates) return self._from_compliant_dataframe( self._compliant_frame.filter(*predicates), ) @@ -599,6 +604,8 @@ def __getitem__( 2 ] """ + if isinstance(item, int): + item = [item] if ( isinstance(item, tuple) and len(item) == 2 @@ -693,6 +700,40 @@ def to_dict( } return self._compliant_frame.to_dict(as_series=as_series) # type: ignore[no-any-return] + def row(self, index: int) -> tuple[Any, ...]: + """ + Get values at given row. + + !!!note + You should NEVER use this method to iterate over a DataFrame; + if you require row-iteration you should strongly prefer use of iter_rows() instead. + + Arguments: + index: Row number. + + Examples: + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> data = {"a": [1, 2, 3], "b": [4, 5, 6]} + >>> df_pd = pd.DataFrame(data) + >>> df_pl = pl.DataFrame(data) + + Let's define a library-agnostic function to get the second row. + + >>> @nw.narwhalify + ... def func(df): + ... return df.row(1) + + We can then pass pandas / Polars / any other supported library: + + >>> func(df_pd) + (2, 5) + >>> func(df_pl) + (2, 5) + """ + return self._compliant_frame.row(index) # type: ignore[no-any-return] + # inherited def pipe(self, function: Callable[[Any], Self], *args: Any, **kwargs: Any) -> Self: """ @@ -1483,14 +1524,15 @@ def unique( """ return super().unique(subset, keep=keep, maintain_order=maintain_order) - def filter(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> Self: + def filter(self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool]) -> Self: r""" Filter the rows in the DataFrame based on one or more predicate expressions. The original order of the remaining rows is preserved. Arguments: - predicates: Expression(s) that evaluates to a boolean Series. + *predicates: Expression(s) that evaluates to a boolean Series. Can + also be a (single!) boolean list. Examples: >>> import pandas as pd @@ -2949,14 +2991,15 @@ def unique( """ return super().unique(subset, keep=keep, maintain_order=maintain_order) - def filter(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> Self: + def filter(self, *predicates: IntoExpr | Iterable[IntoExpr] | list[bool]) -> Self: r""" Filter the rows in the LazyFrame based on a predicate expression. The original order of the remaining rows is preserved. Arguments: - *predicates: Expression that evaluates to a boolean Series. + *predicates: Expression that evaluates to a boolean Series. Can + also be a (single!) boolean list. Examples: >>> import pandas as pd diff --git a/narwhals/expr.py b/narwhals/expr.py index 74ead79e5..a6b330bf1 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -3657,7 +3657,8 @@ def func(plx: Any) -> Any: return Expr(func) -def all() -> Expr: +# Add underscore so it doesn't conflict with builtin `all` +def all_() -> Expr: """ Instantiate an expression representing all columns. @@ -3696,7 +3697,8 @@ def all() -> Expr: return Expr(lambda plx: plx.all()) -def len() -> Expr: +# Add underscore so it doesn't conflict with builtin `len` +def len_() -> Expr: """ Return the number of rows. diff --git a/tests/expr_and_series/filter_test.py b/tests/expr_and_series/filter_test.py index ede0c49a0..b55a0368e 100644 --- a/tests/expr_and_series/filter_test.py +++ b/tests/expr_and_series/filter_test.py @@ -27,3 +27,6 @@ def test_filter_series(constructor_eager: Any) -> None: result = df.select(df["a"].filter((df["i"] < 2) & (df["c"] == 5))) expected = {"a": [0]} compare_dicts(result, expected) + result_s = df["a"].filter([True, False, False, False, False]) + expected = {"a": [0]} + compare_dicts({"a": result_s}, expected) diff --git a/tests/frame/filter_test.py b/tests/frame/filter_test.py index 9b7ed45d2..a8d3144aa 100644 --- a/tests/frame/filter_test.py +++ b/tests/frame/filter_test.py @@ -12,11 +12,9 @@ def test_filter(constructor: Any) -> None: compare_dicts(result, expected) -def test_filter_series(constructor_eager: Any) -> None: +def test_filter_with_boolean_list(constructor: Any) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(constructor_eager(data), eager_only=True).with_columns( - mask=nw.col("a") > 1 - ) - result = df.filter(df["mask"]).drop("mask") + df = nw.from_native(constructor(data)) + result = df.filter([False, True, True]) expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]} compare_dicts(result, expected) diff --git a/tests/frame/get_column_test.py b/tests/frame/get_column_test.py index 70f655620..58766ac31 100644 --- a/tests/frame/get_column_test.py +++ b/tests/frame/get_column_test.py @@ -24,6 +24,9 @@ def test_non_string_name() -> None: result = nw.from_native(df, eager_only=True).get_column(0) # type: ignore[arg-type] compare_dicts({"a": result}, {"a": [1, 2]}) assert result.name == 0 # type: ignore[comparison-overlap] - with pytest.raises(TypeError, match="Expected str or slice"): - # Check that getitem would have raised - nw.from_native(df, eager_only=True)[0] # type: ignore[call-overload] + + +def test_get_single_row() -> None: + df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + result = nw.from_native(df, eager_only=True)[0] # type: ignore[call-overload] + compare_dicts(result, {"a": [1], "b": [3]}) diff --git a/tests/frame/row_test.py b/tests/frame/row_test.py new file mode 100644 index 000000000..602c50f55 --- /dev/null +++ b/tests/frame/row_test.py @@ -0,0 +1,14 @@ +from typing import Any + +import narwhals.stable.v1 as nw + + +def test_row_column(constructor_eager: Any) -> None: + data = { + "a": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + "b": [11, 12, 13, 14, 15, 16], + } + result = nw.from_native(constructor_eager(data), eager_only=True).row(2) + if "pyarrow_table" in str(constructor_eager): + result = tuple(x.as_py() for x in result) + assert result == (3.0, 13) diff --git a/tests/frame/slice_test.py b/tests/frame/slice_test.py index 222717d1c..eea94d440 100644 --- a/tests/frame/slice_test.py +++ b/tests/frame/slice_test.py @@ -49,9 +49,16 @@ def test_slice_lazy_fails() -> None: _ = nw.from_native(pl.LazyFrame(data))[1:] -def test_slice_int_fails(constructor_eager: Any) -> None: - with pytest.raises(TypeError, match="Expected str or slice, got: "): - _ = nw.from_native(constructor_eager(data))[1] # type: ignore[call-overload,index] +def test_slice_int(constructor_eager: Any) -> None: + result = nw.from_native(constructor_eager(data), eager_only=True)[1] # type: ignore[call-overload] + compare_dicts(result, {"a": [2], "b": [12]}) + + +def test_slice_fails(constructor_eager: Any) -> None: + class Foo: ... + + with pytest.raises(TypeError, match="Expected str or slice, got:"): + nw.from_native(constructor_eager(data), eager_only=True)[Foo()] # type: ignore[call-overload] def test_gather(constructor_eager: Any) -> None: