diff --git a/py-polars/src/dataframe/serde.rs b/py-polars/src/dataframe/serde.rs index 8cfe0a692ca6..131262acf9c8 100644 --- a/py-polars/src/dataframe/serde.rs +++ b/py-polars/src/dataframe/serde.rs @@ -1,7 +1,8 @@ -use std::io::{BufWriter, Cursor}; +use std::io::{BufReader, BufWriter, Cursor}; use std::ops::Deref; use polars_io::mmap::ReaderBytes; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyBytes; @@ -42,17 +43,37 @@ impl PyDataFrame { } } + /// Serialize into binary data. + fn serialize_binary(&self, py_f: PyObject) -> PyResult<()> { + let file = get_file_like(py_f, true)?; + let writer = BufWriter::new(file); + ciborium::into_writer(&self.df, writer) + .map_err(|err| PyValueError::new_err(format!("{err:?}"))) + } + + /// Serialize into a JSON string. #[cfg(feature = "json")] - pub fn serialize(&mut self, py_f: PyObject) -> PyResult<()> { - let file = BufWriter::new(get_file_like(py_f, true)?); - serde_json::to_writer(file, &self.df) - .map_err(|e| polars_err!(ComputeError: "{e}")) - .map_err(|e| PyPolarsErr::Other(format!("{e}")).into()) + pub fn serialize_json(&mut self, py_f: PyObject) -> PyResult<()> { + let file = get_file_like(py_f, true)?; + let writer = BufWriter::new(file); + serde_json::to_writer(writer, &self.df) + .map_err(|err| PyValueError::new_err(format!("{err:?}"))) + } + + /// Deserialize a file-like object containing binary data into a DataFrame. + #[staticmethod] + fn deserialize_binary(py_f: PyObject) -> PyResult { + let file = get_file_like(py_f, false)?; + let reader = BufReader::new(file); + let df = ciborium::from_reader::(reader) + .map_err(|err| PyValueError::new_err(format!("{err:?}")))?; + Ok(df.into()) } + /// Deserialize a file-like object containing JSON string data into a DataFrame. #[staticmethod] #[cfg(feature = "json")] - pub fn deserialize(py: Python, mut py_f: Bound) -> PyResult { + pub fn deserialize_json(py: Python, mut py_f: Bound) -> PyResult { use crate::file::read_if_bytesio; py_f = read_if_bytesio(py_f); let mut mmap_bytes_r = get_mmap_bytes_reader(&py_f)?; diff --git a/py-polars/src/lazyframe/serde.rs b/py-polars/src/lazyframe/serde.rs index c6a914c5e60a..af49c0f2ebd7 100644 --- a/py-polars/src/lazyframe/serde.rs +++ b/py-polars/src/lazyframe/serde.rs @@ -37,19 +37,19 @@ impl PyLazyFrame { /// Serialize into binary data. fn serialize_binary(&self, py_f: PyObject) -> PyResult<()> { - let file = BufWriter::new(get_file_like(py_f, true)?); - ciborium::into_writer(&self.ldf.logical_plan, file) - .map_err(|err| PyValueError::new_err(format!("{err:?}")))?; - Ok(()) + let file = get_file_like(py_f, true)?; + let writer = BufWriter::new(file); + ciborium::into_writer(&self.ldf.logical_plan, writer) + .map_err(|err| PyValueError::new_err(format!("{err:?}"))) } /// Serialize into a JSON string. #[cfg(feature = "json")] fn serialize_json(&self, py_f: PyObject) -> PyResult<()> { - let file = BufWriter::new(get_file_like(py_f, true)?); - serde_json::to_writer(file, &self.ldf.logical_plan) - .map_err(|err| PyValueError::new_err(format!("{err:?}")))?; - Ok(()) + let file = get_file_like(py_f, true)?; + let writer = BufWriter::new(file); + serde_json::to_writer(writer, &self.ldf.logical_plan) + .map_err(|err| PyValueError::new_err(format!("{err:?}"))) } /// Deserialize a file-like object containing binary data into a LazyFrame. diff --git a/py-polars/tests/unit/dataframe/test_serde.py b/py-polars/tests/unit/dataframe/test_serde.py index 35c45b3b913e..e1e0c1970545 100644 --- a/py-polars/tests/unit/dataframe/test_serde.py +++ b/py-polars/tests/unit/dataframe/test_serde.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any import pytest -from hypothesis import given +from hypothesis import example, given import polars as pl from polars.exceptions import ComputeError @@ -17,6 +17,13 @@ from pathlib import Path +@given(df=dataframes()) +def test_df_serde_roundtrip_binary(df: pl.DataFrame) -> None: + serialized = df.serialize() + result = pl.DataFrame.deserialize(io.BytesIO(serialized), format="binary") + assert_frame_equal(result, df, categorical_as_str=True) + + @given( df=dataframes( excluded_dtypes=[ @@ -25,24 +32,47 @@ ], ) ) -def test_df_serde_roundtrip(df: pl.DataFrame) -> None: - serialized = df.serialize() - result = pl.DataFrame.deserialize(io.StringIO(serialized)) +@example(df=pl.DataFrame({"a": [None, None]}, schema={"a": pl.Null})) +@example(df=pl.DataFrame(schema={"a": pl.List(pl.String)})) +def test_df_serde_roundtrip_json(df: pl.DataFrame) -> None: + serialized = df.serialize(format="json") + result = pl.DataFrame.deserialize(io.StringIO(serialized), format="json") assert_frame_equal(result, df, categorical_as_str=True) -def test_df_serialize() -> None: +def test_df_serde(df: pl.DataFrame) -> None: + serialized = df.serialize() + assert isinstance(serialized, bytes) + result = pl.DataFrame.deserialize(io.BytesIO(serialized)) + assert_frame_equal(result, df) + + +def test_df_serde_json_stringio(df: pl.DataFrame) -> None: + serialized = df.serialize(format="json") + assert isinstance(serialized, str) + result = pl.DataFrame.deserialize(io.StringIO(serialized), format="json") + assert_frame_equal(result, df) + + +def test_df_serialize_json() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).sort("a") - result = df.serialize() + result = df.serialize(format="json") expected = '{"columns":[{"name":"a","datatype":"Int64","bit_settings":"SORTED_ASC","values":[1,2,3]},{"name":"b","datatype":"Int64","bit_settings":"","values":[4,5,6]}]}' assert result == expected -@pytest.mark.parametrize("buf", [io.BytesIO(), io.StringIO()]) -def test_df_serde_to_from_buffer(df: pl.DataFrame, buf: io.IOBase) -> None: - df.serialize(buf) +@pytest.mark.parametrize( + ("format", "buf"), + [ + ("binary", io.BytesIO()), + ("json", io.StringIO()), + ("json", io.BytesIO()), + ], +) +def test_df_serde_to_from_buffer(df: pl.DataFrame, format: str, buf: io.IOBase) -> None: + df.serialize(buf, format=format) buf.seek(0) - read_df = pl.DataFrame.deserialize(buf) + read_df = pl.DataFrame.deserialize(buf, format=format) assert_frame_equal(df, read_df, categorical_as_str=True) @@ -50,19 +80,19 @@ def test_df_serde_to_from_buffer(df: pl.DataFrame, buf: io.IOBase) -> None: def test_df_serde_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) - file_path = tmp_path / "small.json" + file_path = tmp_path / "small.bin" df.serialize(file_path) out = pl.DataFrame.deserialize(file_path) assert_frame_equal(df, out, categorical_as_str=True) -def test_write_json(df: pl.DataFrame) -> None: +def test_df_serde2(df: pl.DataFrame) -> None: # Text-based conversion loses time info df = df.select(pl.all().exclude(["cat", "time"])) s = df.serialize() f = io.BytesIO() - f.write(s.encode()) + f.write(s) f.seek(0) out = pl.DataFrame.deserialize(f) assert_frame_equal(out, df) @@ -77,7 +107,7 @@ def test_write_json(df: pl.DataFrame) -> None: def test_df_serde_enum() -> None: dtype = pl.Enum(["foo", "bar", "ham"]) df = pl.DataFrame([pl.Series("e", ["foo", "bar", "ham"], dtype=dtype)]) - buf = io.StringIO() + buf = io.BytesIO() df.serialize(buf) buf.seek(0) df_in = pl.DataFrame.deserialize(buf) @@ -111,7 +141,7 @@ def test_df_serde_enum() -> None: ) def test_df_serde_array(data: Any, dtype: pl.DataType) -> None: df = pl.DataFrame({"foo": data}, schema={"foo": dtype}) - buf = io.StringIO() + buf = io.BytesIO() df.serialize(buf) buf.seek(0) deserialized_df = pl.DataFrame.deserialize(buf) @@ -146,33 +176,18 @@ def test_df_serde_array(data: Any, dtype: pl.DataType) -> None: ) def test_df_serde_array_logical_inner_type(data: Any, dtype: pl.DataType) -> None: df = pl.DataFrame({"foo": data}, schema={"foo": dtype}) - buf = io.StringIO() + buf = io.BytesIO() df.serialize(buf) buf.seek(0) - deserialized_df = pl.DataFrame.deserialize(buf) - assert deserialized_df.dtypes == df.dtypes - assert deserialized_df.to_dict(as_series=False) == df.to_dict(as_series=False) - - -def test_df_serde_empty_list_10458() -> None: - schema = {"LIST_OF_STRINGS": pl.List(pl.String)} - serialized_schema = pl.DataFrame(schema=schema).serialize() - df = pl.DataFrame.deserialize(io.StringIO(serialized_schema)) - assert df.schema == schema + result = pl.DataFrame.deserialize(buf) + assert_frame_equal(result, df) @pytest.mark.xfail(reason="Bug: https://github.com/pola-rs/polars/issues/17211") def test_df_serde_float_inf_nan() -> None: df = pl.DataFrame({"a": [1.0, float("inf"), float("-inf"), float("nan")]}) - ser = df.serialize() - result = pl.DataFrame.deserialize(io.StringIO(ser)) - assert_frame_equal(result, df) - - -def test_df_serde_null() -> None: - df = pl.DataFrame({"a": [None, None]}, schema={"a": pl.Null}) - ser = df.serialize() - result = pl.DataFrame.deserialize(io.StringIO(ser)) + ser = df.serialize(format="json") + result = pl.DataFrame.deserialize(io.StringIO(ser), format="json") assert_frame_equal(result, df) @@ -201,4 +216,4 @@ def test_df_deserialize_validation() -> None: """ ) with pytest.raises(ComputeError, match=r"lengths don't match"): - pl.DataFrame.deserialize(f) + pl.DataFrame.deserialize(f, format="json") diff --git a/py-polars/tests/unit/lazyframe/test_serde.py b/py-polars/tests/unit/lazyframe/test_serde.py index 85dd17b789fc..4bd7a22f7906 100644 --- a/py-polars/tests/unit/lazyframe/test_serde.py +++ b/py-polars/tests/unit/lazyframe/test_serde.py @@ -7,6 +7,7 @@ from hypothesis import example, given import polars as pl +from polars.exceptions import ComputeError from polars.testing import assert_frame_equal from polars.testing.parametric import dataframes @@ -81,3 +82,9 @@ def test_lf_serde_to_from_file(lf: pl.LazyFrame, tmp_path: Path) -> None: result = pl.LazyFrame.deserialize(file_path) assert_frame_equal(lf, result) + + +def test_lf_deserialize_validation() -> None: + f = io.BytesIO(b"hello world!") + with pytest.raises(ComputeError, match="expected value at line 1 column 1"): + pl.DataFrame.deserialize(f, format="json")