Skip to content

Commit

Permalink
add diverging state (#330)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and neerajprad committed Sep 8, 2019
1 parent 2ddb42c commit 1a5c16a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
20 changes: 12 additions & 8 deletions numpyro/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from numpyro.util import cond, copy_docs_from, fori_collect, fori_loop, identity

HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'num_steps', 'accept_prob',
'mean_accept_prob', 'adapt_state', 'rng'])
'mean_accept_prob', 'diverging', 'adapt_state', 'rng'])
"""
A :func:`~collections.namedtuple` consisting of the following fields:
Expand All @@ -42,6 +42,7 @@
does not correspond to the proposal if it is rejected.
- **mean_accept_prob** - Mean acceptance probability until current iteration
during warmup adaptation or sampling (for diagnostics).
- **diverging** - A boolean value to indicate whether the current trajectory is diverging.
- **adapt_state** - A ``AdaptState`` namedtuple which contains adaptation information
during warmup:
Expand Down Expand Up @@ -163,6 +164,7 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'):
momentum_generator = None
wa_update = None
wa_steps = None
max_delta_energy = 1000.
if algo not in {'HMC', 'NUTS'}:
raise ValueError('`algo` must be one of `HMC` or `NUTS`.')

Expand Down Expand Up @@ -235,7 +237,7 @@ def init_kernel(init_params,
r = momentum_generator(wa_state.mass_matrix_sqrt, rng)
vv_state = vv_init(z, r)
hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0.,
wa_state, rng_hmc)
False, wa_state, rng_hmc)

# TODO: Remove; this should be the responsibility of the MCMC class.
if run_warmup and num_warmup > 0:
Expand All @@ -259,23 +261,25 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng):
delta_energy = energy_new - energy_old
delta_energy = np.where(np.isnan(delta_energy), np.inf, delta_energy)
accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0)
diverging = delta_energy > max_delta_energy
transition = random.bernoulli(rng, accept_prob)
vv_state = cond(transition,
vv_state_new, lambda state: state,
vv_state, lambda state: state)
return vv_state, num_steps, accept_prob
return vv_state, num_steps, accept_prob, diverging

def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng):
binary_tree = build_tree(vv_update, kinetic_fn, vv_state,
inverse_mass_matrix, step_size, rng,
max_delta_energy=max_delta_energy,
max_tree_depth=max_treedepth)
accept_prob = binary_tree.sum_accept_probs / binary_tree.num_proposals
num_steps = binary_tree.num_proposals
vv_state = IntegratorState(z=binary_tree.z_proposal,
r=vv_state.r,
potential_energy=binary_tree.z_proposal_pe,
z_grad=binary_tree.z_proposal_grad)
return vv_state, num_steps, accept_prob
return vv_state, num_steps, accept_prob, binary_tree.diverging

_next = _nuts_next if algo == 'NUTS' else _hmc_next

Expand All @@ -292,9 +296,9 @@ def sample_kernel(hmc_state):
rng, rng_momentum, rng_transition = random.split(hmc_state.rng, 3)
r = momentum_generator(hmc_state.adapt_state.mass_matrix_sqrt, rng_momentum)
vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
vv_state, num_steps, accept_prob = _next(hmc_state.adapt_state.step_size,
hmc_state.adapt_state.inverse_mass_matrix,
vv_state, rng_transition)
vv_state, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size,
hmc_state.adapt_state.inverse_mass_matrix,
vv_state, rng_transition)
# not update adapt_state after warmup phase
adapt_state = cond(hmc_state.i < wa_steps,
(hmc_state.i, accept_prob, vv_state.z, hmc_state.adapt_state),
Expand All @@ -307,7 +311,7 @@ def sample_kernel(hmc_state):
mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n

return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, num_steps,
accept_prob, mean_accept_prob, adapt_state, rng)
accept_prob, mean_accept_prob, diverging, adapt_state, rng)

# Make `init_kernel` and `sample_kernel` visible from the global scope once
# `hmc` is called for sphinx doc generation.
Expand Down
20 changes: 20 additions & 0 deletions test/test_mcmc_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,26 @@ def model(data):
assert_allclose(np.mean(samples['std']), true_std, rtol=0.05)


@pytest.mark.parametrize('kernel_cls', [HMC, NUTS])
@pytest.mark.parametrize('adapt_step_size', [True, False])
def test_diverging(kernel_cls, adapt_step_size):
data = random.normal(random.PRNGKey(0), (1000,))

def model(data):
loc = numpyro.sample('loc', dist.Normal(0., 1.))
numpyro.sample('obs', dist.Normal(loc, 1), obs=data)

kernel = kernel_cls(model, step_size=10., adapt_step_size=adapt_step_size, adapt_mass_matrix=False)
num_warmup = num_samples = 1000
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.run(random.PRNGKey(1), data, collect_fields=('z', 'diverging'), collect_warmup=True)
num_divergences = mcmc.get_samples()[1].sum()
if adapt_step_size:
assert num_divergences <= num_warmup
else:
assert_allclose(num_divergences, num_warmup + num_samples)


@pytest.mark.parametrize('use_init_params', [False, True])
@pytest.mark.parametrize('chain_method', ['parallel', 'sequential', 'vectorized'])
@pytest.mark.filterwarnings("ignore:There are not enough devices:UserWarning")
Expand Down

0 comments on commit 1a5c16a

Please sign in to comment.