Skip to content

Commit

Permalink
TYP: plotting, make weights kwd explicit (pandas-dev#55877)
Browse files Browse the repository at this point in the history
* TYP: plotting

* TYP: plotting

* Make weights kwd explicit
  • Loading branch information
jbrockmendel authored Nov 8, 2023
1 parent 7ae6b8e commit a167f13
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 26 deletions.
2 changes: 1 addition & 1 deletion pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, data, return_type: str = "axes", **kwargs) -> None:
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
@classmethod
def _plot( # type: ignore[override]
cls, ax, y, column_num=None, return_type: str = "axes", **kwds
cls, ax: Axes, y, column_num=None, return_type: str = "axes", **kwds
):
if y.ndim == 2:
y = [remove_na_arraylike(v) for v in y]
Expand Down
5 changes: 4 additions & 1 deletion pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from typing import (
TYPE_CHECKING,
Any,
Literal,
final,
)
Expand Down Expand Up @@ -998,7 +999,9 @@ def on_right(self, i: int):
return self.data.columns[i] in self.secondary_y

@final
def _apply_style_colors(self, colors, kwds, col_num, label: str):
def _apply_style_colors(
self, colors, kwds: dict[str, Any], col_num: int, label: str
):
"""
Manage style and color based on column number and its label.
Returns tuple of appropriate style and kwds which "color" may be added.
Expand Down
59 changes: 35 additions & 24 deletions pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from typing import (
TYPE_CHECKING,
Any,
Literal,
final,
)

import numpy as np
Expand Down Expand Up @@ -58,13 +60,15 @@ def __init__(
bottom: int | np.ndarray = 0,
*,
range=None,
weights=None,
**kwargs,
) -> None:
if is_list_like(bottom):
bottom = np.array(bottom)
self.bottom = bottom

self._bin_range = range
self.weights = weights

self.xlabel = kwargs.get("xlabel")
self.ylabel = kwargs.get("ylabel")
Expand Down Expand Up @@ -96,7 +100,7 @@ def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray:
@classmethod
def _plot( # type: ignore[override]
cls,
ax,
ax: Axes,
y,
style=None,
bottom: int | np.ndarray = 0,
Expand Down Expand Up @@ -140,7 +144,7 @@ def _make_plot(self, fig: Figure) -> None:
if style is not None:
kwds["style"] = style

kwds = self._make_plot_keywords(kwds, y)
self._make_plot_keywords(kwds, y)

# the bins is multi-dimension array now and each plot need only 1-d and
# when by is applied, label should be columns that are grouped
Expand All @@ -149,21 +153,8 @@ def _make_plot(self, fig: Figure) -> None:
kwds["label"] = self.columns
kwds.pop("color")

# We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
# and each sub-array (10,) will be called in each iteration. If users only
# provide 1D array, we assume the same weights is used for all iterations
weights = kwds.get("weights", None)
if weights is not None:
if np.ndim(weights) != 1 and np.shape(weights)[-1] != 1:
try:
weights = weights[:, i]
except IndexError as err:
raise ValueError(
"weights must have the same shape as data, "
"or be a single column"
) from err
weights = weights[~isna(y)]
kwds["weights"] = weights
if self.weights is not None:
kwds["weights"] = self._get_column_weights(self.weights, i, y)

y = reformat_hist_y_given_by(y, self.by)

Expand All @@ -175,12 +166,29 @@ def _make_plot(self, fig: Figure) -> None:

self._append_legend_handles_labels(artists[0], label)

def _make_plot_keywords(self, kwds, y):
def _make_plot_keywords(self, kwds: dict[str, Any], y) -> None:
"""merge BoxPlot/KdePlot properties to passed kwds"""
# y is required for KdePlot
kwds["bottom"] = self.bottom
kwds["bins"] = self.bins
return kwds

@final
@staticmethod
def _get_column_weights(weights, i: int, y):
# We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
# and each sub-array (10,) will be called in each iteration. If users only
# provide 1D array, we assume the same weights is used for all iterations
if weights is not None:
if np.ndim(weights) != 1 and np.shape(weights)[-1] != 1:
try:
weights = weights[:, i]
except IndexError as err:
raise ValueError(
"weights must have the same shape as data, "
"or be a single column"
) from err
weights = weights[~isna(y)]
return weights

def _post_plot_logic(self, ax: Axes, data) -> None:
if self.orientation == "horizontal":
Expand All @@ -207,11 +215,14 @@ def _kind(self) -> Literal["kde"]:
def orientation(self) -> Literal["vertical"]:
return "vertical"

def __init__(self, data, bw_method=None, ind=None, **kwargs) -> None:
def __init__(
self, data, bw_method=None, ind=None, *, weights=None, **kwargs
) -> None:
# Do not call LinePlot.__init__ which may fill nan
MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
self.bw_method = bw_method
self.ind = ind
self.weights = weights

@staticmethod
def _get_ind(y, ind):
Expand All @@ -233,9 +244,10 @@ def _get_ind(y, ind):
return ind

@classmethod
def _plot(
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
def _plot( # type: ignore[override]
cls,
ax,
ax: Axes,
y,
style=None,
bw_method=None,
Expand All @@ -253,10 +265,9 @@ def _plot(
lines = MPLPlot._plot(ax, ind, y, style=style, **kwds)
return lines

def _make_plot_keywords(self, kwds, y):
def _make_plot_keywords(self, kwds: dict[str, Any], y) -> None:
kwds["bw_method"] = self.bw_method
kwds["ind"] = self._get_ind(y, ind=self.ind)
return kwds

def _post_plot_logic(self, ax, data) -> None:
ax.set_ylabel("Density")
Expand Down

0 comments on commit a167f13

Please sign in to comment.