-
-
Notifications
You must be signed in to change notification settings - Fork 58
Make fit_pathfinder more similar to fit_laplace and pm.sample #447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import collections | ||
import logging | ||
import time | ||
|
@@ -24,7 +25,6 @@ | |
|
||
import arviz as az | ||
import filelock | ||
import jax | ||
import numpy as np | ||
import pymc as pm | ||
import pytensor | ||
|
@@ -43,7 +43,6 @@ | |
find_rng_nodes, | ||
reseed_rngs, | ||
) | ||
from pymc.sampling.jax import get_jaxified_graph | ||
from pymc.util import ( | ||
CustomProgress, | ||
RandomSeed, | ||
|
@@ -64,6 +63,7 @@ | |
# TODO: change to typing.Self after Python versions greater than 3.10 | ||
from typing_extensions import Self | ||
|
||
from pymc_extras.inference.laplace import add_data_to_inferencedata | ||
from pymc_extras.inference.pathfinder.importance_sampling import ( | ||
importance_sampling as _importance_sampling, | ||
) | ||
|
@@ -99,6 +99,8 @@ def get_jaxified_logp_of_ravel_inputs(model: Model, jacobian: bool = True) -> Ca | |
A JAX function that computes the log-probability of a PyMC model with ravelled inputs. | ||
""" | ||
|
||
from pymc.sampling.jax import get_jaxified_graph | ||
|
||
# TODO: JAX: test if we should get jaxified graph of dlogp as well | ||
new_logprob, new_input = pm.pytensorf.join_nonshared_inputs( | ||
model.initial_point(), (model.logp(jacobian=jacobian),), model.value_vars, () | ||
|
@@ -218,6 +220,10 @@ def convert_flat_trace_to_idata( | |
result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result] | ||
|
||
elif inference_backend == "blackjax": | ||
import jax | ||
|
||
from pymc.sampling.jax import get_jaxified_graph | ||
|
||
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) | ||
result = jax.vmap(jax.vmap(jax_fn))( | ||
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0]) | ||
|
@@ -1627,6 +1633,7 @@ def fit_pathfinder( | |
inference_backend: Literal["pymc", "blackjax"] = "pymc", | ||
pathfinder_kwargs: dict = {}, | ||
compile_kwargs: dict = {}, | ||
initvals: dict | None = None, | ||
) -> az.InferenceData: | ||
""" | ||
Fit the Pathfinder Variational Inference algorithm. | ||
|
@@ -1662,12 +1669,12 @@ def fit_pathfinder( | |
importance_sampling : str, None, optional | ||
Method to apply sampling based on log importance weights (logP - logQ). | ||
Options are: | ||
"psis" : Pareto Smoothed Importance Sampling (default) | ||
Recommended for more stable results. | ||
"psir" : Pareto Smoothed Importance Resampling | ||
Less stable than PSIS. | ||
"identity" : Applies log importance weights directly without resampling. | ||
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N). | ||
|
||
- "psis" : Pareto Smoothed Importance Sampling (default). Usually most stable. | ||
- "psir" : Pareto Smoothed Importance Resampling. Less stable than PSIS. | ||
- "identity" : Applies log importance weights directly without resampling. | ||
- None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N). | ||
|
||
progressbar : bool, optional | ||
Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time. | ||
random_seed : RandomSeed, optional | ||
|
@@ -1682,10 +1689,13 @@ def fit_pathfinder( | |
Additional keyword arguments for the Pathfinder algorithm. | ||
compile_kwargs | ||
Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc". | ||
initvals: dict | None = None | ||
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted. | ||
If None, the model's default initial values are used. | ||
|
||
Returns | ||
------- | ||
arviz.InferenceData | ||
:class:`~arviz.InferenceData` | ||
The inference data containing the results of the Pathfinder algorithm. | ||
|
||
References | ||
|
@@ -1695,6 +1705,14 @@ def fit_pathfinder( | |
|
||
model = modelcontext(model) | ||
|
||
if initvals is not None: | ||
model = pm.model.fgraph.clone_model(model) # Create a clone of the model | ||
for ( | ||
rv_name, | ||
ivals, | ||
) in initvals.items(): # Set the initial values for the variables in the clone | ||
model.set_initval(model.named_vars[rv_name], ivals) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might need to ensure that ivals is a support point for the RV. For example, x ~ Uniform(-1, 1) would have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While in the ideal world, I would agree, in practice There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, seems fair enough. Thanks for this submission @velochy |
||
|
||
valid_importance_sampling = {"psis", "psir", "identity", None} | ||
|
||
if importance_sampling is not None: | ||
|
@@ -1734,6 +1752,7 @@ def fit_pathfinder( | |
pathfinder_samples = mp_result.samples | ||
elif inference_backend == "blackjax": | ||
import blackjax | ||
import jax | ||
|
||
if version.parse(blackjax.__version__).major < 1: | ||
raise ImportError("fit_pathfinder requires blackjax 1.0 or above") | ||
|
@@ -1772,4 +1791,7 @@ def fit_pathfinder( | |
model=model, | ||
importance_sampling=importance_sampling, | ||
) | ||
|
||
idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs) | ||
|
||
return idata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't mutate the model, make a copy perhaps if there's no better way to just forward the initvals
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would normally agree. However, I tried it, but model.copy() does not produce a working model sometimes - most notably when any transformations are used.
Should I use some other copy function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@velochy you want the
clone_model
function https://github.com/pymc-devs/pymc/blob/2842401f95de74ab37b7750cff455af28cddaffa/pymc/model/fgraph.py#L374There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for that @zaxtax
Updated the code. Also added a test to check if initvals are used correctly.