Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve perf #21

Merged
merged 1 commit into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from narwhals.group_by import GroupBy
from narwhals.series import Series
from narwhals.typing import IntoExpr
from narwhals.typing import T


class BaseFrame:
Expand Down Expand Up @@ -208,7 +207,7 @@ def to_dict(self, *, as_series: bool = True) -> dict[str, Any]:
class LazyFrame(BaseFrame):
def __init__(
self,
df: T,
df: Any,
*,
implementation: str | None = None,
) -> None:
Expand Down
29 changes: 15 additions & 14 deletions narwhals/pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from narwhals.pandas_like.utils import evaluate_into_exprs
from narwhals.pandas_like.utils import horizontal_concat
from narwhals.pandas_like.utils import reset_index
from narwhals.pandas_like.utils import translate_dtype
from narwhals.pandas_like.utils import validate_dataframe_comparand
from narwhals.utils import flatten_str
Expand All @@ -33,17 +32,18 @@ def __init__(
implementation: str,
) -> None:
self._validate_columns(dataframe.columns)
self._dataframe = reset_index(dataframe)
self._dataframe = dataframe
self._implementation = implementation

def _validate_columns(self, columns: Sequence[str]) -> None:
counter = collections.Counter(columns)
for col, count in counter.items():
if count > 1:
msg = f"Expected unique column names, got {col!r} {count} time(s)"
raise ValueError(
msg,
)
if len(columns) != len(set(columns)):
counter = collections.Counter(columns)
for col, count in counter.items():
if count > 1:
msg = f"Expected unique column names, got {col!r} {count} time(s)"
raise ValueError(
msg,
)

