Skip to content

Commit

Permalink
updated plot_posterior_predictive docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
jsnyde0 committed Sep 20, 2024
1 parent 38bc101 commit 98cfeaf
Showing 1 changed file with 35 additions and 5 deletions.
40 changes: 35 additions & 5 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,21 +368,51 @@ def plot_posterior_predictive(
ax: plt.Axes = None,
**plt_kwargs: Any,
) -> plt.Figure:
"""Plot posterior distribution from the model fit.
"""
Plot the posterior predictive distribution from the model fit.
This function creates a visualization of the model's posterior predictive distribution,
allowing for comparison with observed data. It can include highest density intervals (HDI),
mean predictions, and a gradient representation of the full distribution.
Parameters
----------
original_scale : bool, optional
Whether to plot in the original scale.
If True, plot in the original scale of the target variable.
If False, plot in the transformed scale used for modeling. Default is False.
add_hdi : bool, optional
If True, add highest density intervals to the plot. Default is True.
add_mean : bool, optional
If True, add the mean prediction to the plot. Default is True.
add_gradient : bool, optional
If True, add a gradient representation of the full posterior distribution. Default is False.
ax : plt.Axes, optional
Matplotlib axis object.
**plt_kwargs
Keyword arguments passed to `plt.subplots`.
A matplotlib Axes object to plot on. If None, a new figure and axes will be created.
**plt_kwargs : dict
Additional keyword arguments to pass to plt.subplots() when creating a new figure.
Returns
-------
plt.Figure
The matplotlib Figure object containing the plot.
Raises
------
ValueError
If the length of the target variable doesn't match the length
of the date column in the posterior predictive data.
Notes
-----
This function visualizes the model's predictions against the observed data.
The observed data is always plotted as a black line.
Depending on the parameters, it can also show:
- HDI (Highest Density Intervals) at 94% and 50% levels
- Mean prediction line
- Gradient representation of the full posterior distribution
If predicting out-of-sample, ensure that `self.y` is overwritten with the
corresponding non-transformed target variable.
"""
posterior_predictive_data: Dataset = self._get_posterior_predictive_data(
original_scale=original_scale
Expand Down

0 comments on commit 98cfeaf

Please sign in to comment.