From 655c36b1a1397792fa981eed25328bb442f6b444 Mon Sep 17 00:00:00 2001 From: Kenyon Ng Date: Mon, 4 Sep 2023 23:50:47 +1000 Subject: [PATCH] Moving max_num_doublings of NUTS from build_kernel to kernel (#566) Making the argument consistent with 'num_integration_steps' in HMC --- blackjax/mcmc/nuts.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/blackjax/mcmc/nuts.py b/blackjax/mcmc/nuts.py index e09841ccf..b185cff27 100644 --- a/blackjax/mcmc/nuts.py +++ b/blackjax/mcmc/nuts.py @@ -77,7 +77,6 @@ class NUTSInfo(NamedTuple): def build_kernel( integrator: Callable = integrators.velocity_verlet, divergence_threshold: int = 1000, - max_num_doublings: int = 10, ): """Build an iterative NUTS kernel. @@ -108,10 +107,6 @@ def build_kernel( divergence_threshold The absolute difference in energy above which we consider a transition "divergent". - max_num_doublings - The maximum number of times we expand the trajectory by - doubling the number of steps if the trajectory does not - turn onto itself. """ @@ -121,6 +116,7 @@ def kernel( logdensity_fn: Callable, step_size: float, inverse_mass_matrix: Array, + max_num_doublings: int = 10, ) -> tuple[hmc.HMCState, NUTSInfo]: """Generate a new sample with the NUTS kernel.""" @@ -224,7 +220,7 @@ def __new__( # type: ignore[misc] divergence_threshold: int = 1000, integrator: Callable = integrators.velocity_verlet, ) -> SamplingAlgorithm: - kernel = cls.build_kernel(integrator, divergence_threshold, max_num_doublings) + kernel = cls.build_kernel(integrator, divergence_threshold) def init_fn(position: ArrayLikeTree): return cls.init(position, logdensity_fn) @@ -236,6 +232,7 @@ def step_fn(rng_key: PRNGKey, state): logdensity_fn, step_size, inverse_mass_matrix, + max_num_doublings, ) return SamplingAlgorithm(init_fn, step_fn)