diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index d86f1fee5..94eb08fba 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -15,6 +15,7 @@ import logging +from collections.abc import Callable from functools import reduce from importlib.util import find_spec from itertools import product @@ -39,6 +40,8 @@ from pymc.model.transform.conditioning import remove_value_transforms from pymc.model.transform.optimization import freeze_dims_and_data from pymc.util import get_default_varnames +from pytensor.tensor import TensorVariable +from pytensor.tensor.optimize import minimize from scipy import stats from pymc_extras.inference.find_map import ( @@ -415,6 +418,69 @@ def sample_laplace_posterior( return idata +def find_mode_and_hess( + x: TensorVariable, + model: pm.Model | None = None, + method: minimize_method = "BFGS", + use_jac: bool = True, + use_hess: bool = False, # TODO Tbh we can probably just remove this arg and pass True to the minimizer all the time, but if this is the case, it will throw a warning when the hessian doesn't need to be computed for a particular optimisation routine. + optimizer_kwargs: dict | None = None, +) -> Callable: + """ + Returns a function to estimate the mode and hessian of a model by minimizing negative log likelihood. Wrapper for (pytensor-native) scipy.optimize.minimize. + + Parameters + ---------- + x: TensorVariable + The parameter with which to minimize wrt (that is, find the mode in x). + model: Model + PyMC model to use. + method: minimize_method + Which minimization algorithm to use. + use_jac: bool + If true, the minimizer will compute and store the Jacobian. + use_hess: bool + If true, the minimizer will compute and store the Hessian (note that the Hessian will be computed explicitely even if this is False). + optimizer_kwargs: dict + Kwargs to pass to scipy.optimize.minimize. + + Returns + ------- + f: Callable + A function which accepts the values of the model RVs as args and returns [mu, hess(mu)], where mu is the mode. The TensorVariable x is specified as an initial guess for mu in args. + """ + model = pm.modelcontext(model) + + # Minimise negative log likelihood + nll = -model.logp() + soln, _ = minimize( + objective=nll, + x=x, + method=method, + jac=use_jac, + hess=use_hess, + optimizer_kwargs=optimizer_kwargs, + ) + + # TODO: Jesse suggested I use this graph_replace function, but it seems that "mode" here is a different type to soln: + # + # TypeError: Cannot convert Type Vector(float64, shape=(10,)) (of Variable MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0) into Type Scalar(float64, shape=()). You can try to manually convert MinimizeOp(method=BFGS, jac=True, hess=True, hessp=False).0 into a Scalar(float64, shape=()). + # + # My understanding here is that for some function which evaluates the hessian at x, we're replacing "x" in the hess graph with the subgraph that computes "x" (i.e. soln)? + + # Obtain the Hessian (re-use graph if already computed in minimize) + if use_hess: + mode, _, hess = ( + soln.owner.op.inner_outputs + ) # Note that this mode, _, hess will need to be slightly more elaborate for when use_jac is False (2 items to unpack instead of 3). Just a few if-blocks, but not implemented for now while we're debugging + hess = pytensor.graph.replace.graph_replace(hess, {mode: soln}) + else: + hess = pytensor.gradient.hessian(nll, x) + + args = model.continuous_value_vars + model.discrete_value_vars + return pytensor.function(args, [soln, hess]) + + def fit_laplace( optimize_method: minimize_method | Literal["basinhopping"] = "BFGS", *, diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 8f7a4c017..c59334812 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -21,6 +21,7 @@ from pymc_extras.inference.find_map import GradientBackend, find_MAP from pymc_extras.inference.laplace import ( + find_mode_and_hess, fit_laplace, fit_mvn_at_MAP, sample_laplace_posterior, @@ -279,3 +280,35 @@ def test_laplace_scalar(): assert idata_laplace.fit.covariance_matrix.shape == (1, 1) np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1) + + +def test_find_mode_and_hess(): + rng = np.random.default_rng(42) + n = 100 + sigma_obs = rng.random() + sigma_mu = rng.random() + + coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(n)} + with pm.Model(coords=coords) as model: + obs_val = rng.normal(loc=3, scale=1.5, size=(n, 3)) + + mu = pm.Normal("mu", mu=1, sigma=sigma_mu, dims=["city"]) + obs = pm.Normal( + "obs", + mu=mu, + sigma=sigma_obs, + observed=obs_val, + dims=["obs_idx", "city"], + ) + + get_mode_and_hessian = find_mode_and_hess( + use_hess=False, x=model.rvs_to_values[mu], method="BFGS", optimizer_kwargs={"tol": 1e-8} + ) + + mode, hess = get_mode_and_hessian(**{"mu": [1, 1, 1]}) + + true_mode = obs_val.mean(axis=0) + true_hess = np.diag((1 / sigma_mu**2 + n / sigma_obs**2) * np.ones(3)) + + np.testing.assert_allclose(mode, true_mode, atol=0.1, rtol=0.1) + np.testing.assert_allclose(hess, true_hess, atol=0.1, rtol=0.1)