Skip to content

Commit

Permalink
Implement {Series,DataFrame}GroupBy fillna methods (#8869)
Browse files Browse the repository at this point in the history
Co-authored-by: Ian Rose <[email protected]>
  • Loading branch information
pavithraes and ian-r-rose authored May 10, 2022
1 parent 6685666 commit 5fbda77
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 0 deletions.
73 changes: 73 additions & 0 deletions dask/dataframe/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd

from dask.base import tokenize
from dask.dataframe._compat import PANDAS_GT_150
from dask.dataframe.core import (
GROUP_KEYS_DEFAULT,
DataFrame,
Expand Down Expand Up @@ -1047,6 +1048,13 @@ def _cumcount_aggregate(a, b, fill_value=None):
return a.add(b, fill_value=fill_value) + 1


def _fillna_group(group, by, value, method, limit, fillna_axis):
# apply() conserves the grouped-by columns, so drop them to stay consistent with pandas groupby-fillna
return group.drop(columns=by).fillna(
value=value, method=method, limit=limit, axis=fillna_axis
)


class _GroupBy:
"""Superclass for DataFrameGroupBy and SeriesGroupBy
Expand Down Expand Up @@ -2008,6 +2016,71 @@ def rolling(self, window, min_periods=None, center=False, win_type=None, axis=0)
axis=axis,
)

def fillna(self, value=None, method=None, limit=None, axis=None):
"""Fill NA/NaN values using the specified method.
Parameters
----------
value : scalar, default None
Value to use to fill holes (e.g. 0).
method : {'bfill', 'ffill', None}, default None
Method to use for filling holes in reindexed Series. ffill: propagate last
valid observation forward to next valid. bfill: use next valid observation
to fill gap.
axis : {0 or 'index', 1 or 'columns'}
Axis along which to fill missing values.
limit : int, default None
If method is specified, this is the maximum number of consecutive NaN values
to forward/backward fill. In other words, if there is a gap with more than
this number of consecutive NaNs, it will only be partially filled. If method
is not specified, this is the maximum number of entries along the entire
axis where NaNs will be filled. Must be greater than 0 if not None.
Returns
-------
Series or DataFrame
Object with missing values filled
See also
--------
pandas.core.groupby.DataFrameGroupBy.fillna
"""
if not np.isscalar(value) and value is not None:
raise NotImplementedError(
"groupby-fillna with value=dict/Series/DataFrame is currently not supported"
)
meta = self._meta_nonempty.apply(
_fillna_group,
by=self.by,
value=value,
method=method,
limit=limit,
fillna_axis=axis,
)

result = self.apply(
_fillna_group,
by=self.by,
value=value,
method=method,
limit=limit,
fillna_axis=axis,
meta=meta,
)

if PANDAS_GT_150 and self.group_keys:
return result.map_partitions(M.droplevel, self.by)

return result

@derived_from(pd.core.groupby.GroupBy)
def ffill(self, limit=None):
return self.fillna(method="ffill", limit=limit)

@derived_from(pd.core.groupby.GroupBy)
def bfill(self, limit=None):
return self.fillna(method="bfill", limit=limit)


class DataFrameGroupBy(_GroupBy):
_token_prefix = "dataframe-groupby-"
Expand Down
104 changes: 104 additions & 0 deletions dask/dataframe/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,110 @@ def test_aggregate_dask():
assert len(other.dask) == len(result2.dask)


@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.parametrize("group_keys", [True, False, None])
@pytest.mark.parametrize("method", ["ffill", "bfill"])
@pytest.mark.parametrize("limit", [None, 1, 4])
def test_fillna(axis, group_keys, method, limit):
df = pd.DataFrame(
{
"A": [1, 1, 2, 2],
"B": [3, 4, 3, 4],
"C": [np.nan, 3, np.nan, np.nan],
"D": [4, np.nan, 5, np.nan],
"E": [6, np.nan, 7, np.nan],
}
)
ddf = dd.from_pandas(df, npartitions=2)
assert_eq(
df.groupby("A", group_keys=group_keys).fillna(0, axis=axis),
ddf.groupby("A", group_keys=group_keys).fillna(0, axis=axis),
)
assert_eq(
df.groupby("A", group_keys=group_keys).B.fillna(0),
ddf.groupby("A", group_keys=group_keys).B.fillna(0),
)
assert_eq(
df.groupby(["A", "B"], group_keys=group_keys).fillna(0),
ddf.groupby(["A", "B"], group_keys=group_keys).fillna(0),
)
assert_eq(
df.groupby("A", group_keys=group_keys).fillna(
method=method, limit=limit, axis=axis
),
ddf.groupby("A", group_keys=group_keys).fillna(
method=method, limit=limit, axis=axis
),
)
assert_eq(
df.groupby(["A", "B"], group_keys=group_keys).fillna(
method=method, limit=limit, axis=axis
),
ddf.groupby(["A", "B"], group_keys=group_keys).fillna(
method=method, limit=limit, axis=axis
),
)

with pytest.raises(NotImplementedError):
ddf.groupby("A").fillna({"A": 0})

with pytest.raises(NotImplementedError):
ddf.groupby("A").fillna(pd.Series(dtype=int))

with pytest.raises(NotImplementedError):
ddf.groupby("A").fillna(pd.DataFrame)


def test_ffill():
df = pd.DataFrame(
{
"A": [1, 1, 2, 2],
"B": [3, 4, 3, 4],
"C": [np.nan, 3, np.nan, np.nan],
"D": [4, np.nan, 5, np.nan],
"E": [6, np.nan, 7, np.nan],
}
)
ddf = dd.from_pandas(df, npartitions=2)
assert_eq(
df.groupby("A").ffill(),
ddf.groupby("A").ffill(),
)
assert_eq(
df.groupby("A").B.ffill(),
ddf.groupby("A").B.ffill(),
)
assert_eq(
df.groupby(["A", "B"]).ffill(),
ddf.groupby(["A", "B"]).ffill(),
)


def test_bfill():
df = pd.DataFrame(
{
"A": [1, 1, 2, 2],
"B": [3, 4, 3, 4],
"C": [np.nan, 3, np.nan, np.nan],
"D": [np.nan, 4, np.nan, 5],
"E": [np.nan, 6, np.nan, 7],
}
)
ddf = dd.from_pandas(df, npartitions=2)
assert_eq(
df.groupby("A").bfill(),
ddf.groupby("A").bfill(),
)
assert_eq(
df.groupby("A").B.bfill(),
ddf.groupby("A").B.bfill(),
)
assert_eq(
df.groupby(["A", "B"]).bfill(),
ddf.groupby(["A", "B"]).bfill(),
)


@pytest.mark.parametrize(
"grouper",
[
Expand Down
6 changes: 6 additions & 0 deletions docs/source/dataframe-api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,13 @@ DataFrame Groupby

DataFrameGroupBy.aggregate
DataFrameGroupBy.apply
DataFrameGroupBy.bfill
DataFrameGroupBy.count
DataFrameGroupBy.cumcount
DataFrameGroupBy.cumprod
DataFrameGroupBy.cumsum
DataFrameGroupBy.fillna
DataFrameGroupBy.ffill
DataFrameGroupBy.get_group
DataFrameGroupBy.max
DataFrameGroupBy.mean
Expand All @@ -440,10 +443,13 @@ Series Groupby

SeriesGroupBy.aggregate
SeriesGroupBy.apply
SeriesGroupBy.bfill
SeriesGroupBy.count
SeriesGroupBy.cumcount
SeriesGroupBy.cumprod
SeriesGroupBy.cumsum
SeriesGroupBy.fillna
SeriesGroupBy.ffill
SeriesGroupBy.get_group
SeriesGroupBy.max
SeriesGroupBy.mean
Expand Down

0 comments on commit 5fbda77

Please sign in to comment.