From 7684f0d912738436c08a1252e79683273055344f Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 29 May 2019 20:59:55 -0400 Subject: [PATCH] a pass of hmc docs and comments (#60) * separate hmc utils * temporary save * revise docs * update docs for remaining utils * change returns to return * add missing dot --- notebooks/time_series_forecasting.ipynb | 22 +- numpyro/diagnostics.py | 2 +- numpyro/hmc_util.py | 262 ++++++++++++++++++------ numpyro/mcmc.py | 17 +- 4 files changed, 223 insertions(+), 80 deletions(-) diff --git a/notebooks/time_series_forecasting.ipynb b/notebooks/time_series_forecasting.ipynb index 0e52ddad9..447113945 100644 --- a/notebooks/time_series_forecasting.ipynb +++ b/notebooks/time_series_forecasting.ipynb @@ -118,17 +118,19 @@ "source": [ "The model we are going to use is called **Seasonal, Global Trend**, which when tested on 3003 time series of the [M-3 competition](https://forecasters.org/resources/time-series-data/m3-competition/), has been known to outperform other models originally participating in the competition:\n", "\n", - "$$\n", "\\begin{equation*}\n", + "\\begin{gathered}\n", "\\text{exp_val}_{t} = \\text{level}_{t-1} + \\text{coef_trend} \\times \\text{level}_{t-1}^{\\text{pow_trend}}\n", "+ \\text{s}_t \\times \\text{level}_{t-1}^{\\text{pow_season}}, \\\\\n", "\\sigma_{t} = \\sigma \\times \\text{exp_val}_{t}^{\\text{powx}} + \\text{offset}, \\\\\n", - "y_{t} \\sim \\text{StudentT}(\\nu, \\text{exp_val}_{t}, \\sigma_{t}), \\\\\n", + "y_{t} \\sim \\text{StudentT}(\\nu, \\text{exp_val}_{t}, \\sigma_{t}),\n", + "\\end{gathered}\n", "\\end{equation*}\n", - "$$\n", + "\n", "where `level` and `s` follows the following recursion rules:\n", - "$$\n", + "\n", "\\begin{equation*}\n", + "\\begin{gathered}\n", "\\text{level_p} =\n", " \\begin{cases}\n", " y_t - \\text{s}_t \\times \\text{level}_{t-1}^{\\text{pow_season}} & \\text{if } t \\le \\text{seasonality}, \\\\ \n", @@ -136,9 +138,9 @@ " \\end{cases} \\\\\n", "\\text{level}_{t} = \\text{level_sm} \\times \\text{level_p} + (1 - \\text{level_sm}) \\times \\text{level}_{t-1}, \\\\\n", "\\text{s}_{t + \\text{seasonality}} = \\text{s_sm} \\times \\frac{y_{t} - \\text{level}_{t}}{\\text{level}_{t-1}^{\\text{pow_trend}}}\n", - "+ (1 - \\text{s_sm}) \\times \\text{s}_{t}. \\\\\n", - "\\end{equation*}\n", - "$$" + "+ (1 - \\text{s_sm}) \\times \\text{s}_{t}.\n", + "\\end{gathered}\n", + "\\end{equation*}" ] }, { @@ -389,8 +391,8 @@ "metadata": {}, "source": [ "Given `samples` from `mcmc`, we want to do forecasting for the testing dataset `y_test`. First, we will make some utilities to do forecasting given a sample. Note that to retrieve the last `level` and last `s` value, we substitute a sample to the model:\n", - "```\n", - "... level, s = substitute(sgt, asample)(y, seasonality)`.\n", + "```python\n", + "... level, s = substitute(sgt, asample)(y, seasonality)\n", "```" ] }, @@ -531,7 +533,7 @@ "source": [ "## References\n", "\n", - "[1] `Rlgt: Bayesian Exponential Smoothing Models with Trend Modifications`,
    \n", + "[1] `Rlgt: Bayesian Exponential Smoothing Models with Trend Modifications`,
    \n", "Slawek Smyl, Christoph Bergmeir, Erwin Wibowo, To Wang Ng, Trustees of Columbia University" ] } diff --git a/numpyro/diagnostics.py b/numpyro/diagnostics.py index e9a24cb81..4a9eb3aa6 100644 --- a/numpyro/diagnostics.py +++ b/numpyro/diagnostics.py @@ -80,7 +80,7 @@ def autocorrelation(x, axis=0): """ Computes the autocorrelation of samples at dimension ``axis``. - :param numpy.array x: the input array. + :param numpy.ndarray x: the input array. :param int axis: the dimension to calculate autocorrelation. :return: autocorrelation of ``x``. :rtype: numpy.ndarray diff --git a/numpyro/hmc_util.py b/numpyro/hmc_util.py index 6575afe1b..c63f1db15 100644 --- a/numpyro/hmc_util.py +++ b/numpyro/hmc_util.py @@ -15,11 +15,11 @@ "ss_state", "mm_state", "window_idx", "rng"]) IntegratorState = laxtuple("IntegratorState", ["z", "r", "potential_energy", "z_grad"]) -_TreeInfo = laxtuple('_TreeInfo', ['z_left', 'r_left', 'z_left_grad', - 'z_right', 'r_right', 'z_right_grad', - 'z_proposal', 'z_proposal_pe', 'z_proposal_grad', - 'depth', 'weight', 'r_sum', 'turning', 'diverging', - 'sum_accept_probs', 'num_proposals']) +TreeInfo = laxtuple('TreeInfo', ['z_left', 'r_left', 'z_left_grad', + 'z_right', 'r_right', 'z_right_grad', + 'z_proposal', 'z_proposal_pe', 'z_proposal_grad', + 'depth', 'weight', 'r_sum', 'turning', 'diverging', + 'sum_accept_probs', 'num_proposals']) def _cholesky_inverse(matrix): @@ -32,28 +32,46 @@ def _cholesky_inverse(matrix): def dual_averaging(t0=10, kappa=0.75, gamma=0.05): """ - Dual Averaging is a scheme to solve convex optimization problems. It belongs - to a class of subgradient methods which uses subgradients to update parameters - (in primal space) of a model. Under some conditions, the averages of generated - parameters during the scheme are guaranteed to converge to an optimal value. - However, a counter-intuitive aspect of traditional subgradient methods is - "new subgradients enter the model with decreasing weights" (see :math:`[1]`). - Dual Averaging scheme solves that phenomenon by updating parameters using - weights equally for subgradients (which lie in a dual space), hence we have - the name "dual averaging". - This class implements a dual averaging scheme which is adapted for Markov chain - Monte Carlo (MCMC) algorithms. To be more precise, we will replace subgradients - by some statistics calculated during an MCMC trajectory. In addition, - introducing some free parameters such as ``t0`` and ``kappa`` is helpful and - still guarantees the convergence of the scheme. - - **References** - [1] `Primal-dual subgradient methods for convex problems`, - Yurii Nesterov - [2] `The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo`, - Matthew D. Hoffman, Andrew Gelman + Dual Averaging is a scheme to solve convex optimization problems. It + belongs to a class of subgradient methods which uses subgradients (which + lie in a dual space) to update states (in primal space) of a model. Under + some conditions, the averages of generated parameters during the scheme are + guaranteed to converge to an optimal value. However, a counter-intuitive + aspect of traditional subgradient methods is "new subgradients enter the + model with decreasing weights" (see reference [1]). Dual Averaging scheme + resolves that issue by updating parameters using weights equally for + subgradients, hence we have the name "dual averaging". + + This class implements a dual averaging scheme which is adapted for Markov + chain Monte Carlo (MCMC) algorithms. To be more precise, we will replace + subgradients by some statistics calculated at the end of MCMC trajectories. + Following [2], we introduce some free parameters such as ``t0`` and + ``kappa``, which is helpful and still guarantees the convergence of the + scheme. + + **References:** + + 1. *Primal-dual subgradient methods for convex problems*, + Yurii Nesterov + 2. *The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo*, + Matthew D. Hoffman, Andrew Gelman + + :param int t0: A free parameter introduced in reference [2] that stabilizes + the initial steps of the scheme. Defaults to 10. + :param float kappa: A free parameter introduced in reference [2] that + controls the weights of steps of the scheme. For a small ``kappa``, the + scheme will quickly forget states from early steps. This should be a + number in :math:`(0.5, 1]`. Defaults to 0.75. + :param float gamma: A free parameter introduced in reference [1] which + controls the speed of the convergence of the scheme. Defaults to 0.05. + :return: a (`init_fn`, `update_fn`) pair. """ def init_fn(prox_center=0.): + """ + :param float prox_center: A parameter introduced in reference [1] which + pulls the primal sequence towards it. Defaults to 0. + :return: initial state for the scheme. + """ x_t = 0. x_avg = 0. # average of primal sequence g_avg = 0. # average of dual sequence @@ -61,6 +79,12 @@ def init_fn(prox_center=0.): return x_t, x_avg, g_avg, t, prox_center def update_fn(g, state): + """ + :param float g: The current subgradient or statistics calculated during + an MCMC trajectory. + :param state: Current state of the scheme. + :return: new state for the scheme. + """ x_t, x_avg, g_avg, t, prox_center = state t = t + 1 # g_avg = (g_1 + ... + g_t) / t @@ -80,15 +104,24 @@ def update_fn(g, state): def welford_covariance(diagonal=True): """ - Implements Welford's online method for estimating (co)variance (see :math:`[1]`). - Useful for adapting diagonal and dense mass structures for HMC. + Implements Welford's online method for estimating (co)variance. Useful for + adapting diagonal and dense mass structures for HMC. It is required that + each sample is a 1-dimensional array. + + **References:** + + 1. *The Art of Computer Programming*, + Donald E. Knuth - **References** - [1] `The Art of Computer Programming`, - Donald E. Knuth + :param bool diagonal: If True, we estimate the variance of samples. + Otherwise, we estimate the covariance of the samples. Defaults to True. + :return: a (`init_fn`, `update_fn`, `final_fn`) triple. """ def init_fn(size): - # TODO: replace by a better pattern + """ + :param int size: size of each sample. + :return: initial state for the scheme. + """ mean = np.zeros(size) if diagonal: m2 = np.zeros(size) @@ -98,6 +131,11 @@ def init_fn(size): return mean, m2, n def update_fn(sample, state): + """ + :param sample: A new sample. + :param state: Current state of the scheme. + :return: new state for the scheme. + """ mean, m2, n = state n = n + 1 delta_pre = sample - mean @@ -110,9 +148,13 @@ def update_fn(sample, state): return mean, m2, n def final_fn(state, regularize=False): + """ + :param state: Current state of the scheme. + :param bool regularize: Whether to adjust diagonal for numerical stability. + :return: a pair of estimated covariance and the square root of precision. + """ mean, m2, n = state - # TODO: when n=1, return 0; we temporarily do not check for that case - # because lax.cond is not yet available + # XXX it is not necessary to check for the case n=1 cov = m2 / (n - 1) if regularize: # Regularization from Stan @@ -121,7 +163,7 @@ def final_fn(state, regularize=False): if diagonal: cov = scaled_cov + shrinkage else: - cov = scaled_cov + shrinkage * np.identity(mean.shape[0], dtype=mean.dtype) + cov = scaled_cov + shrinkage * np.identity(mean.shape[0]) if np.ndim(cov) == 2: cov_inv_sqrt = _cholesky_inverse(cov) else: @@ -135,15 +177,30 @@ def velocity_verlet(potential_fn, kinetic_fn): r""" Second order symplectic integrator that uses the velocity verlet algorithm for position `z` and momentum `r`. + + :param potential_fn: Python callable that computes the potential energy + given input parameters. The input parameters to `potential_fn` can be + any python collection type. + :param kinetic_fn: Python callable that returns the kinetic energy given + inverse mass matrix and momentum. + :return: a pair of (`init_fn`, `update_fn`). """ def init_fn(z, r): - # TODO: init using the cache of potential_energy and z_grad? + """ + :param z: Position of the particle. + :param r: Momentum of the particle. + :return: initial state for the integrator. + """ potential_energy, z_grad = value_and_grad(potential_fn)(z) return IntegratorState(z, r, potential_energy, z_grad) def update_fn(step_size, inverse_mass_matrix, state): """ - Single step velocity verlet. + :param float step_size: Size of a single step. + :param inverse_mass_matrix: Inverse of mass matrix, which is used to + calculate kinetic energy. + :param state: Current state of the integrator. + :return: new state for the integrator. """ z, r, _, z_grad = state r = tree_multimap(lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad) # r(n+1/2) @@ -158,6 +215,25 @@ def update_fn(step_size, inverse_mass_matrix, state): def find_reasonable_step_size(potential_fn, kinetic_fn, momentum_generator, inverse_mass_matrix, position, rng, init_step_size): + """ + Finds a reasonable step size by tuning `init_step_size`. This function is used + to avoid working with a too large or too small step size in HMC. + + **References:** + + 1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*, + Matthew D. Hoffman, Andrew Gelman + + :param potential_fn: A callable to compute potential energy. + :param kinetic_fn: A callable to compute kinetic energy. + :param momentum_generator: A generator to get a random momentum variable. + :param inverse_mass_matrix: Inverse of mass matrix. + :param position: Current position of the particle. + :param jax.random.PRNGKey rng: Random key to be used as the source of randomness. + :param float init_step_size: Initial step size to be tuned. + :return: a reasonable value for step size. + :rtype: float + """ # We are going to find a step_size which make accept_prob (Metropolis correction) # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small, # then we have to decrease step_size; otherwise, increase step_size. @@ -194,6 +270,18 @@ def _cond_fn(state): def build_adaptation_schedule(num_steps): + """ + Builds a window adaptation schedule to be used during warmup phase of HMC. + + :param int num_steps: Number of warmup steps. + :return: a list of contiguous windows, each has attributes `start` and `end`, + where `start` is the starting index and `end` is the ending index of the window. + + **References:** + + 1. *Stan Reference Manual version 2.18*, + Stan Development Team + """ adaptation_schedule = [] # from Stan, for small num_steps if num_steps < 20: @@ -240,12 +328,40 @@ def _identity_step_size(inverse_mass_matrix, z, rng, step_size): def warmup_adapter(num_adapt_steps, find_reasonable_step_size=_identity_step_size, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, target_accept_prob=0.8): + """ + A scheme to adapt tunable parameters, namely step size and mass matrix, during + the warmup phase of HMC. + + :param int num_adapt_steps: Number of warmup steps. + :param find_reasonable_step_size: A callable to find a reasonable step size + at the beginning of each adaptation window. + :param bool adapt_step_size: A flag to decide if we want to adapt step_size + during warm-up phase using Dual Averaging scheme (defaults to ``True``). + :param bool adapt_mass_matrix: A flag to decide if we want to adapt mass + matrix during warm-up phase using Welford scheme (defaults to ``True``). + :param bool dense_mass: A flag to decide if mass matrix is dense or + diagonal (defaults to ``False``). + :param float target_accept_prob: Target acceptance probability for step size + adaptation using Dual Averaging. Increasing this value will lead to a smaller + step size, hence the sampling will be slower but more robust. Default to 0.8. + :return: a pair of (`init_fn`, `update_fn`). + """ ss_init, ss_update = dual_averaging() mm_init, mm_update, mm_final = welford_covariance(diagonal=not dense_mass) adaptation_schedule = np.array(build_adaptation_schedule(num_adapt_steps)) num_windows = len(adaptation_schedule) def init_fn(z, rng, step_size=1.0, inverse_mass_matrix=None, mass_matrix_size=None): + """ + :param z: Initial position of the integrator. + :param jax.random.PRNGKey rng: Random key to be used as the source of randomness. + :param float step_size: Initial step size. + :param inverse_mass_matrix: Inverse of the initial mass matrix. If ``None``, + inverse of mass matrix will be an identity matrix with size is decided + by the argument `mass_matrix_size`. + :param int mass_matrix_size: Size of the mass matrix. + :return: initial state of the adapt scheme. + """ rng, rng_ss = random.split(rng) if inverse_mass_matrix is None: assert mass_matrix_size is not None @@ -285,6 +401,13 @@ def _update_at_window_end(z, rng_ss, state): ss_state, mm_state, window_idx, rng) def update_fn(t, accept_prob, z, state): + """ + :param int t: The current time step. + :param float accept_prob: Acceptance probability of the current trajectory. + :param z: New position drawn at the end of the current trajectory. + :param state: Current state of the adapt scheme. + :return: new state of the adapt scheme. + """ step_size, inverse_mass_matrix, mass_matrix_sqrt, ss_state, mm_state, window_idx, rng = state rng, rng_ss = random.split(rng) @@ -292,7 +415,6 @@ def update_fn(t, accept_prob, z, state): if adapt_step_size: ss_state = ss_update(target_accept_prob - accept_prob, ss_state) # note: at the end of warmup phase, use average of log step_size - # TODO: should we make sure that we won't update step_size if t >= num_steps? log_step_size, log_step_size_avg, *_ = ss_state step_size = np.where(t == (num_adapt_steps - 1), np.exp(log_step_size_avg), @@ -391,10 +513,10 @@ def _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right, rng, sum_accept_probs = current_tree.sum_accept_probs + new_tree.sum_accept_probs num_proposals = current_tree.num_proposals + new_tree.num_proposals - return _TreeInfo(z_left, r_left, z_left_grad, z_right, r_right, r_right_grad, - z_proposal, z_proposal_pe, z_proposal_grad, - tree_depth, tree_weight, r_sum, turning, diverging, - sum_accept_probs, num_proposals) + return TreeInfo(z_left, r_left, z_left_grad, z_right, r_right, r_right_grad, + z_proposal, z_proposal_pe, z_proposal_grad, + tree_depth, tree_weight, r_sum, turning, diverging, + sum_accept_probs, num_proposals) def _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, step_size, going_right, @@ -414,10 +536,10 @@ def _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, st diverging = delta_energy > max_delta_energy accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0) - return _TreeInfo(z_new, r_new, z_new_grad, z_new, r_new, z_new_grad, - z_new, potential_energy_new, z_new_grad, - depth=0, weight=tree_weight, r_sum=r_new, turning=False, - diverging=diverging, sum_accept_probs=accept_prob, num_proposals=1) + return TreeInfo(z_new, r_new, z_new_grad, z_new, r_new, z_new_grad, + z_new, potential_energy_new, z_new_grad, + depth=0, weight=tree_weight, r_sum=r_new, turning=False, + diverging=diverging, sum_accept_probs=accept_prob, num_proposals=1) def _get_leaf(tree, going_right): @@ -516,32 +638,46 @@ def _body_fn(state): (basetree, False, r_ckpts, r_sum_ckpts, rng) ) # update depth and turning condition - return _TreeInfo(tree.z_left, tree.r_left, tree.z_left_grad, - tree.z_right, tree.r_right, tree.z_right_grad, - tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, - depth, tree.weight, tree.r_sum, turning, tree.diverging, - tree.sum_accept_probs, tree.num_proposals) + return TreeInfo(tree.z_left, tree.r_left, tree.z_left_grad, + tree.z_right, tree.r_right, tree.z_right_grad, + tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, + depth, tree.weight, tree.r_sum, turning, tree.diverging, + tree.sum_accept_probs, tree.num_proposals) def build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, step_size, rng, max_delta_energy=1000., max_tree_depth=10): """ + Builds a binary tree from the `verlet_state`. This is used in NUTS sampler. + **References:** - [1] `The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo`, - Matthew D. Hoffman, Andrew Gelman - [2] `A Conceptual Introduction to Hamiltonian Monte Carlo`, - Michael Betancourt + + 1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*, + Matthew D. Hoffman, Andrew Gelman + 2. *A Conceptual Introduction to Hamiltonian Monte Carlo*, + Michael Betancourt + + :param verlet_update: A callable to get a new integrator state given a current + integrator state. + :param kinetic_fn: A callable to compute kinetic energy. + :param verlet_state: Initial integrator state. + :param inverse_mass_matrix: Inverse of the mass matrix. + :param float step_size: Step size for the current trajectory. + :param jax.random.PRNGKey rng: random key to be used as the source of + randomness. + :param float max_delta_energy: A threshold to decide if the new state diverges + (based on the energy difference) too much from the initial integrator state. + :return: information of the tree. + :rtype: :data:`TreeInfo` """ z, r, potential_energy, z_grad = verlet_state energy_current = potential_energy + kinetic_fn(inverse_mass_matrix, r) - r_ckpts = np.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]), - dtype=inverse_mass_matrix.dtype) - r_sum_ckpts = np.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]), - dtype=inverse_mass_matrix.dtype) + r_ckpts = np.zeros((max_tree_depth, inverse_mass_matrix.shape[-1])) + r_sum_ckpts = np.zeros((max_tree_depth, inverse_mass_matrix.shape[-1])) - tree = _TreeInfo(z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad, - depth=0, weight=0., r_sum=r, turning=False, diverging=False, - sum_accept_probs=0., num_proposals=0) + tree = TreeInfo(z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad, + depth=0, weight=0., r_sum=r, turning=False, diverging=False, + sum_accept_probs=0., num_proposals=0) def _cond_fn(state): tree, _ = state @@ -606,8 +742,8 @@ def initialize_model(rng, model, *model_args, init_strategy='uniform', **model_k `prior` initializes the parameters by sampling from the prior for each of the sample sites. :param `**model_kwargs`: kwargs provided to the model. - :return: tuple of (`init_params`, `potential_fn`, `constrain_fn`) - `init_params` are values from the prior used to initiate MCMC. + :return: tuple of (`init_params`, `potential_fn`, `constrain_fn`), + `init_params` are values from the prior used to initiate MCMC, `constrain_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support. diff --git a/numpyro/mcmc.py b/numpyro/mcmc.py index 3e535dccc..a34997088 100644 --- a/numpyro/mcmc.py +++ b/numpyro/mcmc.py @@ -19,6 +19,7 @@ 'rng']) """ A :func:`~collections.namedtuple` consisting of the following fields: + - **i** - iteration. This is reset to 0 after warmup. - **z** - Python collection representing values (unconstrained samples from the posterior) at latent sites. @@ -90,9 +91,12 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'): **References:** - 1. *MCMC Using Hamiltonian Dynamics*, Radford M. Neal + 1. *MCMC Using Hamiltonian Dynamics*, + Radford M. Neal 2. *The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo*, Matthew D. Hoffman, and Andrew Gelman. + 3. *A Conceptual Introduction to Hamiltonian Monte Carlo`*, + Michael Betancourt :param potential_fn: Python callable that computes the potential energy given input parameters. The input parameters to `potential_fn` can be @@ -103,7 +107,7 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'): euclidean kinetic energy. :param str algo: Whether to run ``HMC`` with fixed number of steps or ``NUTS`` with adaptive path length. Default is ``NUTS``. - :return init_kernel, sample_kernel: Returns a tuple of callables, the first + :return: a tuple of callables (`init_kernel`, `sample_kernel`), the first one to initialize the sampler, and the second one to generate samples given an existing one. @@ -132,14 +136,15 @@ def hmc(potential_fn, kinetic_fn=None, algo='NUTS'): ... coefs = sample('beta', dist.Normal(coefs_mean, np.ones(3))) ... return sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels) >>> - >>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), model, data, labels) + >>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), + ... model, data, labels) >>> init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') >>> hmc_state = init_kernel(init_params, ... trajectory_length=10, ... num_warmup=300) - >>> hmc_states = fori_collect(500, sample_kernel, hmc_state, - ... transform=lambda x: constrain_fn(x.z)) - >>> print(np.mean(hmc_states['beta'], axis=0)) # doctest: +SKIP + >>> samples = fori_collect(500, sample_kernel, hmc_state, + ... transform=lambda state: constrain_fn(state.z)) + >>> print(np.mean(samples['beta'], axis=0)) # doctest: +SKIP [0.9153987 2.0754058 2.9621222] """ if kinetic_fn is None: