Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backport 2.3.x] BUG/TST (string dtype): fix and update tests for Stata IO (#60130) #60155

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,11 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
if getattr(data[col].dtype, "numpy_dtype", None) is not None:
data[col] = data[col].astype(data[col].dtype.numpy_dtype)
elif is_string_dtype(data[col].dtype):
# TODO could avoid converting string dtype to object here,
# but handle string dtype in _encode_strings
data[col] = data[col].astype("object")
# generate_table checks for None values
data.loc[data[col].isna(), col] = None

dtype = data[col].dtype
empty_df = data.shape[0] == 0
Expand Down Expand Up @@ -2671,6 +2675,7 @@ def _encode_strings(self) -> None:
continue
column = self.data[col]
dtype = column.dtype
# TODO could also handle string dtype here specifically
if dtype.type is np.object_:
inferred_dtype = infer_dtype(column, skipna=True)
if not ((inferred_dtype == "string") or len(column) == 0):
Expand Down
82 changes: 43 additions & 39 deletions pandas/tests/io/test_stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

import pandas.util._test_decorators as td

import pandas as pd
Expand Down Expand Up @@ -347,9 +345,8 @@ def test_write_dta6(self, datapath):
check_index_type=False,
)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
def test_read_write_dta10(self, version):
def test_read_write_dta10(self, version, using_infer_string):
original = DataFrame(
data=[["string", "object", 1, 1.1, np.datetime64("2003-12-25")]],
columns=["string", "object", "integer", "floating", "datetime"],
Expand All @@ -362,12 +359,17 @@ def test_read_write_dta10(self, version):
with tm.ensure_clean() as path:
original.to_stata(path, convert_dates={"datetime": "tc"}, version=version)
written_and_read_again = self.read_dta(path)
# original.index is np.int32, read index is np.int64
tm.assert_frame_equal(
written_and_read_again.set_index("index"),
original,
check_index_type=False,
)

expected = original.copy()
if using_infer_string:
expected["object"] = expected["object"].astype("str")

# original.index is np.int32, read index is np.int64
tm.assert_frame_equal(
written_and_read_again.set_index("index"),
expected,
check_index_type=False,
)

def test_stata_doc_examples(self):
with tm.ensure_clean() as path:
Expand Down Expand Up @@ -1153,7 +1155,6 @@ def test_categorical_ordering(self, file, datapath):
assert parsed[col].cat.ordered
assert not parsed_unordered[col].cat.ordered

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize(
"file",
Expand Down Expand Up @@ -1215,6 +1216,10 @@ def _convert_categorical(from_frame: DataFrame) -> DataFrame:
if cat.categories.dtype == object:
categories = pd.Index._with_infer(cat.categories._values)
cat = cat.set_categories(categories)
elif cat.categories.dtype == "string" and len(cat.categories) == 0:
# if the read categories are empty, it comes back as object dtype
categories = cat.categories.astype(object)
cat = cat.set_categories(categories)
from_frame[col] = cat
return from_frame

Expand Down Expand Up @@ -1244,7 +1249,6 @@ def test_iterator(self, datapath):
from_chunks = pd.concat(itr)
tm.assert_frame_equal(parsed, from_chunks)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize(
"file",
Expand Down Expand Up @@ -1548,12 +1552,11 @@ def test_inf(self, infval):
with tm.ensure_clean() as path:
df.to_stata(path)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_path_pathlib(self):
df = DataFrame(
1.1 * np.arange(120).reshape((30, 4)),
columns=pd.Index(list("ABCD"), dtype=object),
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
columns=pd.Index(list("ABCD")),
index=pd.Index([f"i-{i}" for i in range(30)]),
)
df.index.name = "index"
reader = lambda x: read_stata(x).set_index("index")
Expand Down Expand Up @@ -1584,13 +1587,12 @@ def test_value_labels_iterator(self, write_index):
value_labels = dta_iter.value_labels()
assert value_labels == {"A": {0: "A", 1: "B", 2: "C", 3: "E"}}

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_set_index(self):
# GH 17328
df = DataFrame(
1.1 * np.arange(120).reshape((30, 4)),
columns=pd.Index(list("ABCD"), dtype=object),
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
columns=pd.Index(list("ABCD")),
index=pd.Index([f"i-{i}" for i in range(30)]),
)
df.index.name = "index"
with tm.ensure_clean() as path:
Expand Down Expand Up @@ -1618,8 +1620,7 @@ def test_date_parsing_ignores_format_details(self, column, datapath):
formatted = df.loc[0, column + "_fmt"]
assert unformatted == formatted

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_writer_117(self):
def test_writer_117(self, using_infer_string):
original = DataFrame(
data=[
[
Expand Down Expand Up @@ -1682,13 +1683,17 @@ def test_writer_117(self):
version=117,
)
written_and_read_again = self.read_dta(path)
# original.index is np.int32, read index is np.int64
tm.assert_frame_equal(
written_and_read_again.set_index("index"),
original,
check_index_type=False,
)
tm.assert_frame_equal(original, copy)

expected = original[:]
if using_infer_string:
# object dtype (with only strings/None) comes back as string dtype
expected["object"] = expected["object"].astype("str")

tm.assert_frame_equal(
written_and_read_again.set_index("index"),
expected,
)
tm.assert_frame_equal(original, copy)

def test_convert_strl_name_swap(self):
original = DataFrame(
Expand Down Expand Up @@ -1725,15 +1730,14 @@ def test_invalid_date_conversion(self):
with pytest.raises(ValueError, match=msg):
original.to_stata(path, convert_dates={"wrong_name": "tc"})

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
def test_nonfile_writing(self, version):
# GH 21041
bio = io.BytesIO()
df = DataFrame(
1.1 * np.arange(120).reshape((30, 4)),
columns=pd.Index(list("ABCD"), dtype=object),
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
columns=pd.Index(list("ABCD")),
index=pd.Index([f"i-{i}" for i in range(30)]),
)
df.index.name = "index"
with tm.ensure_clean() as path:
Expand All @@ -1744,13 +1748,12 @@ def test_nonfile_writing(self, version):
reread = read_stata(path, index_col="index")
tm.assert_frame_equal(df, reread)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_gzip_writing(self):
# writing version 117 requires seek and cannot be used with gzip
df = DataFrame(
1.1 * np.arange(120).reshape((30, 4)),
columns=pd.Index(list("ABCD"), dtype=object),
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
columns=pd.Index(list("ABCD")),
index=pd.Index([f"i-{i}" for i in range(30)]),
)
df.index.name = "index"
with tm.ensure_clean() as path:
Expand All @@ -1777,8 +1780,7 @@ def test_unicode_dta_118(self, datapath):

tm.assert_frame_equal(unicode_df, expected)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_mixed_string_strl(self):
def test_mixed_string_strl(self, using_infer_string):
# GH 23633
output = [{"mixed": "string" * 500, "number": 0}, {"mixed": None, "number": 1}]
output = DataFrame(output)
Expand All @@ -1796,7 +1798,10 @@ def test_mixed_string_strl(self):
path, write_index=False, convert_strl=["mixed"], version=117
)
reread = read_stata(path)
expected = output.fillna("")
expected = output.copy()
if using_infer_string:
expected["mixed"] = expected["mixed"].astype("str")
expected = expected.fillna("")
tm.assert_frame_equal(reread, expected)

@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
Expand Down Expand Up @@ -1875,7 +1880,7 @@ def test_stata_119(self, datapath):
reader._ensure_open()
assert reader._nvar == 32999

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
@pytest.mark.filterwarnings("ignore:Downcasting behavior:FutureWarning")
@pytest.mark.parametrize("version", [118, 119, None])
def test_utf8_writer(self, version):
cat = pd.Categorical(["a", "β", "ĉ"], ordered=True)
Expand Down Expand Up @@ -2143,14 +2148,13 @@ def test_iterator_errors(datapath, chunksize):
pass


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_iterator_value_labels():
# GH 31544
values = ["c_label", "b_label"] + ["a_label"] * 500
df = DataFrame({f"col{k}": pd.Categorical(values, ordered=True) for k in range(2)})
with tm.ensure_clean() as path:
df.to_stata(path, write_index=False)
expected = pd.Index(["a_label", "b_label", "c_label"], dtype="object")
expected = pd.Index(["a_label", "b_label", "c_label"])
with read_stata(path, chunksize=100) as reader:
for j, chunk in enumerate(reader):
for i in range(2):
Expand Down
Loading