From 5247b7b6868c6116e76dcb10ae657d0ca049b70b Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Wed, 29 Nov 2023 09:08:52 -0800 Subject: [PATCH] Add meads and chees samplers from blackjax. These samplers pass the current MCMC tests (which just require sampling from a 1d gaussian), but do poorly enough on actual problems that there is probably a bug somewhere. I *think* this is a blackjax problem, but exposing these methods may help track down where the problem is. PiperOrigin-RevId: 586360980 --- bayeux/_src/bayeux.py | 3 +- bayeux/_src/mcmc/blackjax.py | 228 ++++++++++++++++++++++++++------- bayeux/_src/optimize/shared.py | 10 +- bayeux/_src/shared.py | 11 ++ bayeux/mcmc/__init__.py | 7 +- 5 files changed, 201 insertions(+), 58 deletions(-) diff --git a/bayeux/_src/bayeux.py b/bayeux/_src/bayeux.py index b6a4f48..aedb2db 100644 --- a/bayeux/_src/bayeux.py +++ b/bayeux/_src/bayeux.py @@ -29,8 +29,7 @@ "transform_fn", "inverse_transform_fn", "inverse_log_det_jacobian", - "initial_state", -) + "initial_state",) class _Namespace: diff --git a/bayeux/_src/mcmc/blackjax.py b/bayeux/_src/mcmc/blackjax.py index 1e2f207..5cabe4a 100644 --- a/bayeux/_src/mcmc/blackjax.py +++ b/bayeux/_src/mcmc/blackjax.py @@ -19,15 +19,21 @@ from bayeux._src import shared import blackjax import jax +import jax.numpy as jnp +import optax _ADAPT_FNS = { "window": blackjax.window_adaptation, "pathfinder": blackjax.pathfinder_adaptation, + "chees": blackjax.chees_adaptation, + "meads": blackjax.meads_adaptation, } _ALGORITHMS = { "hmc": blackjax.hmc, + "ghmc": blackjax.ghmc, + "dynamic_hmc": blackjax.dynamic_hmc, "nuts": blackjax.nuts, } @@ -51,9 +57,15 @@ class _BlackjaxSampler(shared.Base): def get_kwargs(self, **kwargs): adapt_fn = _ADAPT_FNS[self.adapt_fn] algorithm = _ALGORITHMS[self.algorithm] - return {adapt_fn: get_adaptation_kwargs(adapt_fn, algorithm, kwargs), - algorithm: get_algorithm_kwargs(algorithm, kwargs), - "extra_parameters": get_extra_kwargs(kwargs)} + extra_parameters = get_extra_kwargs(kwargs) + constrained_log_density = self.constrained_log_density() + adaptation_kwargs, run_kwargs = get_adaptation_kwargs( + adapt_fn, algorithm, constrained_log_density, extra_parameters | kwargs) + return {adapt_fn: adaptation_kwargs, + "adapt.run": run_kwargs, + algorithm: get_algorithm_kwargs( + algorithm, constrained_log_density, kwargs), + "extra_parameters": extra_parameters} def __call__(self, seed, **kwargs): init_key, sample_key = jax.random.split(seed) @@ -62,7 +74,24 @@ def __call__(self, seed, **kwargs): init_key, num_chains=kwargs["extra_parameters"]["num_chains"]) return _sample_blackjax( - log_density=self.constrained_log_density(), + initial_state=self.inverse_transform_fn(initial_state), + algorithm=_ALGORITHMS[self.algorithm], + transform_fn=self.transform_fn, + adapt_fn=_ADAPT_FNS[self.adapt_fn], + seed=sample_key, + kwargs=kwargs) + + +class _BlackjaxDynamicSampler(_BlackjaxSampler): + """Base class for blackjax samplers.""" + + def __call__(self, seed, **kwargs): + init_key, sample_key = jax.random.split(seed) + kwargs = self.get_kwargs(**kwargs) + initial_state = self.get_initial_state( + init_key, num_chains=kwargs["extra_parameters"]["num_chains"]) + + return _sample_blackjax_dynamic( initial_state=self.inverse_transform_fn(initial_state), algorithm=_ALGORITHMS[self.algorithm], transform_fn=self.transform_fn, @@ -77,6 +106,18 @@ class HMC(_BlackjaxSampler): algorithm = "hmc" +class CheesHMC(_BlackjaxDynamicSampler): + name = "blackjax_chees_hmc" + adapt_fn = "chees" + algorithm = "dynamic_hmc" + + +class MeadsHMC(_BlackjaxDynamicSampler): + name = "blackjax_meads_hmc" + adapt_fn = "meads" + algorithm = "ghmc" + + class HMCPathfinder(_BlackjaxSampler): name = "blackjax_hmc_pathfinder" adapt_fn = "pathfinder" @@ -95,23 +136,27 @@ class NUTSPathfinder(_BlackjaxSampler): algorithm = "nuts" -def _blackjax_inference_loop( +def _blackjax_adapt( seed, - init_position, adapt_fn, + kwarg_dict, + **kwargs): + adapt = adapt_fn(**kwarg_dict[adapt_fn]) + (last_state, parameters), _ = adapt.run( + rng_key=seed, **kwargs, + **kwarg_dict["adapt.run"]) + return last_state, parameters + + +def _blackjax_inference( + seed, + adapt_state, + adapt_parameters, algorithm, - log_density, num_draws, - num_adapt_draws, kwargs): - """Constructs and runs inference loop.""" - adapt_seed, inference_seed = jax.random.split(seed) - adapt = adapt_fn(logdensity_fn=log_density, **kwargs[adapt_fn]) - (last_state, parameters), _ = adapt.run( - rng_key=adapt_seed, position=init_position, num_steps=num_adapt_draws) - - algorithm_kwargs = kwargs[algorithm] | parameters - kernel = algorithm(log_density, **algorithm_kwargs).step + algorithm_kwargs = kwargs[algorithm] | adapt_parameters + kernel = algorithm(**algorithm_kwargs).step @jax.jit def inference_loop(rng_key): @@ -121,14 +166,37 @@ def one_step(state, rng_key): return state, (state, info) keys = jax.random.split(rng_key, num_draws) - _, (states, infos) = jax.lax.scan(one_step, last_state, keys) + _, (states, infos) = jax.lax.scan(one_step, adapt_state, keys) return states, infos - return inference_loop(inference_seed) + # Functions returned by chees adaptation. + adapt_parameters.pop("next_random_arg_fn", None) + adapt_parameters.pop("integration_steps_fn", None) + return inference_loop(seed), adapt_parameters -def _blackjax_stats_to_dict(sample_stats, potential_energy): +def _blackjax_inference_loop( + seed, + init_position, + adapt_fn, + algorithm, + num_draws, + kwargs): + """Constructs and runs inference loop.""" + adapt_seed, inference_seed = jax.random.split(seed) + adapt_state, adapt_parameters = _blackjax_adapt( + adapt_seed, adapt_fn, kwarg_dict=kwargs, position=init_position) + return _blackjax_inference( + inference_seed, + adapt_state, + adapt_parameters, + algorithm, + num_draws, + kwargs) + + +def _blackjax_stats_to_dict(sample_stats, potential_energy, adapt_parameters): """Extract ArviZ compatible stats from blackjax sampler. Adapted from https://github.com/pymc-devs/pymc @@ -136,6 +204,7 @@ def _blackjax_stats_to_dict(sample_stats, potential_energy): Args: sample_stats: Blackjax NUTSInfo object containing sampler statistics. potential_energy: Potential energy values of sampled positions. + adapt_parameters: Parameters from adaptation. Returns: Dictionary of sampler statistics. @@ -148,8 +217,15 @@ def _blackjax_stats_to_dict(sample_stats, potential_energy): "acceptance_rate": "acceptance_rate", # naming here depends "acceptance_probability": "acceptance_rate", # on blackjax version } - converted_stats = {} - converted_stats["lp"] = potential_energy + converted_stats = {"lp": potential_energy} + step_size = adapt_parameters.get("step_size", None) + if step_size is not None: + if jnp.ndim(step_size) == 0: + converted_stats["step_size"] = jnp.full_like(potential_energy, step_size) + else: + converted_stats["step_size"] = jnp.repeat( + step_size[..., None], repeats=jnp.shape(potential_energy)[-1], axis=-1 + ) for old_name, new_name in rename_key.items(): value = getattr(sample_stats, old_name, None) if value is not None: @@ -157,18 +233,26 @@ def _blackjax_stats_to_dict(sample_stats, potential_energy): return converted_stats -def get_adaptation_kwargs(adaptation_algorithm, algorithm, kwargs): +def get_adaptation_kwargs(adaptation_algorithm, algorithm, log_density, kwargs): """Sets defaults and merges user-provided adaptation keywords.""" adaptation_kwargs, adaptation_required = shared.get_default_signature( adaptation_algorithm) adaptation_kwargs.update( {k: kwargs[k] for k in adaptation_required if k in kwargs}) - adaptation_required.remove("logdensity_fn") - adaptation_required.remove("extra_parameters") - adaptation_required.remove("algorithm") - adaptation_kwargs["algorithm"] = algorithm - adaptation_kwargs = ( - get_algorithm_kwargs(algorithm, kwargs) | adaptation_kwargs) + if "logdensity_fn" in adaptation_required: + adaptation_kwargs["logdensity_fn"] = log_density + adaptation_required.remove("logdensity_fn") + elif "logprob_fn" in adaptation_required: + adaptation_kwargs["logprob_fn"] = log_density + adaptation_required.remove("logprob_fn") + + adaptation_required.discard("extra_parameters") + if "algorithm" in adaptation_required: + adaptation_required.remove("algorithm") + adaptation_kwargs["algorithm"] = algorithm + adaptation_kwargs = ( + get_algorithm_kwargs(algorithm, log_density, kwargs) | adaptation_kwargs + ) adaptation_required = adaptation_required - adaptation_kwargs.keys() @@ -183,7 +267,7 @@ def get_adaptation_kwargs(adaptation_algorithm, algorithm, kwargs): ) # step_size will get adapted -- maybe warn if this is set manually, and # suggest setting init_step_size instead? - adaptation_kwargs.pop("step_size") + adaptation_kwargs.pop("step_size", None) # blackjax doesn't have a pleasant way to accept this argument -- # window_adaptation calls `algorithm.build_kernel()` with no arguments, but # it should probably take the below arguments: @@ -191,13 +275,27 @@ def get_adaptation_kwargs(adaptation_algorithm, algorithm, kwargs): adaptation_kwargs.pop("integrator", None) adaptation_kwargs.pop("max_num_doublings", None) - return adaptation_kwargs + adapt = adaptation_algorithm(**adaptation_kwargs) + run_kwargs, run_required = shared.get_default_signature(adapt.run) + run_required.remove("rng_key") + run_kwargs.update({k: kwargs[k] for k in run_required if k in kwargs}) + if "optim" in run_required: + run_kwargs["optim"] = optax.adam(learning_rate=0.01) + run_required.remove("optim") + if "step_size" in run_required: + run_kwargs["step_size"] = 0.001 + run_required.remove("step_size") + run_kwargs["num_steps"] = kwargs.get("num_adapt_draws", + run_kwargs["num_steps"]) + + return adaptation_kwargs, run_kwargs -def get_algorithm_kwargs(algorithm, kwargs): +def get_algorithm_kwargs(algorithm, log_density, kwargs): """Sets defaults and merges user-provided keywords for sampling.""" algorithm_kwargs, algorithm_required = shared.get_default_signature(algorithm) kwargs_with_defaults = { + "logdensity_fn": log_density, "step_size": 0.01, "num_integration_steps": 8, } | kwargs @@ -208,7 +306,10 @@ def get_algorithm_kwargs(algorithm, kwargs): if k in kwargs_with_defaults }) algorithm_required.remove("logdensity_fn") - algorithm_required.remove("inverse_mass_matrix") + algorithm_required.discard("inverse_mass_matrix") + algorithm_required.discard("alpha") + algorithm_required.discard("delta") + algorithm_required.discard("momentum_inverse_scale") algorithm_required = algorithm_required - algorithm_kwargs.keys() if algorithm_required: @@ -223,9 +324,54 @@ def get_algorithm_kwargs(algorithm, kwargs): return algorithm_kwargs +def _sample_blackjax_dynamic( + *, + initial_state, + algorithm, + seed, + transform_fn, + adapt_fn, + kwargs): + """Constructs and runs blackjax sampler.""" + extra_parameters = kwargs.pop("extra_parameters") + num_draws = extra_parameters["num_draws"] + num_chains = extra_parameters["num_chains"] + chain_method = extra_parameters["chain_method"] + num_adapt_draws = extra_parameters["num_adapt_draws"] + + adapt_seed, seed = jax.random.split(seed) + adapt_state, adapt_parameters = _blackjax_adapt( + seed=adapt_seed, + adapt_fn=adapt_fn, + kwarg_dict=kwargs, + positions=initial_state, + ) + sampler = functools.partial( + _blackjax_inference, + adapt_parameters=adapt_parameters, + algorithm=algorithm, + num_draws=num_draws, + kwargs=kwargs) + map_seed = jax.random.split(seed, num_chains) + mapped_sampler = shared.map_fn(chain_method, sampler) + + (states, stats), adapt_parameters = mapped_sampler(map_seed, adapt_state) + draws = transform_fn(states.position) + if extra_parameters["return_pytree"]: + return draws + else: + potential_energy = states.logdensity + sample_stats = _blackjax_stats_to_dict( + stats, potential_energy, adapt_parameters) + if hasattr(draws, "_asdict"): + draws = draws._asdict() + elif not isinstance(draws, dict): + draws = {"var0": draws} + return az.from_dict(posterior=draws, sample_stats=sample_stats) + + def _sample_blackjax( *, - log_density, initial_state, algorithm, seed, @@ -240,29 +386,21 @@ def _sample_blackjax( num_adapt_draws = extra_parameters["num_adapt_draws"] sampler = functools.partial( _blackjax_inference_loop, - log_density=log_density, algorithm=algorithm, adapt_fn=adapt_fn, num_draws=num_draws, - num_adapt_draws=num_adapt_draws, kwargs=kwargs) map_seed = jax.random.split(seed, num_chains) - if chain_method == "parallel": - mapped_sampler = jax.pmap(sampler) - elif chain_method == "vectorized": - mapped_sampler = jax.vmap(sampler) - elif chain_method == "sequential": - mapped_sampler = functools.partial(jax.tree_map, sampler) - else: - raise ValueError(f"Chain method {chain_method} not supported.") + mapped_sampler = shared.map_fn(chain_method, sampler) - states, stats = mapped_sampler(map_seed, initial_state) + (states, stats), adapt_parameters = mapped_sampler(map_seed, initial_state) draws = transform_fn(states.position) if extra_parameters["return_pytree"]: return draws else: potential_energy = states.logdensity - sample_stats = _blackjax_stats_to_dict(stats, potential_energy) + sample_stats = _blackjax_stats_to_dict( + stats, potential_energy, adapt_parameters) if hasattr(draws, "_asdict"): draws = draws._asdict() elif not isinstance(draws, dict): diff --git a/bayeux/_src/optimize/shared.py b/bayeux/_src/optimize/shared.py index c0e07a9..d58d40f 100644 --- a/bayeux/_src/optimize/shared.py +++ b/bayeux/_src/optimize/shared.py @@ -14,11 +14,9 @@ """Shared functions for optimizers.""" import collections -import functools from bayeux._src import debug from bayeux._src import shared -import jax OptimizerResults = collections.namedtuple("OptimizerResults", @@ -88,13 +86,7 @@ def transformed_negative_log_prob(self): return lambda x: -self.log_density(self.transform_fn(x)) def _map_optimizer(self, chain_method, fit): - if chain_method == "parallel": - return jax.pmap(fit) - elif chain_method == "vectorized": - return jax.vmap(fit) - elif chain_method == "sequential": - return functools.partial(jax.tree_map, fit) - raise ValueError(f"Chain method {chain_method} not supported.") + return shared.map_fn(chain_method, fit) def _prep_args(self, seed, kwargs): num_particles = kwargs["extra_parameters"]["num_particles"] diff --git a/bayeux/_src/shared.py b/bayeux/_src/shared.py index ac98c4f..d16f181 100644 --- a/bayeux/_src/shared.py +++ b/bayeux/_src/shared.py @@ -15,6 +15,7 @@ """Shared functionality for MCMC sampling.""" import dataclasses +import functools import inspect from typing import Callable, Optional @@ -26,6 +27,16 @@ import oryx +def map_fn(chain_method, fn): + if chain_method == "parallel": + return jax.pmap(fn) + elif chain_method == "vectorized": + return jax.vmap(fn) + elif chain_method == "sequential": + return functools.partial(jax.tree_map, fn) + raise ValueError(f"Chain method {chain_method} not supported.") + + def _default_init( *, initial_state, diff --git a/bayeux/mcmc/__init__.py b/bayeux/mcmc/__init__.py index cd25ac3..629c46f 100644 --- a/bayeux/mcmc/__init__.py +++ b/bayeux/mcmc/__init__.py @@ -19,12 +19,15 @@ __all__ = [] if importlib.util.find_spec("blackjax") is not None: + from bayeux._src.mcmc.blackjax import CheesHMC as CheesHMCblackjax from bayeux._src.mcmc.blackjax import HMC as HMCblackjax from bayeux._src.mcmc.blackjax import HMCPathfinder as HMC_Pathfinder_blackjax + from bayeux._src.mcmc.blackjax import MeadsHMC as MeadsHMCblackjax from bayeux._src.mcmc.blackjax import NUTS as NUTSblackjax from bayeux._src.mcmc.blackjax import NUTSPathfinder as NUTS_Pathfinder_blackjax - __all__.extend(["HMCblackjax", "NUTSblackjax", - "HMC_Pathfinder_blackjax", "NUTS_Pathfinder_blackjax"]) + __all__.extend(["HMCblackjax", "CheesHMCblackjax", "MeadsHMCblackjax", + "NUTSblackjax", "HMC_Pathfinder_blackjax", + "NUTS_Pathfinder_blackjax"]) if importlib.util.find_spec("numpyro") is not None: from bayeux._src.mcmc.numpyro import HMC as HMCnumpyro