Skip to content

Commit

Permalink
add gaussianity
Browse files Browse the repository at this point in the history
  • Loading branch information
phinate committed Nov 29, 2021
1 parent bd65370 commit fd9f88f
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 11 deletions.
12 changes: 10 additions & 2 deletions src/relaxed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from relaxed._version import version as __version__

__all__ = ("__version__", "hist", "cramer_rao_uncert", "fisher_info", "mle", "infer")
__all__ = (
"__version__",
"hist",
"cramer_rao_uncert",
"fisher_info",
"mle",
"infer",
"gaussianity",
)

from relaxed import infer, mle
from relaxed.ops import cramer_rao_uncert, fisher_info, hist
from relaxed.ops import cramer_rao_uncert, fisher_info, gaussianity, hist
4 changes: 2 additions & 2 deletions src/relaxed/infer/hypothesis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

from functools import partial

import jax
import jax.numpy as jnp
import pyhf
from chex import Array
from jax import jit

from ..mle import fit, fixed_poi_fit


@partial(jax.jit, static_argnames=["model", "return_mle_pars"]) # forward pass
@partial(jit, static_argnames=["model", "return_mle_pars"]) # forward pass
def hypotest(
test_poi: float,
data: Array,
Expand Down
4 changes: 2 additions & 2 deletions src/relaxed/mle/constrained_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from functools import partial
from typing import TYPE_CHECKING, Callable, cast

import jax
import jax.numpy as jnp
from chex import Array
from jax import jit

if TYPE_CHECKING:
import pyhf
Expand All @@ -33,7 +33,7 @@ def fit_objective(
return fit_objective


@partial(jax.jit, static_argnames=["model"]) # forward pass
@partial(jit, static_argnames=["model"]) # forward pass
def fixed_poi_fit(
data: Array,
model: pyhf.Model,
Expand Down
4 changes: 2 additions & 2 deletions src/relaxed/mle/global_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from functools import partial
from typing import TYPE_CHECKING, Callable, cast

import jax
from chex import Array
from jax import jit

if TYPE_CHECKING:
import pyhf
Expand All @@ -24,7 +24,7 @@ def fit_objective(lhood_pars_to_optimize: Array) -> float: # NLL
return fit_objective


@partial(jax.jit, static_argnames=["model"])
@partial(jit, static_argnames=["model"])
def fit(
data: Array,
model: pyhf.Model,
Expand Down
3 changes: 2 additions & 1 deletion src/relaxed/mle/minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
import jaxopt
import optax
from chex import Array
from jax import jit


# try wrapping obj with closure_convert
@partial(jax.jit, static_argnames=["objective_fn"]) # forward pass
@partial(jit, static_argnames=["objective_fn"]) # forward pass
def _minimize(
objective_fn: Callable[..., float], init_pars: Array, lr: float, *obj_args: Any
) -> Array:
Expand Down
2 changes: 2 additions & 0 deletions src/relaxed/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
"hist",
"fisher_info",
"cramer_rao_uncert",
"gaussianity",
)

from relaxed.ops.fisher_information import cramer_rao_uncert, fisher_info
from relaxed.ops.histograms import hist
from relaxed.ops.likelihood_gaussianity import gaussianity
4 changes: 2 additions & 2 deletions src/relaxed/ops/histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

from functools import partial

import jax
import jax.numpy as jnp
import jax.scipy as jsp
from chex import Array
from jax import jit


@partial(jax.jit, static_argnames=["density", "reflect_infinities"])
@partial(jit, static_argnames=["density", "reflect_infinities"])
def hist(
events: Array,
bins: Array,
Expand Down
65 changes: 65 additions & 0 deletions src/relaxed/ops/likelihood_gaussianity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

__all__ = ("gaussianity",)

from functools import partial
from typing import TYPE_CHECKING

import jax.numpy as jnp
import jax.scipy as jsp
from chex import Array
from jax import jit, vmap
from jax.random import PRNGKey, multivariate_normal

if TYPE_CHECKING:
import pyhf


def gaussian_logpdf(
bestfit_pars: Array,
data: Array,
cov: Array,
) -> Array:
return jsp.stats.multivariate_normal.logpdf(data, bestfit_pars, cov).reshape(
1,
)


@partial(jit, static_argnames=["model", "rng_key", "n_samples"])
def gaussianity(
model: pyhf.Model,
bestfit_pars: Array,
cov_approx: Array,
observed_data: Array,
rng_key: PRNGKey,
n_samples: int = 1000,
) -> Array:
# - compare the likelihood of the fitted model with a gaussian approximation
# that has the same MLE (fitted_pars)
# - do this across a number of points in parspace (sampled from the gaussian approx)
# and take the mean squared diff
# - centre the values wrt the best-fit vals to scale the differences
gaussian_parspace_samples = multivariate_normal(
key=rng_key,
mean=bestfit_pars,
cov=cov_approx,
shape=(n_samples,),
)

relative_nlls_model = vmap(
lambda pars, data: -(
model.logpdf(pars, data)[0] - model.logpdf(bestfit_pars, data)[0]
), # scale origin to bestfit pars
in_axes=(0, None),
)(gaussian_parspace_samples, observed_data)

relative_nlls_gaussian = vmap(
lambda pars, data: -(
gaussian_logpdf(pars, data, cov_approx)[0]
- gaussian_logpdf(bestfit_pars, data, cov_approx)[0]
), # data fixes the lhood shape
in_axes=(0, None),
)(gaussian_parspace_samples, bestfit_pars)

diffs = relative_nlls_model - relative_nlls_gaussian
return jnp.nanmean(diffs ** 2, axis=0)
27 changes: 27 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,30 @@ def model(pars, data):
return relaxed.cramer_rao_uncert(model, pars * x, data * x)

jacrev(pipeline)(4.0) # just check you can calc it w/o exception


def test_gaussianity():
"""Test that the gaussianity of the distribution is preserved."""
pyhf.set_backend("jax")
m = pyhf.simplemodels.uncorrelated_background([5, 5], [50, 50], [5, 5])
pars = jnp.asarray(m.config.suggested_init())
data = jnp.asarray(m.expected_data(pars))
cov_approx = jnp.linalg.inv(
relaxed.fisher_info(lambda d, p: m.logpdf(d, p)[0], pars, data)
)
relaxed.gaussianity(m, pars, cov_approx, data, PRNGKey(0))


def test_gaussianity_grad(example_model):
def pipeline(x):
def model(pars, data):
return example_model.logpdf(pars, data)[0]

pars = example_model.config.suggested_init()
data = example_model.expected_data(pars)
cov_approx = jnp.linalg.inv(relaxed.fisher_info(model, pars, data))
return relaxed.gaussianity(
example_model, pars * x, cov_approx * x, data * x, PRNGKey(0)
)

jacrev(pipeline)(4.0) # just check you can calc it w/o exception

0 comments on commit fd9f88f

Please sign in to comment.