From aa7f436bdc22fb5b25903252c437e32fbc8b33c0 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 26 Feb 2025 18:55:25 -0800 Subject: [PATCH] Allow pivot_table to accept single label index and column arguments (#18115) closes https://github.com/rapidsai/cudf/issues/12410 closes https://github.com/rapidsai/cudf/issues/12409 The fix just mirrors the pandas logic. Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/18115 --- python/cudf/cudf/core/reshape.py | 20 +++++++++----------- python/cudf/cudf/tests/test_reshape.py | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/python/cudf/cudf/core/reshape.py b/python/cudf/cudf/core/reshape.py index c5d2fd349e9..7d76907916f 100644 --- a/python/cudf/cudf/core/reshape.py +++ b/python/cudf/cudf/core/reshape.py @@ -1526,9 +1526,9 @@ def pivot_table( ---------- data : DataFrame values : column name or list of column names to aggregate, optional - index : list of column names + index : scalar or list of column names Values to group by in the rows. - columns : list of column names + columns : scalar or list of column names Values to group by in the columns. aggfunc : str or dict, default "mean" If dict is passed, the key is column to aggregate @@ -1562,6 +1562,11 @@ def pivot_table( if sort is not True: raise NotImplementedError("sort is not supported yet") + if is_scalar(index): + index = [index] + if is_scalar(columns): + columns = [columns] + keys = index + columns values_passed = values is not None @@ -1620,15 +1625,8 @@ def pivot_table( table = table.fillna(fill_value) # discard the top level - if values_passed and not values_multi and table._data.multiindex: - column_names = table._data.level_names[1:] - table_columns = tuple( - map(lambda column: column[1:], table._column_names) - ) - table.columns = pd.MultiIndex.from_tuples( - tuples=table_columns, names=column_names - ) - + if values_passed and not values_multi and table._data.nlevels > 1: + table.columns = table._data.to_pandas_index.droplevel(0) if len(index) == 0 and len(columns) > 0: table = table.T diff --git a/python/cudf/cudf/tests/test_reshape.py b/python/cudf/cudf/tests/test_reshape.py index 7fbe072dde7..eae73e47955 100644 --- a/python/cudf/cudf/tests/test_reshape.py +++ b/python/cudf/cudf/tests/test_reshape.py @@ -798,6 +798,25 @@ def test_dataframe_pivot_table_simple(aggfunc, fill_value): assert_eq(expected, actual, check_dtype=False) +@pytest.mark.parametrize("index", ["A", ["A"]]) +@pytest.mark.parametrize("columns", ["C", ["C"]]) +def test_pivot_table_scalar_index_columns(index, columns): + data = { + "A": ["one", "one", "two", "three"] * 6, + "B": ["A", "B", "C"] * 8, + "C": ["foo", "foo", "foo", "bar", "bar", "bar"] * 4, + "D": range(24), + "E": range(24), + } + result = cudf.DataFrame(data).pivot_table( + values="D", index=index, columns=columns, aggfunc="sum" + ) + expected = pd.DataFrame(data).pivot_table( + values="D", index=index, columns=columns, aggfunc="sum" + ) + assert_eq(result, expected) + + def test_crosstab_simple(): a = np.array( [