From 2de91c3b157f8ed05203eb58f974c6b157d58790 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 18 Mar 2024 16:48:05 +0000 Subject: [PATCH] try filter fastpath --- narwhals/pandas_like/dataframe.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/narwhals/pandas_like/dataframe.py b/narwhals/pandas_like/dataframe.py index 32dc9ae75..b25e8759a 100644 --- a/narwhals/pandas_like/dataframe.py +++ b/narwhals/pandas_like/dataframe.py @@ -94,16 +94,23 @@ def select( def filter( self, - *predicates: IntoPandasExpr | Iterable[IntoPandasExpr], + *predicates: Iterable[IntoPandasExpr], ) -> Self: - from narwhals.pandas_like.namespace import PandasNamespace - - plx = PandasNamespace(self._implementation) - expr = plx.all_horizontal(*predicates) - # Safety: all_horizontal's expression only returns a single column. - mask = expr._call(self)[0] - _mask = validate_dataframe_comparand(mask) - return self._from_dataframe(self._dataframe[_mask]) + masks = evaluate_into_exprs(self, *predicates) + if len(masks) == 1: + return self._from_dataframe( + self._dataframe[validate_dataframe_comparand(masks[0])] + ) + + return self._from_dataframe( + self._dataframe[ + validate_dataframe_comparand( + horizontal_concat(masks, implementation=self._implementation).all( + axis=1 + ) + ) + ] + ) def with_columns( self,