From bab74624f6f62f3d701b4f5bb850a82c34207513 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Tue, 2 Apr 2024 13:36:19 +0800 Subject: [PATCH] fix(python): Raise if pass a negative `n` into `clear` --- py-polars/polars/dataframe/frame.py | 17 +++++++++-------- py-polars/polars/series/series.py | 4 ++++ py-polars/tests/unit/operations/test_clear.py | 10 ++++++++++ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index a34c4aae5b84..8ae8dd516b81 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -6912,17 +6912,18 @@ def clear(self, n: int = 0) -> Self: │ null ┆ null ┆ null │ └──────┴──────┴──────┘ """ + if n < 0: + msg = "n should be greater than or equal to 0." + raise ValueError(msg) # faster path if n == 0: return self._from_pydf(self._df.clear()) - if n > 0 or len(self) > 0: - return self.__class__( - { - nm: pl.Series(name=nm, dtype=tp).extend_constant(None, n) - for nm, tp in self.schema.items() - } - ) - return self.clone() + return self.__class__( + { + nm: pl.Series(name=nm, dtype=tp).extend_constant(None, n) + for nm, tp in self.schema.items() + } + ) def clone(self) -> Self: """ diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index a651f8f77508..527ae2458b5a 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4759,6 +4759,10 @@ def clear(self, n: int = 0) -> Series: null ] """ + if n < 0: + msg = "n should be greater than or equal to 0." + raise ValueError(msg) + # faster path if n == 0: return self._from_pyseries(self._s.clear()) s = ( diff --git a/py-polars/tests/unit/operations/test_clear.py b/py-polars/tests/unit/operations/test_clear.py index 0ac3c1d27ba0..c4c98a611ad4 100644 --- a/py-polars/tests/unit/operations/test_clear.py +++ b/py-polars/tests/unit/operations/test_clear.py @@ -73,3 +73,13 @@ def test_clear_series_object_starting_with_null() -> None: assert result.dtype == s.dtype assert result.name == s.name assert result.is_empty() + + +def test_clear_raise_negative_n() -> None: + s = pl.Series([1, 2, 3]) + with pytest.raises(ValueError, match="n should be greater than or equal to 0"): + s.clear(-1) + + df = pl.DataFrame({"a": [1, 2, 3]}) + with pytest.raises(ValueError, match="n should be greater than or equal to 0"): + df.clear(-1)