Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed May 28, 2024
1 parent d16f4df commit dd26991
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 174 deletions.
3 changes: 3 additions & 0 deletions py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ filterwarnings = [
# TODO: Remove when behavior is updated
# https://github.com/pola-rs/polars/issues/13441
"ignore:.*default coalesce behavior of left join.*:DeprecationWarning",
# TODO: Remove when default is updated
# https://github.com/pola-rs/polars/issues/14526
"ignore:.*will only write row-oriented JSON.*:DeprecationWarning",
]
xfail_strict = true

Expand Down
136 changes: 136 additions & 0 deletions py-polars/tests/unit/dataframe/test_serde.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from __future__ import annotations

import io
from datetime import date, datetime, timedelta
from typing import TYPE_CHECKING, Any

import pytest

import polars as pl
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
from pathlib import Path


def test_df_serialize() -> None:
Expand All @@ -10,6 +20,132 @@ def test_df_serialize() -> None:
assert result == expected


@pytest.mark.parametrize("buf", [io.BytesIO(), io.StringIO()])
def test_to_from_buffer(df: pl.DataFrame, buf: io.IOBase) -> None:
df.serialize(buf)
buf.seek(0)
read_df = pl.DataFrame.deserialize(buf)
assert_frame_equal(df, read_df, categorical_as_str=True)


@pytest.mark.write_disk()
def test_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)

file_path = tmp_path / "small.json"
df.serialize(file_path)
out = pl.DataFrame.deserialize(file_path)

assert_frame_equal(df, out, categorical_as_str=True)


def test_write_json_to_string() -> None:
# Tests if it runs if no arg given
df = pl.DataFrame({"a": [1, 2, 3]})
expected_str = '{"columns":[{"name":"a","datatype":"Int64","bit_settings":"","values":[1,2,3]}]}'
assert df.serialize() == expected_str


def test_write_json(df: pl.DataFrame) -> None:
# Text-based conversion loses time info
df = df.select(pl.all().exclude(["cat", "time"]))
s = df.serialize()
f = io.BytesIO()
f.write(s.encode())
f.seek(0)
out = pl.DataFrame.deserialize(f)
assert_frame_equal(out, df)

file = io.BytesIO()
df.serialize(file)
file.seek(0)
out = pl.DataFrame.deserialize(file)
assert_frame_equal(out, df)


def test_df_serde_enum() -> None:
dtype = pl.Enum(["foo", "bar", "ham"])
df = pl.DataFrame([pl.Series("e", ["foo", "bar", "ham"], dtype=dtype)])
buf = io.StringIO()
df.serialize(buf)
buf.seek(0)
df_in = pl.DataFrame.deserialize(buf)
assert df_in.schema["e"] == dtype


@pytest.mark.parametrize(
("data", "dtype"),
[
([[1, 2, 3], [None, None, None], [1, None, 3]], pl.Array(pl.Int32(), width=3)),
([["a", "b"], [None, None]], pl.Array(pl.Utf8, width=2)),
([[True, False, None], [None, None, None]], pl.Array(pl.Utf8, width=3)),
(
[[[1, 2, 3], [4, None, 5]], None, [[None, None, 2]]],
pl.List(pl.Array(pl.Int32(), width=3)),
),
(
[
[datetime(1991, 1, 1), datetime(1991, 1, 1), None],
[None, None, None],
],
pl.Array(pl.Datetime, width=3),
),
],
)
def test_write_read_json_array(data: Any, dtype: pl.DataType) -> None:
df = pl.DataFrame({"foo": data}, schema={"foo": dtype})
buf = io.StringIO()
df.serialize(buf)
buf.seek(0)
deserialized_df = pl.DataFrame.deserialize(buf)
assert_frame_equal(deserialized_df, df)


@pytest.mark.parametrize(
("data", "dtype"),
[
(
[
[
datetime(1997, 10, 1),
datetime(2000, 1, 2, 10, 30, 1),
],
[None, None],
],
pl.Array(pl.Datetime, width=2),
),
(
[[date(1997, 10, 1), date(2000, 1, 1)], [None, None]],
pl.Array(pl.Date, width=2),
),
(
[
[timedelta(seconds=1), timedelta(seconds=10)],
[None, None],
],
pl.Array(pl.Duration, width=2),
),
],
)
def test_write_read_json_array_logical_inner_type(
data: Any, dtype: pl.DataType
) -> None:
df = pl.DataFrame({"foo": data}, schema={"foo": dtype})
buf = io.StringIO()
df.serialize(buf)
buf.seek(0)
deserialized_df = pl.DataFrame.deserialize(buf)
assert deserialized_df.dtypes == df.dtypes
assert deserialized_df.to_dict(as_series=False) == df.to_dict(as_series=False)


def test_json_deserialize_empty_list_10458() -> None:
schema = {"LIST_OF_STRINGS": pl.List(pl.String)}
serialized_schema = pl.DataFrame(schema=schema).serialize()
df = pl.DataFrame.deserialize(io.StringIO(serialized_schema))
assert df.schema == schema


def test_df_write_json_deprecated() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
with pytest.deprecated_call():
Expand Down
Loading

0 comments on commit dd26991

Please sign in to comment.