diff --git a/pymc_marketing/mmm/base.py b/pymc_marketing/mmm/base.py index 25d056bc..1a62dae3 100644 --- a/pymc_marketing/mmm/base.py +++ b/pymc_marketing/mmm/base.py @@ -368,21 +368,51 @@ def plot_posterior_predictive( ax: plt.Axes = None, **plt_kwargs: Any, ) -> plt.Figure: - """Plot posterior distribution from the model fit. + """ + Plot the posterior predictive distribution from the model fit. + + This function creates a visualization of the model's posterior predictive distribution, + allowing for comparison with observed data. It can include highest density intervals (HDI), + mean predictions, and a gradient representation of the full distribution. Parameters ---------- original_scale : bool, optional - Whether to plot in the original scale. + If True, plot in the original scale of the target variable. + If False, plot in the transformed scale used for modeling. Default is False. + add_hdi : bool, optional + If True, add highest density intervals to the plot. Default is True. + add_mean : bool, optional + If True, add the mean prediction to the plot. Default is True. + add_gradient : bool, optional + If True, add a gradient representation of the full posterior distribution. Default is False. ax : plt.Axes, optional - Matplotlib axis object. - **plt_kwargs - Keyword arguments passed to `plt.subplots`. + A matplotlib Axes object to plot on. If None, a new figure and axes will be created. + **plt_kwargs : dict + Additional keyword arguments to pass to plt.subplots() when creating a new figure. Returns ------- plt.Figure + The matplotlib Figure object containing the plot. + Raises + ------ + ValueError + If the length of the target variable doesn't match the length + of the date column in the posterior predictive data. + + Notes + ----- + This function visualizes the model's predictions against the observed data. + The observed data is always plotted as a black line. + Depending on the parameters, it can also show: + - HDI (Highest Density Intervals) at 94% and 50% levels + - Mean prediction line + - Gradient representation of the full posterior distribution + + If predicting out-of-sample, ensure that `self.y` is overwritten with the + corresponding non-transformed target variable. """ posterior_predictive_data: Dataset = self._get_posterior_predictive_data( original_scale=original_scale