Skip to content

Commit 24d31b7

Browse files
BUG: Fix pyarrow categoricals not working for pivot and multiindex (#61193)
* BUG: Fix bug with DataFrame.pivot and .set_index not compatible with pyarrow dictionary categoricals Relates to #53051 Code for fix taken and adapted from #59099 * TST: Add tests for faulty behavior relating to pyarrow categoricals * CLN: Fix issues reported by pre-commit hooks * TST: Fix failing tests for minimum version by ignoring obsolete deprecation warning * DOC: Add entry for bugfix to whatsnew v3.0.0 * CLN: Refactor code and clean up according to PR feedback * CLN: Refactor code and clean up according to PR feedback * CLN: Refactor tests to adress PR feedback * CLN: Refactor tests to adress PR feedback
1 parent 80795df commit 24d31b7

File tree

4 files changed

+60
-1
lines changed

4 files changed

+60
-1
lines changed

Diff for: doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ Bug fixes
639639
Categorical
640640
^^^^^^^^^^^
641641
- Bug in :func:`Series.apply` where ``nan`` was ignored for :class:`CategoricalDtype` (:issue:`59938`)
642+
- Bug in :meth:`DataFrame.pivot` and :meth:`DataFrame.set_index` raising an ``ArrowNotImplementedError`` for columns with pyarrow dictionary dtype (:issue:`53051`)
642643
- Bug in :meth:`Series.convert_dtypes` with ``dtype_backend="pyarrow"`` where empty :class:`CategoricalDtype` :class:`Series` raised an error or got converted to ``null[pyarrow]`` (:issue:`59934`)
643644
-
644645

Diff for: pandas/core/arrays/categorical.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def __init__(
452452
if isinstance(values, Index):
453453
arr = values._data._pa_array.combine_chunks()
454454
else:
455-
arr = values._pa_array.combine_chunks()
455+
arr = extract_array(values)._pa_array.combine_chunks()
456456
categories = arr.dictionary.to_pandas(types_mapper=ArrowDtype)
457457
codes = arr.indices.to_numpy()
458458
dtype = CategoricalDtype(categories, values.dtype.pyarrow_dtype.ordered)

Diff for: pandas/tests/reshape/test_pivot.py

+29
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import pandas as pd
1717
from pandas import (
18+
ArrowDtype,
1819
Categorical,
1920
DataFrame,
2021
Grouper,
@@ -2851,3 +2852,31 @@ def test_pivot_margins_with_none_index(self):
28512852
),
28522853
)
28532854
tm.assert_frame_equal(result, expected)
2855+
2856+
@pytest.mark.filterwarnings("ignore:Passing a BlockManager:DeprecationWarning")
2857+
def test_pivot_with_pyarrow_categorical(self):
2858+
# GH#53051
2859+
pa = pytest.importorskip("pyarrow")
2860+
2861+
df = DataFrame(
2862+
{"string_column": ["A", "B", "C"], "number_column": [1, 2, 3]}
2863+
).astype(
2864+
{
2865+
"string_column": ArrowDtype(pa.dictionary(pa.int32(), pa.string())),
2866+
"number_column": "float[pyarrow]",
2867+
}
2868+
)
2869+
2870+
df = df.pivot(columns=["string_column"], values=["number_column"])
2871+
2872+
multi_index = MultiIndex.from_arrays(
2873+
[["number_column", "number_column", "number_column"], ["A", "B", "C"]],
2874+
names=(None, "string_column"),
2875+
)
2876+
df_expected = DataFrame(
2877+
[[1.0, np.nan, np.nan], [np.nan, 2.0, np.nan], [np.nan, np.nan, 3.0]],
2878+
columns=multi_index,
2879+
)
2880+
tm.assert_frame_equal(
2881+
df, df_expected, check_dtype=False, check_column_type=False
2882+
)

Diff for: pandas/tests/test_multilevel.py

+29
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pandas as pd
77
from pandas import (
8+
ArrowDtype,
89
DataFrame,
910
MultiIndex,
1011
Series,
@@ -318,6 +319,34 @@ def test_multiindex_dt_with_nan(self):
318319
expected = Series(["a", "b", "c", "d"], name=("sub", np.nan))
319320
tm.assert_series_equal(result, expected)
320321

322+
@pytest.mark.filterwarnings("ignore:Passing a BlockManager:DeprecationWarning")
323+
def test_multiindex_with_pyarrow_categorical(self):
324+
# GH#53051
325+
pa = pytest.importorskip("pyarrow")
326+
327+
df = DataFrame(
328+
{"string_column": ["A", "B", "C"], "number_column": [1, 2, 3]}
329+
).astype(
330+
{
331+
"string_column": ArrowDtype(pa.dictionary(pa.int32(), pa.string())),
332+
"number_column": "float[pyarrow]",
333+
}
334+
)
335+
336+
df = df.set_index(["string_column", "number_column"])
337+
338+
df_expected = DataFrame(
339+
index=MultiIndex.from_arrays(
340+
[["A", "B", "C"], [1, 2, 3]], names=["string_column", "number_column"]
341+
)
342+
)
343+
tm.assert_frame_equal(
344+
df,
345+
df_expected,
346+
check_index_type=False,
347+
check_column_type=False,
348+
)
349+
321350

322351
class TestSorted:
323352
"""everything you wanted to test about sorting"""

0 commit comments

Comments
 (0)