From 78db043c460de2d675584b844ca2402a564132bf Mon Sep 17 00:00:00 2001 From: Anant Pandey <67699077+anantp316@users.noreply.github.com> Date: Fri, 22 Mar 2024 17:46:39 +0530 Subject: [PATCH 1/2] Added functionality that allows user to save figure generated by the `plot_media_baseline_contribution_area_plot` function. --- lightweight_mmm/plot.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lightweight_mmm/plot.py b/lightweight_mmm/plot.py index cd1c74f..e330c29 100644 --- a/lightweight_mmm/plot.py +++ b/lightweight_mmm/plot.py @@ -1014,6 +1014,7 @@ def plot_media_baseline_contribution_area_plot( channel_names: Optional[Sequence[Any]] = None, fig_size: Optional[Tuple[int, int]] = (20, 7), legend_outside: Optional[bool] = False, + save_path: Optional[str] = None ) -> matplotlib.figure.Figure: """Plots an area chart to visualize weekly media & baseline contribution. @@ -1023,6 +1024,7 @@ def plot_media_baseline_contribution_area_plot( channel_names: Names of media channels. fig_size: Size of the figure to plot as used by matplotlib. legend_outside: Put the legend outside of the chart, center-right. + save_path: Path to save the plotted figure. Returns: Stacked area chart of weekly baseline & media contribution. @@ -1072,6 +1074,11 @@ def plot_media_baseline_contribution_area_plot( for tick in ax.get_xticklabels(): tick.set_rotation(45) + + # Save the plot if save_path is provided + if save_path: + fig.savefig(save_path, bbox_inches="tight") + plt.close() return fig From 4c3a8598298dd151d3200c1c5e365a46bedd418b Mon Sep 17 00:00:00 2001 From: Anant Pandey <67699077+anantp316@users.noreply.github.com> Date: Fri, 22 Mar 2024 19:32:30 +0530 Subject: [PATCH 2/2] Added functionality that allows user to save figure generated by the `plot_media_baseline_contribution_area_plot` and `plot_pre_post_budget_allocation_comparison` functions. --- lightweight_mmm/plot.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/lightweight_mmm/plot.py b/lightweight_mmm/plot.py index e330c29..61b2b55 100644 --- a/lightweight_mmm/plot.py +++ b/lightweight_mmm/plot.py @@ -890,7 +890,8 @@ def plot_pre_post_budget_allocation_comparison( optimal_buget_allocation: jnp.ndarray, previous_budget_allocation: jnp.ndarray, channel_names: Optional[Sequence[Any]] = None, - figure_size: Tuple[int, int] = (20, 10) + figure_size: Tuple[int, int] = (20, 10), + save_path: Optional[str] = None ) -> matplotlib.figure.Figure: """Plots a barcharts to compare pre & post budget allocation. @@ -905,6 +906,7 @@ def plot_pre_post_budget_allocation_comparison( budget allocation proportion. channel_names: Names of media channels to be added to plot. figure_size: size of the plot. + save_path: Path to save the plotted figure. Returns: Barplots of budget allocation across media channels pre & post optimization. @@ -1004,6 +1006,11 @@ def plot_pre_post_budget_allocation_comparison( textcoords="offset points") plt.tight_layout() + + # Save the plot if save_path is provided + if save_path: + fig.savefig(save_path, bbox_inches="tight") + plt.close() return fig