diff --git a/src/relaxed/__init__.py b/src/relaxed/__init__.py index 4d754a4..f309aa4 100644 --- a/src/relaxed/__init__.py +++ b/src/relaxed/__init__.py @@ -1,6 +1,14 @@ from relaxed._version import version as __version__ -__all__ = ("__version__", "hist", "cramer_rao_uncert", "fisher_info", "mle", "infer") +__all__ = ( + "__version__", + "hist", + "cramer_rao_uncert", + "fisher_info", + "mle", + "infer", + "gaussianity", +) from relaxed import infer, mle -from relaxed.ops import cramer_rao_uncert, fisher_info, hist +from relaxed.ops import cramer_rao_uncert, fisher_info, gaussianity, hist diff --git a/src/relaxed/infer/hypothesis_test.py b/src/relaxed/infer/hypothesis_test.py index a713741..f721296 100644 --- a/src/relaxed/infer/hypothesis_test.py +++ b/src/relaxed/infer/hypothesis_test.py @@ -5,15 +5,15 @@ from functools import partial -import jax import jax.numpy as jnp import pyhf from chex import Array +from jax import jit from ..mle import fit, fixed_poi_fit -@partial(jax.jit, static_argnames=["model", "return_mle_pars"]) # forward pass +@partial(jit, static_argnames=["model", "return_mle_pars"]) # forward pass def hypotest( test_poi: float, data: Array, diff --git a/src/relaxed/mle/constrained_fit.py b/src/relaxed/mle/constrained_fit.py index a70096a..e87b756 100644 --- a/src/relaxed/mle/constrained_fit.py +++ b/src/relaxed/mle/constrained_fit.py @@ -5,9 +5,9 @@ from functools import partial from typing import TYPE_CHECKING, Callable, cast -import jax import jax.numpy as jnp from chex import Array +from jax import jit if TYPE_CHECKING: import pyhf @@ -33,7 +33,7 @@ def fit_objective( return fit_objective -@partial(jax.jit, static_argnames=["model"]) # forward pass +@partial(jit, static_argnames=["model"]) # forward pass def fixed_poi_fit( data: Array, model: pyhf.Model, diff --git a/src/relaxed/mle/global_fit.py b/src/relaxed/mle/global_fit.py index a5bef5e..f3ad618 100644 --- a/src/relaxed/mle/global_fit.py +++ b/src/relaxed/mle/global_fit.py @@ -5,8 +5,8 @@ from functools import partial from typing import TYPE_CHECKING, Callable, cast -import jax from chex import Array +from jax import jit if TYPE_CHECKING: import pyhf @@ -24,7 +24,7 @@ def fit_objective(lhood_pars_to_optimize: Array) -> float: # NLL return fit_objective -@partial(jax.jit, static_argnames=["model"]) +@partial(jit, static_argnames=["model"]) def fit( data: Array, model: pyhf.Model, diff --git a/src/relaxed/mle/minimize.py b/src/relaxed/mle/minimize.py index 4082890..3052454 100644 --- a/src/relaxed/mle/minimize.py +++ b/src/relaxed/mle/minimize.py @@ -10,10 +10,11 @@ import jaxopt import optax from chex import Array +from jax import jit # try wrapping obj with closure_convert -@partial(jax.jit, static_argnames=["objective_fn"]) # forward pass +@partial(jit, static_argnames=["objective_fn"]) # forward pass def _minimize( objective_fn: Callable[..., float], init_pars: Array, lr: float, *obj_args: Any ) -> Array: diff --git a/src/relaxed/ops/__init__.py b/src/relaxed/ops/__init__.py index 2b86a52..bedf249 100644 --- a/src/relaxed/ops/__init__.py +++ b/src/relaxed/ops/__init__.py @@ -2,7 +2,9 @@ "hist", "fisher_info", "cramer_rao_uncert", + "gaussianity", ) from relaxed.ops.fisher_information import cramer_rao_uncert, fisher_info from relaxed.ops.histograms import hist +from relaxed.ops.likelihood_gaussianity import gaussianity diff --git a/src/relaxed/ops/histograms.py b/src/relaxed/ops/histograms.py index 0e57b6b..ccee602 100644 --- a/src/relaxed/ops/histograms.py +++ b/src/relaxed/ops/histograms.py @@ -5,13 +5,13 @@ from functools import partial -import jax import jax.numpy as jnp import jax.scipy as jsp from chex import Array +from jax import jit -@partial(jax.jit, static_argnames=["density", "reflect_infinities"]) +@partial(jit, static_argnames=["density", "reflect_infinities"]) def hist( events: Array, bins: Array, diff --git a/src/relaxed/ops/likelihood_gaussianity.py b/src/relaxed/ops/likelihood_gaussianity.py new file mode 100644 index 0000000..c8356b0 --- /dev/null +++ b/src/relaxed/ops/likelihood_gaussianity.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +__all__ = ("gaussianity",) + +from functools import partial +from typing import TYPE_CHECKING + +import jax.numpy as jnp +import jax.scipy as jsp +from chex import Array +from jax import jit, vmap +from jax.random import PRNGKey, multivariate_normal + +if TYPE_CHECKING: + import pyhf + + +def gaussian_logpdf( + bestfit_pars: Array, + data: Array, + cov: Array, +) -> Array: + return jsp.stats.multivariate_normal.logpdf(data, bestfit_pars, cov).reshape( + 1, + ) + + +@partial(jit, static_argnames=["model", "rng_key", "n_samples"]) +def gaussianity( + model: pyhf.Model, + bestfit_pars: Array, + cov_approx: Array, + observed_data: Array, + rng_key: PRNGKey, + n_samples: int = 1000, +) -> Array: + # - compare the likelihood of the fitted model with a gaussian approximation + # that has the same MLE (fitted_pars) + # - do this across a number of points in parspace (sampled from the gaussian approx) + # and take the mean squared diff + # - centre the values wrt the best-fit vals to scale the differences + gaussian_parspace_samples = multivariate_normal( + key=rng_key, + mean=bestfit_pars, + cov=cov_approx, + shape=(n_samples,), + ) + + relative_nlls_model = vmap( + lambda pars, data: -( + model.logpdf(pars, data)[0] - model.logpdf(bestfit_pars, data)[0] + ), # scale origin to bestfit pars + in_axes=(0, None), + )(gaussian_parspace_samples, observed_data) + + relative_nlls_gaussian = vmap( + lambda pars, data: -( + gaussian_logpdf(pars, data, cov_approx)[0] + - gaussian_logpdf(bestfit_pars, data, cov_approx)[0] + ), # data fixes the lhood shape + in_axes=(0, None), + )(gaussian_parspace_samples, bestfit_pars) + + diffs = relative_nlls_model - relative_nlls_gaussian + return jnp.nanmean(diffs ** 2, axis=0) diff --git a/tests/test_ops.py b/tests/test_ops.py index fd207b3..2c06d07 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -163,3 +163,30 @@ def model(pars, data): return relaxed.cramer_rao_uncert(model, pars * x, data * x) jacrev(pipeline)(4.0) # just check you can calc it w/o exception + + +def test_gaussianity(): + """Test that the gaussianity of the distribution is preserved.""" + pyhf.set_backend("jax") + m = pyhf.simplemodels.uncorrelated_background([5, 5], [50, 50], [5, 5]) + pars = jnp.asarray(m.config.suggested_init()) + data = jnp.asarray(m.expected_data(pars)) + cov_approx = jnp.linalg.inv( + relaxed.fisher_info(lambda d, p: m.logpdf(d, p)[0], pars, data) + ) + relaxed.gaussianity(m, pars, cov_approx, data, PRNGKey(0)) + + +def test_gaussianity_grad(example_model): + def pipeline(x): + def model(pars, data): + return example_model.logpdf(pars, data)[0] + + pars = example_model.config.suggested_init() + data = example_model.expected_data(pars) + cov_approx = jnp.linalg.inv(relaxed.fisher_info(model, pars, data)) + return relaxed.gaussianity( + example_model, pars * x, cov_approx * x, data * x, PRNGKey(0) + ) + + jacrev(pipeline)(4.0) # just check you can calc it w/o exception