def _validate_booleanness(self) -> None:
if not (
Expand Down Expand Up @@ -102,7 +102,7 @@ def filter(
expr = plx.all_horizontal(*predicates)
# Safety: all_horizontal's expression only returns a single column.
mask = expr._call(self)[0]
_mask = validate_dataframe_comparand(mask)
_mask = validate_dataframe_comparand(self._dataframe.index, mask)
return self._from_dataframe(self._dataframe.loc[_mask])

def with_columns(
Expand All @@ -112,7 +112,10 @@ def with_columns(
) -> Self:
new_series = evaluate_into_exprs(self, *exprs, **named_exprs)
df = self._dataframe.assign(
**{series.name: validate_dataframe_comparand(series) for series in new_series}
**{
series.name: validate_dataframe_comparand(self._dataframe.index, series)
for series in new_series
}
)
return self._from_dataframe(df)

Expand All @@ -137,9 +140,7 @@ def sort(
ascending: bool | list[bool] = not descending
else:
ascending = [not d for d in descending]
return self._from_dataframe(
df.sort_values(flat_keys, ascending=ascending),
)
return self._from_dataframe(df.sort_values(flat_keys, ascending=ascending))

# --- convert ---
def collect(self) -> PandasDataFrame:
Expand Down
59 changes: 30 additions & 29 deletions narwhals/pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pandas.api.types import is_extension_array_dtype

from narwhals.pandas_like.utils import item
from narwhals.pandas_like.utils import reset_index
from narwhals.pandas_like.utils import reverse_translate_dtype
from narwhals.pandas_like.utils import translate_dtype
from narwhals.pandas_like.utils import validate_column_comparand
Expand All @@ -32,7 +31,7 @@ def __init__(
"""

self._name = str(series.name) if series.name is not None else ""
self._series = reset_index(series)
self._series = series
self._implementation = implementation

def _from_series(self, series: Any) -> Self:
Expand Down Expand Up @@ -70,7 +69,9 @@ def cast(

def filter(self, mask: Self) -> Self:
ser = self._series
return self._from_series(ser.loc[validate_column_comparand(mask)])
return self._from_series(
ser.loc[validate_column_comparand(self._series.index, mask)]
)

def item(self) -> Any:
return item(self._series)
Expand All @@ -93,122 +94,122 @@ def is_in(self, other: Any) -> PandasSeries:

def __eq__(self, other: object) -> PandasSeries: # type: ignore[override]
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__eq__(other)).rename(ser.name, copy=False))

def __ne__(self, other: object) -> PandasSeries: # type: ignore[override]
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__ne__(other)).rename(ser.name, copy=False))

def __ge__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__ge__(other)).rename(ser.name, copy=False))

def __gt__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__gt__(other)).rename(ser.name, copy=False))

def __le__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__le__(other)).rename(ser.name, copy=False))

def __lt__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__lt__(other)).rename(ser.name, copy=False))

def __and__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__and__(other)).rename(ser.name, copy=False))

def __rand__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__rand__(other)).rename(ser.name, copy=False))

def __or__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__or__(other)).rename(ser.name, copy=False))

def __ror__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__ror__(other)).rename(ser.name, copy=False))

def __add__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__add__(other)).rename(ser.name, copy=False))

def __radd__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__radd__(other)).rename(ser.name, copy=False))

def __sub__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__sub__(other)).rename(ser.name, copy=False))

def __rsub__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__rsub__(other)).rename(ser.name, copy=False))

def __mul__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__mul__(other)).rename(ser.name, copy=False))

def __rmul__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__rmul__(other)).rename(ser.name, copy=False))

def __truediv__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__truediv__(other)).rename(ser.name, copy=False))

def __rtruediv__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__rtruediv__(other)).rename(ser.name, copy=False))

def __floordiv__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__floordiv__(other)).rename(ser.name, copy=False))

def __rfloordiv__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__rfloordiv__(other)).rename(ser.name, copy=False))

def __pow__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__pow__(other)).rename(ser.name, copy=False))

def __rpow__(self, other: Any) -> PandasSeries: # pragma: no cover
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__rpow__(other)).rename(ser.name, copy=False))

def __mod__(self, other: Any) -> PandasSeries:
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__mod__(other)).rename(ser.name, copy=False))

def __rmod__(self, other: Any) -> PandasSeries: # pragma: no cover
ser = self._series
other = validate_column_comparand(other)
other = validate_column_comparand(self._series.index, other)
return self._from_series((ser.__rmod__(other)).rename(ser.name, copy=False))

# Unary
Expand Down Expand Up @@ -285,8 +286,8 @@ def n_unique(self) -> int:
return ser.nunique() # type: ignore[no-any-return]

def zip_with(self, mask: PandasSeries, other: PandasSeries) -> PandasSeries:
mask = validate_column_comparand(mask)
other = validate_column_comparand(other)
mask = validate_column_comparand(self._series.index, mask)
other = validate_column_comparand(self._series.index, other)
ser = self._series
return self._from_series(ser.where(mask, other))

Expand Down
42 changes: 20 additions & 22 deletions narwhals/pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from narwhals.pandas_like.typing import IntoPandasExpr


def validate_column_comparand(other: Any) -> Any:
def validate_column_comparand(index: Any, other: Any) -> Any:
"""Validate RHS of binary operation.

If the comparison isn't supported, return `NotImplemented` so that the
Expand All @@ -47,11 +47,17 @@ def validate_column_comparand(other: Any) -> Any:
if other.len() == 1:
# broadcast
return other.item()
if other._series.index is not index and not (other._series.index == index).all():
msg = (
"Narwhals does not support automated index alignment. "
"Please reset the index of the Series or DataFrame."
)
raise ValueError(msg)
return other._series
return other


def validate_dataframe_comparand(other: Any) -> Any:
def validate_dataframe_comparand(index: Any, other: Any) -> Any:
"""Validate RHS of binary operation.

If the comparison isn't supported, return `NotImplemented` so that the
Expand All @@ -60,19 +66,25 @@ def validate_dataframe_comparand(other: Any) -> Any:
from narwhals.pandas_like.dataframe import PandasDataFrame
from narwhals.pandas_like.series import PandasSeries

if isinstance(other, list) and len(other) > 1:
# e.g. `plx.all() + plx.all()`
msg = "Multi-output expressions are not supported in this context"
raise ValueError(msg)
if isinstance(other, list):
other = other[0]
if isinstance(other, PandasDataFrame):
return NotImplemented
if isinstance(other, PandasSeries):
if other.len() == 1:
# broadcast
return item(other._series)
if other._series.index is not index and not (other._series.index == index).all():
msg = (
"Narwhals does not support automated index alignment. "
"Please reset the index of the Series or DataFrame."
)
raise ValueError(msg)
return other._series
if isinstance(other, list) and len(other) > 1:
# e.g. `plx.all() + plx.all()`
msg = "Multi-output expressions are not supported in this context"
raise ValueError(msg)
if isinstance(other, list):
other = other[0]
return other


Expand Down Expand Up @@ -368,17 +380,3 @@ def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
return "bool"
msg = f"Unknown dtype: {dtype}"
raise TypeError(msg)


def reset_index(obj: Any) -> Any:
index = obj.index
if (
hasattr(index, "start")
and hasattr(index, "stop")
and hasattr(index, "step")
and index.start == 0
and index.stop == len(obj)
and index.step == 1
):
return obj
return obj.reset_index(drop=True)
1 change: 1 addition & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def test_accepted_dataframes() -> None:


@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd])
@pytest.mark.filterwarnings("ignore:.*Passing a BlockManager.*:DeprecationWarning")
def test_convert_pandas(df_raw: Any) -> None:
result = nw.DataFrame(df_raw).to_pandas()
expected = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]})
Expand Down
Loading
Loading