From ce542f2557fc77fdccd7b803a091e7bd63dd80df Mon Sep 17 00:00:00 2001 From: raisa <> Date: Tue, 19 Mar 2024 15:12:24 +0000 Subject: [PATCH 1/4] change sample method to work for pandas --- narwhals/pandas_like/expr.py | 12 ++++++++++-- narwhals/pandas_like/series.py | 12 ++++++++---- narwhals/series.py | 10 ++++++++-- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/narwhals/pandas_like/expr.py b/narwhals/pandas_like/expr.py index a8457fa53..b2b64d2cf 100644 --- a/narwhals/pandas_like/expr.py +++ b/narwhals/pandas_like/expr.py @@ -180,8 +180,16 @@ def n_unique(self) -> Self: def unique(self) -> Self: return register_expression_call(self, "unique") - def sample(self, n: int, fraction: float, *, with_replacement: bool) -> Self: - return register_expression_call(self, "sample", n, fraction, with_replacement) + def sample( + self, + n: int | None = None, + fraction: float | None = None, + *, + with_replacement: bool = False, + ) -> Self: + return register_expression_call( + self, "sample", n, fraction=fraction, with_replacement=with_replacement + ) def alias(self, name: str) -> Self: # Define this one manually, so that we can diff --git a/narwhals/pandas_like/series.py b/narwhals/pandas_like/series.py index 59a1e0bcc..6b264112e 100644 --- a/narwhals/pandas_like/series.py +++ b/narwhals/pandas_like/series.py @@ -291,11 +291,15 @@ def zip_with(self, mask: PandasSeries, other: PandasSeries) -> PandasSeries: ser = self._series return self._from_series(ser.where(mask, other)) - def sample(self, n: int, fraction: float, *, with_replacement: bool) -> PandasSeries: + def sample( + self, + n: int | None = None, + fraction: float | None = None, + *, + with_replacement: bool = False, + ) -> PandasSeries: ser = self._series - return self._from_series( - ser.sample(n=n, frac=fraction, with_replacement=with_replacement) - ) + return self._from_series(ser.sample(n=n, frac=fraction, replace=with_replacement)) def unique(self) -> PandasSeries: if self._implementation != "pandas": diff --git a/narwhals/series.py b/narwhals/series.py index a99cba85e..fc4c442ed 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -118,9 +118,15 @@ def zip_with(self, mask: Self, other: Self) -> Self: self._series.zip_with(self._extract_native(mask), self._extract_native(other)) ) - def sample(self, n: int, fraction: float, *, with_replacement: bool) -> Self: + def sample( + self, + n: int | None = None, + fraction: float | None = None, + *, + with_replacement: bool = False, + ) -> Self: return self._from_series( - self._series.sample(n, fraction=fraction, with_replacement=with_replacement) + self._series.sample(n=n, fraction=fraction, with_replacement=with_replacement) ) def to_numpy(self) -> Any: From 998ebe5a1f7534857e2c8b0583599577b5d9c700 Mon Sep 17 00:00:00 2001 From: raisa <> Date: Wed, 20 Mar 2024 10:31:52 +0000 Subject: [PATCH 2/4] add invert to expressions, add tests for invert and sample --- narwhals/expression.py | 3 +++ tests/test_common.py | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/narwhals/expression.py b/narwhals/expression.py index bd9c2aee9..b786b2493 100644 --- a/narwhals/expression.py +++ b/narwhals/expression.py @@ -112,6 +112,9 @@ def __ge__(self, other: Any) -> Expr: ) # --- unary --- + def __invert__(self) -> Expr: + return self.__class__(lambda plx: self._call(plx).__invert__()) + def mean(self) -> Expr: return self.__class__(lambda plx: self._call(plx).mean()) diff --git a/tests/test_common.py b/tests/test_common.py index 78c160d6a..2337bcbe8 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -15,6 +15,8 @@ df_pandas = pd.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) df_polars = pl.DataFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) df_lazy = pl.LazyFrame({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}) +df_pandas_na = pd.DataFrame({"a": [None, 3, 2], "b": [4, 4, 6], "z": [7.0, None, 9]}) +df_lazy_na = pl.LazyFrame({"a": [None, 3, 2], "b": [4, 4, 6], "z": [7.0, None, 9]}) if os.environ.get("CI", None): import modin.pandas as mpd @@ -321,3 +323,21 @@ def test_expr_min_max(df_raw: Any) -> None: expected_max = {"a": [3], "b": [6], "z": [9]} compare_dicts(result_min, expected_min) compare_dicts(result_max, expected_max) + + +@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) +def test_expr_sample(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result_shape = nw.to_native(df.select(nw.col("a", "b").sample(n=2))).collect().shape + expected = (2, 2) + assert result_shape == expected + + +@pytest.mark.parametrize("df_raw", [df_pandas_na, df_lazy_na]) +def test_expr_na(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result_nna = nw.to_native( + df.filter((~nw.col("a").is_null()) & (~nw.col("z").is_null())) + ) + expected = {"a": [2], "b": [6], "z": [9]} + compare_dicts(result_nna, expected) From dc98e59b70944393f47e0d440f0de494c0bdbe32 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 20 Mar 2024 10:44:51 +0000 Subject: [PATCH 3/4] reset indices if we really have to --- narwhals/pandas_like/dataframe.py | 2 ++ narwhals/pandas_like/utils.py | 8 ++++++++ tests/test_common.py | 2 +- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/narwhals/pandas_like/dataframe.py b/narwhals/pandas_like/dataframe.py index bbfdfc74b..1724a5395 100644 --- a/narwhals/pandas_like/dataframe.py +++ b/narwhals/pandas_like/dataframe.py @@ -8,6 +8,7 @@ from narwhals.pandas_like.utils import evaluate_into_exprs from narwhals.pandas_like.utils import horizontal_concat +from narwhals.pandas_like.utils import maybe_reset_indices from narwhals.pandas_like.utils import translate_dtype from narwhals.pandas_like.utils import validate_dataframe_comparand from narwhals.utils import flatten_str @@ -86,6 +87,7 @@ def select( **named_exprs: IntoPandasExpr, ) -> Self: new_series = evaluate_into_exprs(self, *exprs, **named_exprs) + new_series = maybe_reset_indices(new_series) df = horizontal_concat( [series._series for series in new_series], implementation=self._implementation, diff --git a/narwhals/pandas_like/utils.py b/narwhals/pandas_like/utils.py index 9486e94fb..0da6c046b 100644 --- a/narwhals/pandas_like/utils.py +++ b/narwhals/pandas_like/utils.py @@ -381,3 +381,11 @@ def reverse_translate_dtype(dtype: DType | type[DType]) -> Any: return "bool" msg = f"Unknown dtype: {dtype}" raise TypeError(msg) + + +def maybe_reset_indices(series: list[PandasSeries]) -> list[PandasSeries]: + idx = series[0]._series.index + for s in series[1:]: + if s._series.index is not idx and not (s._series.index == idx).all(): + break + return [s._from_series(s._series.reset_index(drop=True)) for s in series] diff --git a/tests/test_common.py b/tests/test_common.py index 2337bcbe8..040cb6c49 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -328,7 +328,7 @@ def test_expr_min_max(df_raw: Any) -> None: @pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) def test_expr_sample(df_raw: Any) -> None: df = nw.LazyFrame(df_raw) - result_shape = nw.to_native(df.select(nw.col("a", "b").sample(n=2))).collect().shape + result_shape = nw.to_native(df.select(nw.col("a", "b").sample(n=2)).collect()).shape expected = (2, 2) assert result_shape == expected From 1a9cb0394faf4b1adbd416cd7edcf1b768b2891a Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 20 Mar 2024 10:46:40 +0000 Subject: [PATCH 4/4] but only reset indices if you have to --- narwhals/pandas_like/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/narwhals/pandas_like/utils.py b/narwhals/pandas_like/utils.py index 0da6c046b..6c870506d 100644 --- a/narwhals/pandas_like/utils.py +++ b/narwhals/pandas_like/utils.py @@ -385,7 +385,11 @@ def reverse_translate_dtype(dtype: DType | type[DType]) -> Any: def maybe_reset_indices(series: list[PandasSeries]) -> list[PandasSeries]: idx = series[0]._series.index + found_non_matching_index = False for s in series[1:]: if s._series.index is not idx and not (s._series.index == idx).all(): + found_non_matching_index = True break - return [s._from_series(s._series.reset_index(drop=True)) for s in series] + if found_non_matching_index: + return [s._from_series(s._series.reset_index(drop=True)) for s in series] + return series