Skip to content

Make fit_pathfinder more similar to fit_laplace and pm.sample #447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ Inference
.. autosummary::
:toctree: generated/

find_MAP
fit
fit_laplace
fit_pathfinder


Distributions
Expand Down
4 changes: 1 addition & 3 deletions pymc_extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@

from pymc_extras import gp, statespace, utils
from pymc_extras.distributions import *
from pymc_extras.inference.find_map import find_MAP
from pymc_extras.inference.fit import fit
from pymc_extras.inference.laplace import fit_laplace
from pymc_extras.inference import find_MAP, fit, fit_laplace, fit_pathfinder
from pymc_extras.model.marginal.marginal_model import (
MarginalModel,
marginalize,
Expand Down
6 changes: 4 additions & 2 deletions pymc_extras/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from pymc_extras.inference.find_map import find_MAP
from pymc_extras.inference.fit import fit
from pymc_extras.inference.laplace import fit_laplace
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder

__all__ = ["fit"]
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
7 changes: 4 additions & 3 deletions pymc_extras/inference/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,20 @@

def fit(method: str, **kwargs) -> az.InferenceData:
"""
Fit a model with an inference algorithm
Fit a model with an inference algorithm.
See :func:`fit_pathfinder` and :func:`fit_laplace` for more details.

Parameters
----------
method : str
Which inference method to run.
Supported: pathfinder or laplace

kwargs are passed on.
kwargs: keyword arguments are passed on to the inference method.

Returns
-------
arviz.InferenceData
:class:`~arviz.InferenceData`
"""
if method == "pathfinder":
from pymc_extras.inference.pathfinder import fit_pathfinder
Expand Down
2 changes: 1 addition & 1 deletion pymc_extras/inference/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def fit_laplace(

Returns
-------
idata: az.InferenceData
:class:`~arviz.InferenceData`
An InferenceData object containing the approximated posterior samples.

Examples
Expand Down
40 changes: 31 additions & 9 deletions pymc_extras/inference/pathfinder/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import collections
import logging
import time
Expand All @@ -24,7 +25,6 @@

import arviz as az
import filelock
import jax
import numpy as np
import pymc as pm
import pytensor
Expand All @@ -43,7 +43,6 @@
find_rng_nodes,
reseed_rngs,
)
from pymc.sampling.jax import get_jaxified_graph
from pymc.util import (
CustomProgress,
RandomSeed,
Expand All @@ -64,6 +63,7 @@
# TODO: change to typing.Self after Python versions greater than 3.10
from typing_extensions import Self

from pymc_extras.inference.laplace import add_data_to_inferencedata
from pymc_extras.inference.pathfinder.importance_sampling import (
importance_sampling as _importance_sampling,
)
Expand Down Expand Up @@ -99,6 +99,8 @@ def get_jaxified_logp_of_ravel_inputs(model: Model, jacobian: bool = True) -> Ca
A JAX function that computes the log-probability of a PyMC model with ravelled inputs.
"""

from pymc.sampling.jax import get_jaxified_graph

# TODO: JAX: test if we should get jaxified graph of dlogp as well
new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
model.initial_point(), (model.logp(jacobian=jacobian),), model.value_vars, ()
Expand Down Expand Up @@ -218,6 +220,10 @@ def convert_flat_trace_to_idata(
result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result]

elif inference_backend == "blackjax":
import jax

from pymc.sampling.jax import get_jaxified_graph

jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
result = jax.vmap(jax.vmap(jax_fn))(
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
Expand Down Expand Up @@ -1627,6 +1633,7 @@ def fit_pathfinder(
inference_backend: Literal["pymc", "blackjax"] = "pymc",
pathfinder_kwargs: dict = {},
compile_kwargs: dict = {},
initvals: dict | None = None,
) -> az.InferenceData:
"""
Fit the Pathfinder Variational Inference algorithm.
Expand Down Expand Up @@ -1662,12 +1669,12 @@ def fit_pathfinder(
importance_sampling : str, None, optional
Method to apply sampling based on log importance weights (logP - logQ).
Options are:
"psis" : Pareto Smoothed Importance Sampling (default)
Recommended for more stable results.
"psir" : Pareto Smoothed Importance Resampling
Less stable than PSIS.
"identity" : Applies log importance weights directly without resampling.
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).

- "psis" : Pareto Smoothed Importance Sampling (default). Usually most stable.
- "psir" : Pareto Smoothed Importance Resampling. Less stable than PSIS.
- "identity" : Applies log importance weights directly without resampling.
- None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).

progressbar : bool, optional
Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
random_seed : RandomSeed, optional
Expand All @@ -1682,10 +1689,13 @@ def fit_pathfinder(
Additional keyword arguments for the Pathfinder algorithm.
compile_kwargs
Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc".
initvals: dict | None = None
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
If None, the model's default initial values are used.

Returns
-------
arviz.InferenceData
:class:`~arviz.InferenceData`
The inference data containing the results of the Pathfinder algorithm.

References
Expand All @@ -1695,6 +1705,14 @@ def fit_pathfinder(

model = modelcontext(model)

if initvals is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't mutate the model, make a copy perhaps if there's no better way to just forward the initvals

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would normally agree. However, I tried it, but model.copy() does not produce a working model sometimes - most notably when any transformations are used.

Should I use some other copy function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for that @zaxtax
Updated the code. Also added a test to check if initvals are used correctly.

model = pm.model.fgraph.clone_model(model) # Create a clone of the model
for (
rv_name,
ivals,
) in initvals.items(): # Set the initial values for the variables in the clone
model.set_initval(model.named_vars[rv_name], ivals)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might need to ensure that ivals is a support point for the RV. For example, x ~ Uniform(-1, 1) would have nan initial values with model.set_initval(model.named_vars["x"], 2)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While in the ideal world, I would agree, in practice
a) It is very nontrivial to do as I understand, as the limits are not specified anywhere where they are easy to take
b) pm.sample does no such checks, and the goal of this PR is to be compatible with that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, seems fair enough. Thanks for this submission @velochy


valid_importance_sampling = {"psis", "psir", "identity", None}

if importance_sampling is not None:
Expand Down Expand Up @@ -1734,6 +1752,7 @@ def fit_pathfinder(
pathfinder_samples = mp_result.samples
elif inference_backend == "blackjax":
import blackjax
import jax

if version.parse(blackjax.__version__).major < 1:
raise ImportError("fit_pathfinder requires blackjax 1.0 or above")
Expand Down Expand Up @@ -1772,4 +1791,7 @@ def fit_pathfinder(
model=model,
importance_sampling=importance_sampling,
)

idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)

return idata
12 changes: 12 additions & 0 deletions tests/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,15 @@ def test_pathfinder_importance_sampling(importance_sampling):
assert idata.posterior["mu"].shape == (1, num_draws)
assert idata.posterior["tau"].shape == (1, num_draws)
assert idata.posterior["theta"].shape == (1, num_draws, 8)


def test_pathfinder_initvals():
# Run a model with an ordered transform that will fail unless initvals are in place
with pm.Model() as mdl:
pm.Normal("ordered", size=10, transform=pm.distributions.transforms.ordered)
idata = pmx.fit_pathfinder(initvals={"ordered": np.linspace(0, 1, 10)})

# Check that the samples are ordered to make sure transform was applied
assert np.all(
idata.posterior["ordered"][..., 1:].values > idata.posterior["ordered"][..., :-1].values
)
Loading