diff --git a/py-polars/polars/testing/parametric/strategies/core.py b/py-polars/polars/testing/parametric/strategies/core.py index b90e759ee3fb..22e8d7c65d5b 100644 --- a/py-polars/polars/testing/parametric/strategies/core.py +++ b/py-polars/polars/testing/parametric/strategies/core.py @@ -420,6 +420,7 @@ def dataframes( # noqa: D417 version="1.0.0", ) min_size = max_size = size + allow_nan = kwargs.pop("allow_nan", None) if isinstance(include_cols, column): include_cols = [include_cols] @@ -451,6 +452,11 @@ def dataframes( # noqa: D417 c.allow_null = allow_null.get(c.name, True) else: c.allow_null = allow_null + if c.allow_nan is None: + if isinstance(allow_nan, Mapping): + c.allow_nan = allow_nan.get(c.name, True) + else: + c.allow_nan = allow_nan allow_series_chunks = draw(st.booleans()) if allow_chunks else False @@ -464,6 +470,7 @@ def dataframes( # noqa: D417 max_size=size, strategy=c.strategy, allow_null=c.allow_null, # type: ignore[arg-type] + allow_nan=c.allow_nan, allow_chunks=allow_series_chunks, unique=c.unique, allowed_dtypes=allowed_dtypes, @@ -503,6 +510,8 @@ class column: supports overriding the default strategy for the given dtype. allow_null : bool, optional Allow nulls as possible values and allow the `Null` data type by default. + allow_nan : bool, optional + Allow nans as possible values. Only applicable to float/decimal dtype columns. unique : bool, optional flag indicating that all values generated for the column should be unique. @@ -540,6 +549,7 @@ class column: dtype: PolarsDataType | None = None strategy: SearchStrategy[Any] | None = None allow_null: bool | None = None + allow_nan: bool | None = None unique: bool = False null_probability: float | None = None diff --git a/py-polars/polars/testing/parametric/strategies/data.py b/py-polars/polars/testing/parametric/strategies/data.py index 1f0c656cf73f..6186ac82d30f 100644 --- a/py-polars/polars/testing/parametric/strategies/data.py +++ b/py-polars/polars/testing/parametric/strategies/data.py @@ -94,10 +94,17 @@ def integers( def floats( - bit_width: Literal[32, 64] = 64, *, allow_infinity: bool = True + bit_width: Literal[32, 64] = 64, + *, + allow_infinity: bool = True, + allow_nan: bool = True, ) -> SearchStrategy[float]: """Create a strategy for generating integers.""" - return st.floats(width=bit_width, allow_infinity=allow_infinity) + return st.floats( + width=bit_width, + allow_infinity=allow_infinity, + allow_nan=allow_nan, + ) def booleans() -> SearchStrategy[bool]: @@ -382,9 +389,17 @@ def data( if (strategy := _STATIC_STRATEGIES.get(dtype.base_type())) is not None: strategy = strategy elif dtype == Float32: - strategy = floats(32, allow_infinity=kwargs.pop("allow_infinity", True)) + strategy = floats( + 32, + allow_infinity=kwargs.pop("allow_infinity", True), + allow_nan=kwargs.pop("allow_nan", True), + ) elif dtype == Float64: - strategy = floats(64, allow_infinity=kwargs.pop("allow_infinity", True)) + strategy = floats( + 64, + allow_infinity=kwargs.pop("allow_infinity", True), + allow_nan=kwargs.pop("allow_nan", True), + ) elif dtype == Datetime: strategy = datetimes( time_unit=getattr(dtype, "time_unit", None) or "us", diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_core.py b/py-polars/tests/unit/testing/parametric/strategies/test_core.py index bb54255f4d10..ebc5a3dda2f0 100644 --- a/py-polars/tests/unit/testing/parametric/strategies/test_core.py +++ b/py-polars/tests/unit/testing/parametric/strategies/test_core.py @@ -68,6 +68,18 @@ def test_series_allow_null_allowed_dtypes(s: pl.Series) -> None: assert s.dtype == pl.Null +@given( + s=series( + allowed_dtypes=[pl.Float32, pl.Float64], + allow_nan=False, + allow_null=False, + min_size=1, + ) +) +def test_series_allow_nan_false(s: pl.Series) -> None: + assert s.is_not_nan().any() + + @given(s=series(allowed_dtypes=[pl.List(pl.Int8)], allow_null=False)) def test_series_allow_null_nested(s: pl.Series) -> None: for v in s: @@ -115,6 +127,33 @@ def test_dataframes_allow_null_column(df: pl.DataFrame) -> None: assert 0 <= null_count <= df.height * df.width +@given( + df=dataframes( + cols=1, + allowed_dtypes=[pl.Float32, pl.Float64], + allow_nan=False, + ), +) +def test_dataframes_allow_nan_false_global(df: pl.DataFrame) -> None: + print(df) + nan_count = df.select(pl.col("col0").is_nan().sum()).item() + assert nan_count == 0 + + +@given( + df=dataframes( + cols=2, + allowed_dtypes=[pl.Float32, pl.Float64], + allow_nan={"col0": False}, + ), +) +def test_dataframes_allow_nan_false_column(df: pl.DataFrame) -> None: + print(df) + nan_count = sum(df.select(pl.all().is_nan().sum()).row(0)) + # The maximum nan count is all values in a single column. + assert 0 <= nan_count <= df.height + + @given( df=dataframes( cols=1,