diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index f4bcaec07..74e1c492d 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -619,11 +619,9 @@ def quantile( def zip_with(self: Self, mask: Any, other: Any) -> PandasLikeSeries: ser = self._native_series - mask = validate_column_comparand(ser.index, mask) - if isinstance(mask, str) or not isinstance( - mask, (self.__native_namespace__().Series, Sequence) - ): - mask = [mask] + mask = validate_column_comparand( + ser.index, mask, treat_length_one_as_scalar=False + ) other = validate_column_comparand(ser.index, other) res = ser.where(mask, other) return self._from_native_series(res) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 910c19a11..df87e6499 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -32,7 +32,9 @@ } -def validate_column_comparand(index: Any, other: Any) -> Any: +def validate_column_comparand( + index: Any, other: Any, *, treat_length_one_as_scalar: bool = True +) -> Any: """Validate RHS of binary operation. If the comparison isn't supported, return `NotImplemented` so that the @@ -53,7 +55,7 @@ def validate_column_comparand(index: Any, other: Any) -> Any: if isinstance(other, PandasLikeDataFrame): return NotImplemented if isinstance(other, PandasLikeSeries): - if other.len() == 1: + if other.len() == 1 and treat_length_one_as_scalar: # broadcast return other.item() if other._native_series.index is not index: