Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

michaelis_menten transformation as pt.TensorVariable #1054

Merged
merged 8 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/notebooks/mmm/mmm_components.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
"\n",
"alpha = 1\n",
"lam = 1 / 10\n",
"yy = saturation.function(xx, alpha=alpha, lam=lam)\n",
"yy = saturation.function(xx, alpha=alpha, lam=lam).eval()\n",
"\n",
"fig, ax = plt.subplots()\n",
"fig.suptitle(\"Example Saturation Curve\")\n",
Expand Down
31 changes: 28 additions & 3 deletions pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,10 +826,10 @@ def tanh_saturation_baselined(
return gain * x0 * pt.tanh(x * pt.arctanh(r) / x0) / r


def michaelis_menten(
def michaelis_menten_function(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only used in the functions we will deprecate, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only used in the functions we will deprecate, right?

Yes, they are only used in the methods linked in #1055.
If we remove those, there is no need to define the extra function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind closing that now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @wd60622,

having doubts about completely removing functions in #1055.

As an example, take estimate_menten_parameters. It was introduced in #329, and never used outside of the current test. I am guessing the original purpose was to double check that the introduced function gives correct results. If we remove the tests, we would never be checking that again.

As alternative I've simplified the PR to wrap only MichaelisMentenSaturation.function with pt.as_tensor_variable.

Note that the newly introduced fixture now passes the tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we need to remove tests, that is fine.

This is only saturation transformation which needs to be defined with two functions (which I find silly). Just cosolidate

x: float | np.ndarray | npt.NDArray[np.float64],
alpha: float | np.ndarray | npt.NDArray[np.float64],
lam: float | np.ndarray | npt.NDArray[np.float64],
alpha: float,
lam: float,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't know the best type hints here. But TensorVariable is also accepted here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't know the best type hints here. But TensorVariable is also accepted here

Reverted changes to type hints

) -> float | Any:
r"""Evaluate the Michaelis-Menten function for given values of x, alpha, and lambda.

Expand Down Expand Up @@ -914,6 +914,31 @@ def michaelis_menten(
return alpha * x / (lam + x)


def michaelis_menten(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need two functions?

Copy link
Contributor

@wd60622 wd60622 Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can all of the SaturationTransformations be wrapped in as_tensor_variable instead of making these changes?

ie. here:

return self.function(x, **kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need two functions?

We don't need them, but made my life easier to avoid the pesky L70

Can all of the SaturationTransformations be wrapped in as_tensor_variable instead of making these changes?
Do you have anything in mind? I can think of 2 ways:

  • Adding boilerplate to all the classes. Perhaps defining a decorator to reduce boilerplate.
  • Rely on Transformation._checks() but here we would need to trust the type hints to use something like signature(class_function).return_annotation and wrap pt.as_tensor_variable if needed.

Perhaps you have something less convoluted in mind?

Copy link
Contributor

@wd60622 wd60622 Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you not add it in the one location of the apply method? (where I linked) That is a common method that is not overwritten by the subclasses

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having it for apply would make sure that other custom saturation transformations will not encounter this found bug as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't hurt to has pt.as_tensor_variable wrapping an already TensorVariable. pytensor will figure out an optimization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you not add it in the one location of the apply method? (where I linked) That is a common method that is not overwritten by the subclasses

The linked apply method needs to be called within a model context. Not the case in the plotting function.

Copy link
Contributor Author

@PabloRoque PabloRoque Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposed sample_curve might work though, since it opens a model context. Having a look

Copy link
Contributor Author

@PabloRoque PabloRoque Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to add another pm.Deterministic to the model?
It would be the case if we use _sample_curve

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We wouldnt want to use the private methods of the saturation transformation. Does sample_curve with fit_result not get us close to the data that is plotted? (Just not scaled on x and y)

x: float | np.ndarray | npt.NDArray[np.float64],
alpha: float,
lam: float,
) -> pt.TensorVariable:
r"""TensorVariable wrap over the Michaelis-Menten transformation.

Parameters
----------
x : float
The spent on a channel.
alpha : float
The maximum contribution a channel can make.
lam : float
The Michaelis constant for the given enzyme-substrate system.

Returns
-------
pt.TensorVariable
The value of the Michaelis-Menten function given the parameters as a TensorVariable.

"""
return pt.as_tensor_variable(michaelis_menten_function(x, alpha, lam))


def hill_function(
x: pt.TensorLike, slope: pt.TensorLike, kappa: pt.TensorLike
) -> pt.TensorVariable:
Expand Down
6 changes: 4 additions & 2 deletions pymc_marketing/mmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import xarray as xr
from scipy.optimize import curve_fit, minimize_scalar

from pymc_marketing.mmm.transformers import michaelis_menten
from pymc_marketing.mmm.transformers import michaelis_menten_function


def estimate_menten_parameters(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this still used? I thought there was a generalization of this? @cetagostini

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created an issue here: #1055

Expand Down Expand Up @@ -67,7 +67,9 @@ def estimate_menten_parameters(
# Initial guess for L and k
initial_guess = [alpha_initial_estimate, lam_initial_estimate]
# Curve fitting
popt, _ = curve_fit(michaelis_menten, x, y, p0=initial_guess, maxfev=maxfev)
popt, _ = curve_fit(
michaelis_menten_function, x, y, p0=initial_guess, maxfev=maxfev
)

# Save the parameters
return popt
Expand Down
2 changes: 1 addition & 1 deletion tests/mmm/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def test_tanh_saturation_parameterization_transformation(self, x, b, c):
],
)
def test_michaelis_menten(self, x, alpha, lam, expected):
assert np.isclose(michaelis_menten(x, alpha, lam), expected, atol=0.01)
assert np.isclose(michaelis_menten(x, alpha, lam).eval(), expected, atol=0.01)

@pytest.mark.parametrize(
"sigma, beta, lam",
Expand Down
Loading