Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Oct 25, 2024
1 parent 34a8998 commit a5181ab
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
3 changes: 2 additions & 1 deletion preliz/ppls/agnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def posterior_to_prior(model, idata, alternative=None, engine="auto"):
Fit a posterior from a model to its prior
The fit is based on maximum likelihood of each posterior marginal to the prior
in the model. It is expected that the posterior was computed from the model.
in the model. Thus possible correlations between parameters in the posteriors
will not be preserved. It is expected that the posterior was computed from the model.
Parameters
----------
Expand Down
7 changes: 7 additions & 0 deletions preliz/tests/test_mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ def test_recover_right():
assert idx[0] == 0
assert ax is None

dists = [pz.Normal(), pz.Gamma(), pz.Poisson()]
sample = pz.Normal(10, 0.5).rvs((2, 10000), random_state=123)
idx, ax = pz.mle(dists, sample, plot=0)
all(d.params[0].size == 2 for d in dists)
assert idx[0] == 0
assert ax is None

plt.figure()
sample = Gamma(2, 10).rvs(10000)
idx, ax = pz.mle(dists, sample)
Expand Down
19 changes: 10 additions & 9 deletions preliz/tests/test_posterior_to_prior.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import pandas as pd
import numpy as np
import pymc as pm
import bambi as bmb
import preliz as pz


data = pz.Normal(0, 1).rvs(200)
SEED = 2945

data = pz.Normal(0, 1).rvs(200, random_state=SEED)

with pm.Model() as model:
a = pm.Normal("a", mu=0, sigma=1)
b = pm.HalfNormal("b", sigma=1)
y = pm.Normal("y", mu=a, sigma=b, observed=data)
idata = pm.sample(tune=200, draws=500, random_seed=2945)
b = pm.HalfNormal("b", sigma=[1, 1], shape=2)
y = pm.Normal("y", mu=a, sigma=b[0], observed=data) # pylint:disable = unsubscriptable-object
idata = pm.sample(tune=200, draws=500, random_seed=SEED)


def test_p2p_pymc():
Expand All @@ -25,14 +26,14 @@ def test_p2p_pymc():

bmb_data = pd.DataFrame(
{
"y": np.random.normal(size=117),
"x": np.random.normal(size=117),
"x1": np.random.normal(size=117),
"y": pz.Normal(0, 1).rvs(117, random_state=SEED),
"x": pz.Normal(0, 1).rvs(117, random_state=SEED),
"x1": pz.Normal(0, 1).rvs(117, random_state=SEED),
}
)
bmb_prior = {"Intercept": bmb.Prior("Normal", mu=0, sigma=1)}
bmb_model = bmb.Model("y ~ x + x1", bmb_data, priors=bmb_prior)
bmb_idata = bmb_model.fit(tune=200, draws=200, random_seed=2945)
bmb_idata = bmb_model.fit(tune=200, draws=200, random_seed=SEED)


def test_p2p_bambi():
Expand Down

0 comments on commit a5181ab

Please sign in to comment.