From 31158655dbc43172370e245dbaa0a1b8a6941d07 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Wed, 26 Jun 2024 15:59:33 +0200 Subject: [PATCH] test(python): Refactor serde tests, add hypothesis tests (#17216) --- py-polars/tests/unit/dataframe/test_serde.py | 54 ++++++++++++----- py-polars/tests/unit/expr/test_serde.py | 45 ++++++++++++++ py-polars/tests/unit/lazyframe/test_serde.py | 63 ++++++++++++++++++++ py-polars/tests/unit/test_serde.py | 51 +--------------- 4 files changed, 148 insertions(+), 65 deletions(-) create mode 100644 py-polars/tests/unit/expr/test_serde.py create mode 100644 py-polars/tests/unit/lazyframe/test_serde.py diff --git a/py-polars/tests/unit/dataframe/test_serde.py b/py-polars/tests/unit/dataframe/test_serde.py index d8509396801b..6362b4d11028 100644 --- a/py-polars/tests/unit/dataframe/test_serde.py +++ b/py-polars/tests/unit/dataframe/test_serde.py @@ -6,15 +6,32 @@ from typing import TYPE_CHECKING, Any import pytest +from hypothesis import given import polars as pl from polars.exceptions import ComputeError from polars.testing import assert_frame_equal +from polars.testing.parametric import dataframes if TYPE_CHECKING: from pathlib import Path +@given( + df=dataframes( + excluded_dtypes=[ + pl.Null, # Not implemented yet + pl.Float32, # Bug, see: https://github.com/pola-rs/polars/issues/17211 + pl.Float64, # Bug, see: https://github.com/pola-rs/polars/issues/17211 + ], + ) +) +def test_df_serde_roundtrip(df: pl.DataFrame) -> None: + serialized = df.serialize() + result = pl.DataFrame.deserialize(io.StringIO(serialized)) + assert_frame_equal(result, df, categorical_as_str=True) + + def test_df_serialize() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).sort("a") result = df.serialize() @@ -23,7 +40,7 @@ def test_df_serialize() -> None: @pytest.mark.parametrize("buf", [io.BytesIO(), io.StringIO()]) -def test_to_from_buffer(df: pl.DataFrame, buf: io.IOBase) -> None: +def test_df_serde_to_from_buffer(df: pl.DataFrame, buf: io.IOBase) -> None: df.serialize(buf) buf.seek(0) read_df = pl.DataFrame.deserialize(buf) @@ -31,7 +48,7 @@ def test_to_from_buffer(df: pl.DataFrame, buf: io.IOBase) -> None: @pytest.mark.write_disk() -def test_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None: +def test_df_serde_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) file_path = tmp_path / "small.json" @@ -41,13 +58,6 @@ def test_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None: assert_frame_equal(df, out, categorical_as_str=True) -def test_write_json_to_string() -> None: - # Tests if it runs if no arg given - df = pl.DataFrame({"a": [1, 2, 3]}) - expected_str = '{"columns":[{"name":"a","datatype":"Int64","bit_settings":"","values":[1,2,3]}]}' - assert df.serialize() == expected_str - - def test_write_json(df: pl.DataFrame) -> None: # Text-based conversion loses time info df = df.select(pl.all().exclude(["cat", "time"])) @@ -100,7 +110,7 @@ def test_df_serde_enum() -> None: ), ], ) -def test_write_read_json_array(data: Any, dtype: pl.DataType) -> None: +def test_df_serde_array(data: Any, dtype: pl.DataType) -> None: df = pl.DataFrame({"foo": data}, schema={"foo": dtype}) buf = io.StringIO() df.serialize(buf) @@ -135,9 +145,7 @@ def test_write_read_json_array(data: Any, dtype: pl.DataType) -> None: ), ], ) -def test_write_read_json_array_logical_inner_type( - data: Any, dtype: pl.DataType -) -> None: +def test_df_serde_array_logical_inner_type(data: Any, dtype: pl.DataType) -> None: df = pl.DataFrame({"foo": data}, schema={"foo": dtype}) buf = io.StringIO() df.serialize(buf) @@ -147,14 +155,30 @@ def test_write_read_json_array_logical_inner_type( assert deserialized_df.to_dict(as_series=False) == df.to_dict(as_series=False) -def test_json_deserialize_empty_list_10458() -> None: +def test_df_serde_empty_list_10458() -> None: schema = {"LIST_OF_STRINGS": pl.List(pl.String)} serialized_schema = pl.DataFrame(schema=schema).serialize() df = pl.DataFrame.deserialize(io.StringIO(serialized_schema)) assert df.schema == schema -def test_serde_validation() -> None: +@pytest.mark.xfail(reason="Bug: https://github.com/pola-rs/polars/issues/17211") +def test_df_serde_float_inf_nan() -> None: + df = pl.DataFrame({"a": [1.0, float("inf"), float("-inf"), float("nan")]}) + ser = df.serialize() + result = pl.DataFrame.deserialize(io.StringIO(ser)) + assert_frame_equal(result, df) + + +@pytest.mark.xfail(reason="Not implemented yet") +def test_df_serde_null() -> None: + df = pl.DataFrame({"a": [None, None]}) + ser = df.serialize() + result = pl.DataFrame.deserialize(io.StringIO(ser)) + assert_frame_equal(result, df) + + +def test_df_deserialize_validation() -> None: f = io.StringIO( """ { diff --git a/py-polars/tests/unit/expr/test_serde.py b/py-polars/tests/unit/expr/test_serde.py new file mode 100644 index 000000000000..a5100ca9d4b5 --- /dev/null +++ b/py-polars/tests/unit/expr/test_serde.py @@ -0,0 +1,45 @@ +import io + +import pytest + +import polars as pl +from polars.exceptions import ComputeError + + +def test_expr_serialization_roundtrip() -> None: + expr = pl.col("foo").sum().over("bar") + json = expr.meta.serialize() + round_tripped = pl.Expr.deserialize(io.StringIO(json)) + assert round_tripped.meta == expr + + +def test_expr_deserialize_file_not_found() -> None: + with pytest.raises(FileNotFoundError): + pl.Expr.deserialize("abcdef") + + +def test_expr_deserialize_invalid_json() -> None: + with pytest.raises( + ComputeError, match="could not deserialize input into an expression" + ): + pl.Expr.deserialize(io.StringIO("abcdef")) + + +def test_expr_write_json_from_json_deprecated() -> None: + expr = pl.col("foo").sum().over("bar") + + with pytest.deprecated_call(): + json = expr.meta.write_json() + + with pytest.deprecated_call(): + round_tripped = pl.Expr.from_json(json) + + assert round_tripped.meta == expr + + +def test_expression_json_13991() -> None: + expr = pl.col("foo").cast(pl.Decimal) + json = expr.meta.serialize() + + round_tripped = pl.Expr.deserialize(io.StringIO(json)) + assert round_tripped.meta == expr diff --git a/py-polars/tests/unit/lazyframe/test_serde.py b/py-polars/tests/unit/lazyframe/test_serde.py new file mode 100644 index 000000000000..83fc8fb042e9 --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_serde.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import io +from typing import TYPE_CHECKING + +import pytest +from hypothesis import given + +import polars as pl +from polars.testing import assert_frame_equal +from polars.testing.parametric import dataframes + +if TYPE_CHECKING: + from pathlib import Path + + +@given( + lf=dataframes( + lazy=True, + excluded_dtypes=[ + pl.Null, # Not implemented yet + pl.Float32, # Bug, see: https://github.com/pola-rs/polars/issues/17211 + pl.Float64, # Bug, see: https://github.com/pola-rs/polars/issues/17211 + ], + ) +) +def test_lf_serde_roundtrip(lf: pl.LazyFrame) -> None: + serialized = lf.serialize() + result = pl.LazyFrame.deserialize(io.StringIO(serialized)) + assert_frame_equal(result, lf, categorical_as_str=True) + + +@pytest.fixture() +def lf() -> pl.LazyFrame: + """Sample LazyFrame for testing serialization/deserialization.""" + return pl.LazyFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}).select("a").sum() + + +def test_lf_serde(lf: pl.LazyFrame) -> None: + serialized = lf.serialize() + assert isinstance(serialized, str) + result = pl.LazyFrame.deserialize(io.StringIO(serialized)) + + assert_frame_equal(result, lf) + + +@pytest.mark.parametrize("buf", [io.BytesIO(), io.StringIO()]) +def test_lf_serde_to_from_buffer(lf: pl.LazyFrame, buf: io.IOBase) -> None: + lf.serialize(buf) + buf.seek(0) + result = pl.LazyFrame.deserialize(buf) + assert_frame_equal(lf, result) + + +@pytest.mark.write_disk() +def test_lf_serde_to_from_file(lf: pl.LazyFrame, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + file_path = tmp_path / "small.json" + lf.serialize(file_path) + result = pl.LazyFrame.deserialize(file_path) + + assert_frame_equal(lf, result) diff --git a/py-polars/tests/unit/test_serde.py b/py-polars/tests/unit/test_serde.py index 2b9c7a118744..869e75e3bf64 100644 --- a/py-polars/tests/unit/test_serde.py +++ b/py-polars/tests/unit/test_serde.py @@ -1,6 +1,5 @@ from __future__ import annotations -import io import pickle from datetime import datetime, timedelta @@ -8,7 +7,7 @@ import polars as pl from polars import StringCache -from polars.exceptions import ComputeError, SchemaError +from polars.exceptions import SchemaError from polars.testing import assert_frame_equal, assert_series_equal @@ -24,15 +23,6 @@ def test_pickling_as_struct_11100() -> None: assert str(pickle.loads(buf)) == str(e) -def test_lazyframe_serde() -> None: - lf = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}).lazy().select(pl.col("a")) - - json = lf.serialize() - result = pl.LazyFrame.deserialize(io.StringIO(json)) - - assert_series_equal(result.collect().to_series(), pl.Series("a", [1, 2, 3])) - - def test_serde_time_unit() -> None: values = [datetime(2022, 1, 1) + timedelta(days=1) for _ in range(3)] s = pl.Series(values).cast(pl.Datetime("ns")) @@ -195,45 +185,6 @@ def test_serde_array_dtype() -> None: assert_series_equal(pickle.loads(pickle.dumps(nested_s)), nested_s) -def test_expr_serialization_roundtrip() -> None: - expr = pl.col("foo").sum().over("bar") - json = expr.meta.serialize() - round_tripped = pl.Expr.deserialize(io.StringIO(json)) - assert round_tripped.meta == expr - - -def test_expr_deserialize_file_not_found() -> None: - with pytest.raises(FileNotFoundError): - pl.Expr.deserialize("abcdef") - - -def test_expr_deserialize_invalid_json() -> None: - with pytest.raises( - ComputeError, match="could not deserialize input into an expression" - ): - pl.Expr.deserialize(io.StringIO("abcdef")) - - -def test_expr_write_json_from_json_deprecated() -> None: - expr = pl.col("foo").sum().over("bar") - - with pytest.deprecated_call(): - json = expr.meta.write_json() - - with pytest.deprecated_call(): - round_tripped = pl.Expr.from_json(json) - - assert round_tripped.meta == expr - - -def test_expression_json_13991() -> None: - expr = pl.col("foo").cast(pl.Decimal) - json = expr.meta.serialize() - - round_tripped = pl.Expr.deserialize(io.StringIO(json)) - assert round_tripped.meta == expr - - def test_serde_data_type_class() -> None: dtype = pl.Datetime serialized = pickle.dumps(dtype)