Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

michaelis_menten transformation as pt.TensorVariable #1054

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

PabloRoque
Copy link
Contributor

@PabloRoque PabloRoque commented Sep 20, 2024

Make michaelis_meten consistent with the rest of SaturationTransformation by wrapping it as pt.TensorVariable

Description

  • Rename old michaelis_menten as michaelis_menten_function. Made to avoid major changes in mmm.utils.estimate_menten_parameters's L70.
  • Wrap previous michaelis_menten_function as pt.TensorVariable inside michaelis_menten
  • Minor change in test_michaelis_menten. The TensorVariable now needs to be evaluated.
  • Minor changes to type hints
  • .eval() in example notebook

Related Issue

Checklist

Modules affected

  • MMM
  • CLV

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc-marketing--1054.org.readthedocs.build/en/1054/

@PabloRoque PabloRoque changed the title michaelis_menten tranformation as pt.TensorVariable michaelis_menten transformation as pt.TensorVariable Sep 20, 2024
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@@ -914,6 +914,31 @@ def michaelis_menten(
return alpha * x / (lam + x)


def michaelis_menten(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need two functions?

Copy link
Contributor

@wd60622 wd60622 Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can all of the SaturationTransformations be wrapped in as_tensor_variable instead of making these changes?

ie. here:

return self.function(x, **kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need two functions?

We don't need them, but made my life easier to avoid the pesky L70

Can all of the SaturationTransformations be wrapped in as_tensor_variable instead of making these changes?
Do you have anything in mind? I can think of 2 ways:

  • Adding boilerplate to all the classes. Perhaps defining a decorator to reduce boilerplate.
  • Rely on Transformation._checks() but here we would need to trust the type hints to use something like signature(class_function).return_annotation and wrap pt.as_tensor_variable if needed.

Perhaps you have something less convoluted in mind?

Copy link
Contributor

@wd60622 wd60622 Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you not add it in the one location of the apply method? (where I linked) That is a common method that is not overwritten by the subclasses

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having it for apply would make sure that other custom saturation transformations will not encounter this found bug as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't hurt to has pt.as_tensor_variable wrapping an already TensorVariable. pytensor will figure out an optimization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you not add it in the one location of the apply method? (where I linked) That is a common method that is not overwritten by the subclasses

The linked apply method needs to be called within a model context. Not the case in the plotting function.

Copy link
Contributor Author

@PabloRoque PabloRoque Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposed sample_curve might work though, since it opens a model context. Having a look

Copy link
Contributor Author

@PabloRoque PabloRoque Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to add another pm.Deterministic to the model?
It would be the case if we use _sample_curve

Comment on lines +831 to +832
alpha: float,
lam: float,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't know the best type hints here. But TensorVariable is also accepted here

@@ -22,7 +22,7 @@
import xarray as xr
from scipy.optimize import curve_fit, minimize_scalar

from pymc_marketing.mmm.transformers import michaelis_menten
from pymc_marketing.mmm.transformers import michaelis_menten_function


def estimate_menten_parameters(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this still used? I thought there was a generalization of this? @cetagostini

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created an issue here: #1055

Copy link
Contributor

@wd60622 wd60622 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you write a test for the case that discovered this bug

@wd60622
Copy link
Contributor

wd60622 commented Sep 20, 2024

I am looking over the _plot_response_curve_fit internals. I think it would be better to rely on the SaturationTransformation.sample_curve method so that the solution doesn't directly call the function method.

There needs to be a solution that works for others custom saturation transformations

I will create an issue for this. We can address in the future #1056

@PabloRoque
Copy link
Contributor Author

Can you write a test for the case that discovered this bug

Added saturation fixture. Now test_mmm_plots includes MichaelisMentenSaturation.
Would fail without TensorVariable in plot_direct_contribution_curves fixture, but does not with the fix in this PR (pending suggested changes)

@PabloRoque
Copy link
Contributor Author

SaturationTransformation.sample_curve

This will add 3 extra pm.Deterministic to the model.
Are we happy we that behavior?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

MMM.plot_direct_contribution_curves errors when using MichaelisMentenSaturation
2 participants