Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

michaelis_menten transformation as pt.TensorVariable #1054

Merged
merged 8 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/notebooks/mmm/mmm_components.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
"\n",
"alpha = 1\n",
"lam = 1 / 10\n",
"yy = saturation.function(xx, alpha=alpha, lam=lam)\n",
"yy = saturation.function(xx, alpha=alpha, lam=lam).eval()\n",
"\n",
"fig, ax = plt.subplots()\n",
"fig.suptitle(\"Example Saturation Curve\")\n",
Expand Down
5 changes: 4 additions & 1 deletion pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def function(self, x, b):
"""

import numpy as np
import pytensor.tensor as pt
import xarray as xr
from pydantic import Field, InstanceOf, validate_call

Expand Down Expand Up @@ -337,7 +338,9 @@ class MichaelisMentenSaturation(SaturationTransformation):

lookup_name = "michaelis_menten"

function = michaelis_menten
def function(self, x, alpha, lam):
"""Michaelis-Menten saturation function."""
return pt.as_tensor_variable(michaelis_menten(x, alpha, lam))

default_priors = {
"alpha": Prior("Gamma", mu=2, sigma=1),
Expand Down
17 changes: 14 additions & 3 deletions tests/mmm/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from matplotlib import pyplot as plt

from pymc_marketing.mmm.components.adstock import GeometricAdstock
from pymc_marketing.mmm.components.saturation import LogisticSaturation
from pymc_marketing.mmm.components.saturation import (
LogisticSaturation,
MichaelisMentenSaturation,
)
from pymc_marketing.mmm.mmm import MMM, BaseMMM
from pymc_marketing.mmm.preprocessing import MaxAbsScaleTarget

Expand Down Expand Up @@ -220,10 +223,18 @@ def test_plots(self, plotting_mmm, func_plot_name, kwargs_plot) -> None:
plt.close("all")


@pytest.fixture(
scope="module",
params=[LogisticSaturation(), MichaelisMentenSaturation()],
ids=["LogisticSaturation", "MichaelisMentenSaturation"],
)
def saturation(request):
return request.param


@pytest.fixture(scope="module")
def mock_mmm() -> MMM:
def mock_mmm(saturation) -> MMM:
adstock = GeometricAdstock(l_max=4)
saturation = LogisticSaturation()
return MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
Expand Down
Loading