Skip to content

Commit

Permalink
Merge branch 'branch-25.04' into cln/dtype/astype
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke authored Feb 27, 2025
2 parents d58363f + aa7f436 commit f8b2af0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
20 changes: 9 additions & 11 deletions python/cudf/cudf/core/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -1526,9 +1526,9 @@ def pivot_table(
----------
data : DataFrame
values : column name or list of column names to aggregate, optional
index : list of column names
index : scalar or list of column names
Values to group by in the rows.
columns : list of column names
columns : scalar or list of column names
Values to group by in the columns.
aggfunc : str or dict, default "mean"
If dict is passed, the key is column to aggregate
Expand Down Expand Up @@ -1562,6 +1562,11 @@ def pivot_table(
if sort is not True:
raise NotImplementedError("sort is not supported yet")

if is_scalar(index):
index = [index]
if is_scalar(columns):
columns = [columns]

keys = index + columns

values_passed = values is not None
Expand Down Expand Up @@ -1620,15 +1625,8 @@ def pivot_table(
table = table.fillna(fill_value)

# discard the top level
if values_passed and not values_multi and table._data.multiindex:
column_names = table._data.level_names[1:]
table_columns = tuple(
map(lambda column: column[1:], table._column_names)
)
table.columns = pd.MultiIndex.from_tuples(
tuples=table_columns, names=column_names
)

if values_passed and not values_multi and table._data.nlevels > 1:
table.columns = table._data.to_pandas_index.droplevel(0)
if len(index) == 0 and len(columns) > 0:
table = table.T

Expand Down
19 changes: 19 additions & 0 deletions python/cudf/cudf/tests/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,25 @@ def test_dataframe_pivot_table_simple(aggfunc, fill_value):
assert_eq(expected, actual, check_dtype=False)


@pytest.mark.parametrize("index", ["A", ["A"]])
@pytest.mark.parametrize("columns", ["C", ["C"]])
def test_pivot_table_scalar_index_columns(index, columns):
data = {
"A": ["one", "one", "two", "three"] * 6,
"B": ["A", "B", "C"] * 8,
"C": ["foo", "foo", "foo", "bar", "bar", "bar"] * 4,
"D": range(24),
"E": range(24),
}
result = cudf.DataFrame(data).pivot_table(
values="D", index=index, columns=columns, aggfunc="sum"
)
expected = pd.DataFrame(data).pivot_table(
values="D", index=index, columns=columns, aggfunc="sum"
)
assert_eq(result, expected)


def test_crosstab_simple():
a = np.array(
[
Expand Down

0 comments on commit f8b2af0

Please sign in to comment.