diff --git a/examples/hmm.py b/examples/hmm.py index 4972f1642..1953080fe 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -8,11 +8,9 @@ from jax.scipy.special import logsumexp import numpyro.distributions as dist -from numpyro.diagnostics import summary from numpyro.handlers import sample from numpyro.hmc_util import initialize_model -from numpyro.mcmc import hmc -from numpyro.util import fori_collect +from numpyro.mcmc import mcmc """ @@ -117,21 +115,6 @@ def semi_supervised_hmm(transition_prior, emission_prior, return sample('forward_log_prob', dist.Multinomial(logits=-log_prob), obs=0) -def run_inference(transition_prior, emission_prior, supervised_categories, supervised_words, - unsupervised_words, rng, args): - init_params, potential_fn, constrain_fn = initialize_model( - rng, - semi_supervised_hmm, - transition_prior, emission_prior, supervised_categories, - supervised_words, unsupervised_words, - ) - init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') - hmc_state = init_kernel(init_params, args.num_warmup) - hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state, - transform=lambda state: constrain_fn(state.z)) - return hmc_states - - def print_results(posterior, transition_prob, emission_prob): header = semi_supervised_hmm.__name__ + ' - TRAIN' columns = ['', 'ActualProb', 'Pred(p25)', 'Pred(p50)', 'Pred(p75)'] @@ -165,11 +148,15 @@ def main(args): num_unsupervised_data=args.num_unsupervised, ) print('Starting inference...') - zs = run_inference(transition_prior, emission_prior, - supervised_categories, supervised_words, unsupervised_words, - random.PRNGKey(2), args) - summary(zs) - print_results(zs, transition_prob, emission_prob) + init_params, potential_fn, constrain_fn = initialize_model( + random.PRNGKey(2), + semi_supervised_hmm, + transition_prior, emission_prior, supervised_categories, + supervised_words, unsupervised_words, + ) + samples = mcmc(args.num_warmup, args.num_samples, init_params, + potential_fn=potential_fn, constrain_fn=constrain_fn) + print_results(samples, transition_prob, emission_prob) if __name__ == '__main__': diff --git a/examples/ucbadmit.py b/examples/ucbadmit.py index bf31964e1..bb4d8885a 100644 --- a/examples/ucbadmit.py +++ b/examples/ucbadmit.py @@ -3,7 +3,7 @@ import numpy as onp import jax.numpy as np -import jax.random as random +from jax import random, vmap from jax.config import config as jax_config import numpyro.distributions as dist @@ -52,20 +52,18 @@ """ -# TODO: Remove broadcasting logic when support for `pyro.plate` is available. def glmm(dept, male, applications, admit): v_mu = sample('v_mu', dist.Normal(0, np.array([4., 1.]))) sigma = sample('sigma', dist.HalfNormal(np.ones(2))) L_Rho = sample('L_Rho', dist.LKJCholesky(2)) - scale_tril = np.expand_dims(sigma, axis=-1) * L_Rho + scale_tril = sigma[..., np.newaxis] * L_Rho # non-centered parameterization num_dept = len(onp.unique(dept)) z = sample('z', dist.Normal(np.zeros((num_dept, 2)), 1)) - v = np.squeeze(np.matmul(np.expand_dims(scale_tril, axis=-3), np.expand_dims(z, axis=-1)), - axis=-1) + v = np.dot(scale_tril, z.T).T - logits = v_mu[..., :1] + v[..., dept, 0] + (v_mu[..., 1:] + v[..., dept, 1]) * male + logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male sample('admit', dist.Binomial(applications, logits=logits), obs=admit) @@ -79,13 +77,10 @@ def run_inference(dept, male, applications, admit, rng, args): return hmc_states -def predict(dept, male, applications, admit, z, rng): - header = glmm.__name__ + ' - TRAIN' +def predict(dept, male, applications, z, rng): model = substitute(seed(glmm, rng), z) model_trace = trace(model).get_trace(dept, male, applications, admit=None) - predictions = model_trace['admit']['fn'].probs - probs = admit / applications - print_results('=' * 30 + header + '=' * 30, predictions, dept, male, probs) + return model_trace['admit']['fn'].probs def print_results(header, preds, dept, male, probs): @@ -105,7 +100,10 @@ def main(args): dept, male, applications, admit = fetch_train() rng, rng_predict = random.split(random.PRNGKey(1)) zs = run_inference(dept, male, applications, admit, rng, args) - predict(dept, male, applications, admit, zs, rng_predict) + rngs = random.split(rng_predict, args.num_samples) + pred_probs = vmap(lambda z, rng: predict(dept, male, applications, z, rng))(zs, rngs) + header = '=' * 30 + 'glmm - TRAIN' + '=' * 30 + print_results(header, pred_probs, dept, male, admit / applications) if __name__ == '__main__':