diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 5d1f40d1..f2219e08 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -23,7 +23,10 @@ Inference .. autosummary:: :toctree: generated/ + find_MAP fit + fit_laplace + fit_pathfinder Distributions diff --git a/pymc_extras/__init__.py b/pymc_extras/__init__.py index 5ef7a58e..0ff07bd6 100644 --- a/pymc_extras/__init__.py +++ b/pymc_extras/__init__.py @@ -15,9 +15,7 @@ from pymc_extras import gp, statespace, utils from pymc_extras.distributions import * -from pymc_extras.inference.find_map import find_MAP -from pymc_extras.inference.fit import fit -from pymc_extras.inference.laplace import fit_laplace +from pymc_extras.inference import find_MAP, fit, fit_laplace, fit_pathfinder from pymc_extras.model.marginal.marginal_model import ( MarginalModel, marginalize, diff --git a/pymc_extras/inference/__init__.py b/pymc_extras/inference/__init__.py index ac65fdae..a01fdd5c 100644 --- a/pymc_extras/inference/__init__.py +++ b/pymc_extras/inference/__init__.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from pymc_extras.inference.find_map import find_MAP from pymc_extras.inference.fit import fit +from pymc_extras.inference.laplace import fit_laplace +from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder -__all__ = ["fit"] +__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"] diff --git a/pymc_extras/inference/fit.py b/pymc_extras/inference/fit.py index 60d89777..5b83ff1f 100644 --- a/pymc_extras/inference/fit.py +++ b/pymc_extras/inference/fit.py @@ -16,7 +16,8 @@ def fit(method: str, **kwargs) -> az.InferenceData: """ - Fit a model with an inference algorithm + Fit a model with an inference algorithm. + See :func:`fit_pathfinder` and :func:`fit_laplace` for more details. Parameters ---------- @@ -24,11 +25,11 @@ def fit(method: str, **kwargs) -> az.InferenceData: Which inference method to run. Supported: pathfinder or laplace - kwargs are passed on. + kwargs: keyword arguments are passed on to the inference method. Returns ------- - arviz.InferenceData + :class:`~arviz.InferenceData` """ if method == "pathfinder": from pymc_extras.inference.pathfinder import fit_pathfinder diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index 78b5100e..bc35d926 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -509,7 +509,7 @@ def fit_laplace( Returns ------- - idata: az.InferenceData + :class:`~arviz.InferenceData` An InferenceData object containing the approximated posterior samples. Examples diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index 8a73549f..846c00fa 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import collections import logging import time @@ -24,7 +25,6 @@ import arviz as az import filelock -import jax import numpy as np import pymc as pm import pytensor @@ -43,7 +43,6 @@ find_rng_nodes, reseed_rngs, ) -from pymc.sampling.jax import get_jaxified_graph from pymc.util import ( CustomProgress, RandomSeed, @@ -64,6 +63,7 @@ # TODO: change to typing.Self after Python versions greater than 3.10 from typing_extensions import Self +from pymc_extras.inference.laplace import add_data_to_inferencedata from pymc_extras.inference.pathfinder.importance_sampling import ( importance_sampling as _importance_sampling, ) @@ -99,6 +99,8 @@ def get_jaxified_logp_of_ravel_inputs(model: Model, jacobian: bool = True) -> Ca A JAX function that computes the log-probability of a PyMC model with ravelled inputs. """ + from pymc.sampling.jax import get_jaxified_graph + # TODO: JAX: test if we should get jaxified graph of dlogp as well new_logprob, new_input = pm.pytensorf.join_nonshared_inputs( model.initial_point(), (model.logp(jacobian=jacobian),), model.value_vars, () @@ -218,6 +220,10 @@ def convert_flat_trace_to_idata( result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result] elif inference_backend == "blackjax": + import jax + + from pymc.sampling.jax import get_jaxified_graph + jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) result = jax.vmap(jax.vmap(jax_fn))( *jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0]) @@ -1627,6 +1633,7 @@ def fit_pathfinder( inference_backend: Literal["pymc", "blackjax"] = "pymc", pathfinder_kwargs: dict = {}, compile_kwargs: dict = {}, + initvals: dict | None = None, ) -> az.InferenceData: """ Fit the Pathfinder Variational Inference algorithm. @@ -1662,12 +1669,12 @@ def fit_pathfinder( importance_sampling : str, None, optional Method to apply sampling based on log importance weights (logP - logQ). Options are: - "psis" : Pareto Smoothed Importance Sampling (default) - Recommended for more stable results. - "psir" : Pareto Smoothed Importance Resampling - Less stable than PSIS. - "identity" : Applies log importance weights directly without resampling. - None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N). + + - "psis" : Pareto Smoothed Importance Sampling (default). Usually most stable. + - "psir" : Pareto Smoothed Importance Resampling. Less stable than PSIS. + - "identity" : Applies log importance weights directly without resampling. + - None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N). + progressbar : bool, optional Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time. random_seed : RandomSeed, optional @@ -1682,10 +1689,13 @@ def fit_pathfinder( Additional keyword arguments for the Pathfinder algorithm. compile_kwargs Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc". + initvals: dict | None = None + Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted. + If None, the model's default initial values are used. Returns ------- - arviz.InferenceData + :class:`~arviz.InferenceData` The inference data containing the results of the Pathfinder algorithm. References @@ -1695,6 +1705,14 @@ def fit_pathfinder( model = modelcontext(model) + if initvals is not None: + model = pm.model.fgraph.clone_model(model) # Create a clone of the model + for ( + rv_name, + ivals, + ) in initvals.items(): # Set the initial values for the variables in the clone + model.set_initval(model.named_vars[rv_name], ivals) + valid_importance_sampling = {"psis", "psir", "identity", None} if importance_sampling is not None: @@ -1734,6 +1752,7 @@ def fit_pathfinder( pathfinder_samples = mp_result.samples elif inference_backend == "blackjax": import blackjax + import jax if version.parse(blackjax.__version__).major < 1: raise ImportError("fit_pathfinder requires blackjax 1.0 or above") @@ -1772,4 +1791,7 @@ def fit_pathfinder( model=model, importance_sampling=importance_sampling, ) + + idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs) + return idata diff --git a/tests/test_pathfinder.py b/tests/test_pathfinder.py index b2f4b815..b0776a1c 100644 --- a/tests/test_pathfinder.py +++ b/tests/test_pathfinder.py @@ -200,3 +200,15 @@ def test_pathfinder_importance_sampling(importance_sampling): assert idata.posterior["mu"].shape == (1, num_draws) assert idata.posterior["tau"].shape == (1, num_draws) assert idata.posterior["theta"].shape == (1, num_draws, 8) + + +def test_pathfinder_initvals(): + # Run a model with an ordered transform that will fail unless initvals are in place + with pm.Model() as mdl: + pm.Normal("ordered", size=10, transform=pm.distributions.transforms.ordered) + idata = pmx.fit_pathfinder(initvals={"ordered": np.linspace(0, 1, 10)}) + + # Check that the samples are ordered to make sure transform was applied + assert np.all( + idata.posterior["ordered"][..., 1:].values > idata.posterior["ordered"][..., :-1].values + )