diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 13a4596d5852..7df33cda0208 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -3034,7 +3034,7 @@ def extend(self, other: Series) -> Self: raise return self - def filter(self, predicate: Series | list[bool]) -> Self: + def filter(self, predicate: Series | Iterable[bool]) -> Self: """ Filter elements by a boolean mask. @@ -3060,7 +3060,7 @@ def filter(self, predicate: Series | list[bool]) -> Self: 3 ] """ - if isinstance(predicate, list): + if not isinstance(predicate, Series): predicate = Series("", predicate) return self._from_pyseries(self._s.filter(predicate._s)) diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 40df19738f6e..fc3872ce1f3b 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -1374,6 +1374,10 @@ def test_filter() -> None: assert_series_equal(s.filter(mask), pl.Series("a", [1, 3])) assert_series_equal(s.filter([True, False, True]), pl.Series("a", [1, 3])) + assert_series_equal(s.filter(np.array([True, False, True])), pl.Series("a", [1, 3])) + + with pytest.raises(RuntimeError, match="Expected a boolean mask"): + s.filter(np.array([1, 0, 1])) def test_gather_every() -> None: