Skip to content

Commit f53445a

Browse files
committed
Added initvals to parameters, constants and observations to returnvalue for pathfinder and cleaned relevant docs a bit
1 parent 4d65ea0 commit f53445a

File tree

7 files changed

+49
-16
lines changed

7 files changed

+49
-16
lines changed

docs/api_reference.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ Inference
2323
.. autosummary::
2424
:toctree: generated/
2525

26+
find_MAP
2627
fit
28+
fit_laplace
29+
fit_pathfinder
2730

2831

2932
Distributions

pymc_extras/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515

1616
from pymc_extras import gp, statespace, utils
1717
from pymc_extras.distributions import *
18-
from pymc_extras.inference.find_map import find_MAP
19-
from pymc_extras.inference.fit import fit
20-
from pymc_extras.inference.laplace import fit_laplace
18+
from pymc_extras.inference import find_MAP, fit, fit_laplace, fit_pathfinder
2119
from pymc_extras.model.marginal.marginal_model import (
2220
MarginalModel,
2321
marginalize,

pymc_extras/inference/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
15+
from pymc_extras.inference.find_map import find_MAP
1616
from pymc_extras.inference.fit import fit
17+
from pymc_extras.inference.laplace import fit_laplace
18+
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
1719

18-
__all__ = ["fit"]
20+
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]

pymc_extras/inference/fit.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,20 @@
1616

1717
def fit(method: str, **kwargs) -> az.InferenceData:
1818
"""
19-
Fit a model with an inference algorithm
19+
Fit a model with an inference algorithm.
20+
See :func:`fit_pathfinder` and :func:`fit_laplace` for more details.
2021
2122
Parameters
2223
----------
2324
method : str
2425
Which inference method to run.
2526
Supported: pathfinder or laplace
2627
27-
kwargs are passed on.
28+
kwargs: keyword arguments are passed on to the inference method.
2829
2930
Returns
3031
-------
31-
arviz.InferenceData
32+
:class:`~arviz.InferenceData`
3233
"""
3334
if method == "pathfinder":
3435
from pymc_extras.inference.pathfinder import fit_pathfinder

pymc_extras/inference/laplace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def fit_laplace(
509509
510510
Returns
511511
-------
512-
idata: az.InferenceData
512+
:class:`~arviz.InferenceData`
513513
An InferenceData object containing the approximated posterior samples.
514514
515515
Examples

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
1516
import collections
1617
import logging
1718
import time
@@ -64,6 +65,7 @@
6465
# TODO: change to typing.Self after Python versions greater than 3.10
6566
from typing_extensions import Self
6667

68+
from pymc_extras.inference.laplace import add_data_to_inferencedata
6769
from pymc_extras.inference.pathfinder.importance_sampling import (
6870
importance_sampling as _importance_sampling,
6971
)
@@ -1627,6 +1629,7 @@ def fit_pathfinder(
16271629
inference_backend: Literal["pymc", "blackjax"] = "pymc",
16281630
pathfinder_kwargs: dict = {},
16291631
compile_kwargs: dict = {},
1632+
initvals: dict | None = None,
16301633
) -> az.InferenceData:
16311634
"""
16321635
Fit the Pathfinder Variational Inference algorithm.
@@ -1662,12 +1665,12 @@ def fit_pathfinder(
16621665
importance_sampling : str, None, optional
16631666
Method to apply sampling based on log importance weights (logP - logQ).
16641667
Options are:
1665-
"psis" : Pareto Smoothed Importance Sampling (default)
1666-
Recommended for more stable results.
1667-
"psir" : Pareto Smoothed Importance Resampling
1668-
Less stable than PSIS.
1669-
"identity" : Applies log importance weights directly without resampling.
1670-
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1668+
1669+
- **"psis"** : Pareto Smoothed Importance Sampling (default). Usually most stable.
1670+
- **"psir"** : Pareto Smoothed Importance Resampling. Less stable than PSIS.
1671+
- **"identity"** : Applies log importance weights directly without resampling.
1672+
- **None** : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).
1673+
16711674
progressbar : bool, optional
16721675
Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
16731676
random_seed : RandomSeed, optional
@@ -1682,10 +1685,13 @@ def fit_pathfinder(
16821685
Additional keyword arguments for the Pathfinder algorithm.
16831686
compile_kwargs
16841687
Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc".
1688+
initvals: dict | None = None
1689+
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
1690+
If None, the model's default initial values are used.
16851691
16861692
Returns
16871693
-------
1688-
arviz.InferenceData
1694+
:class:`~arviz.InferenceData`
16891695
The inference data containing the results of the Pathfinder algorithm.
16901696
16911697
References
@@ -1695,6 +1701,14 @@ def fit_pathfinder(
16951701

16961702
model = modelcontext(model)
16971703

1704+
if initvals is not None:
1705+
model = pm.model.fgraph.clone_model(model) # Create a clone of the model
1706+
for (
1707+
rv_name,
1708+
ivals,
1709+
) in initvals.items(): # Set the initial values for the variables in the clone
1710+
model.set_initval(model.named_vars[rv_name], ivals)
1711+
16981712
valid_importance_sampling = {"psis", "psir", "identity", None}
16991713

17001714
if importance_sampling is not None:
@@ -1772,4 +1786,7 @@ def fit_pathfinder(
17721786
model=model,
17731787
importance_sampling=importance_sampling,
17741788
)
1789+
1790+
idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)
1791+
17751792
return idata

tests/test_pathfinder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,15 @@ def test_pathfinder_importance_sampling(importance_sampling):
200200
assert idata.posterior["mu"].shape == (1, num_draws)
201201
assert idata.posterior["tau"].shape == (1, num_draws)
202202
assert idata.posterior["theta"].shape == (1, num_draws, 8)
203+
204+
205+
def test_initvals():
206+
# Run a model with an ordered transform that will fail unless initvals are in place
207+
with pm.Model() as mdl:
208+
pm.Normal("ordered", size=10, transform=pm.distributions.transforms.ordered)
209+
idata = pmx.fit_pathfinder(initvals={"ordered": np.linspace(0, 1, 10)})
210+
211+
# Check that the samples are ordered to make sure transform was applied
212+
assert np.all(
213+
idata.posterior["ordered"][..., 1:].values > idata.posterior["ordered"][..., :-1].values
214+
)

0 commit comments

Comments
 (0)