Skip to content

Commit

Permalink
Add saturation fixture
Browse files Browse the repository at this point in the history
  • Loading branch information
PabloRoque committed Sep 20, 2024
1 parent 0c26c0e commit 245f771
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions tests/mmm/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from matplotlib import pyplot as plt

from pymc_marketing.mmm.components.adstock import GeometricAdstock
from pymc_marketing.mmm.components.saturation import LogisticSaturation
from pymc_marketing.mmm.components.saturation import (
LogisticSaturation,
MichaelisMentenSaturation,
)
from pymc_marketing.mmm.mmm import MMM, BaseMMM
from pymc_marketing.mmm.preprocessing import MaxAbsScaleTarget

Expand Down Expand Up @@ -164,10 +167,18 @@ def test_plots(self, plotting_mmm, func_plot_name, kwargs_plot) -> None:
plt.close("all")


@pytest.fixture(
scope="module",
params=[LogisticSaturation(), MichaelisMentenSaturation()],
ids=["LogisticSaturation", "MichaelisMentenSaturation"],
)
def saturation(request):
return request.param


@pytest.fixture(scope="module")
def mock_mmm() -> MMM:
def mock_mmm(saturation) -> MMM:
adstock = GeometricAdstock(l_max=4)
saturation = LogisticSaturation()
return MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
Expand Down

0 comments on commit 245f771

Please sign in to comment.