Skip to content

Commit

Permalink
chore: Refactor validate_column_comparand (#1102)
Browse files Browse the repository at this point in the history
* simplify

* rename argument
  • Loading branch information
MarcoGorelli authored Oct 1, 2024
1 parent 03ebe77 commit 8a1241f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
8 changes: 3 additions & 5 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,11 +619,9 @@ def quantile(

def zip_with(self: Self, mask: Any, other: Any) -> PandasLikeSeries:
ser = self._native_series
mask = validate_column_comparand(ser.index, mask)
if isinstance(mask, str) or not isinstance(
mask, (self.__native_namespace__().Series, Sequence)
):
mask = [mask]
mask = validate_column_comparand(
ser.index, mask, treat_length_one_as_scalar=False
)
other = validate_column_comparand(ser.index, other)
res = ser.where(mask, other)
return self._from_native_series(res)
Expand Down
6 changes: 4 additions & 2 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
}


def validate_column_comparand(index: Any, other: Any) -> Any:
def validate_column_comparand(
index: Any, other: Any, *, treat_length_one_as_scalar: bool = True
) -> Any:
"""Validate RHS of binary operation.
If the comparison isn't supported, return `NotImplemented` so that the
Expand All @@ -53,7 +55,7 @@ def validate_column_comparand(index: Any, other: Any) -> Any:
if isinstance(other, PandasLikeDataFrame):
return NotImplemented
if isinstance(other, PandasLikeSeries):
if other.len() == 1:
if other.len() == 1 and treat_length_one_as_scalar:
# broadcast
return other.item()
if other._native_series.index is not index:
Expand Down

0 comments on commit 8a1241f

Please sign in to comment.