Skip to content

Commit

Permalink
New integrator, and add some metadata to integrators.py (#681)
Browse files Browse the repository at this point in the history
* TESTS

* TESTS

* UPDATE DOCSTRING

* ADD STREAMING VERSION

* ADD PRECONDITIONING TO MCLMC

* ADD PRECONDITIONING TO TUNING FOR MCLMC

* UPDATE GITIGNORE

* UPDATE GITIGNORE

* UPDATE TESTS

* UPDATE TESTS

* ADD DOCSTRING

* ADD TEST

* STREAMING AVERAGE

* ADD TEST

* REFACTOR RUN_INFERENCE_ALGORITHM

* UPDATE DOCSTRING

* Precommit

* CLEAN TESTS

* GITIGNORE

* PRECOMMIT CLEAN UP

* FIX SPELLING, ADD OMELYAN, EXPORT COEFFICIENTS

* TEMPORARILY ADD BENCHMARKS

* ADD INITIAL_POSITION

* FIX TEST

* CLEAN UP

* REMOVE BENCHMARKS

* ADD TEST

* REMOVE BENCHMARKS

* BUG FIX

* CHANGE PRECISION

* CHANGE PRECISION

* ADD OMELYAN TEST

* RENAME O

* UPDATE STREAMING AVG

* UPDATE PR

* RENAME STD_MAT

* MERGE MAIN

* REMOVE COEFFICIENT EXPORTS
  • Loading branch information
reubenharry authored May 27, 2024
1 parent 5831740 commit 20666de
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 21 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python

explore.py

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
46 changes: 32 additions & 14 deletions blackjax/mcmc/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@

__all__ = [
"mclachlan",
"omelyan",
"velocity_verlet",
"yoshida",
"implicit_midpoint",
"isokinetic_leapfrog",
"with_isokinetic_maruyama",
"isokinetic_velocity_verlet",
"isokinetic_mclachlan",
"isokinetic_omelyan",
"isokinetic_yoshida",
"implicit_midpoint",
]


Expand Down Expand Up @@ -70,7 +73,7 @@ def generalized_two_stage_integrator(
.. math:: \\frac{d}{dt}f = (O_1+O_2)f
The leapfrog operator can be seen as approximating :math:`e^{\\epsilon(O_1 + O_2)}`
The velocity_verlet operator can be seen as approximating :math:`e^{\\epsilon(O_1 + O_2)}`
by :math:`e^{\\epsilon O_1/2}e^{\\epsilon O_2}e^{\\epsilon O_1/2}`.
In a standard Hamiltonian, the forms of :math:`e^{\\epsilon O_2}` and
Expand Down Expand Up @@ -210,7 +213,7 @@ def format_euclidean_state_output(
return IntegratorState(position, momentum, logdensity, logdensity_grad)


def generate_euclidean_integrator(cofficients):
def generate_euclidean_integrator(coefficients):
"""Generate symplectic integrator for solving a Hamiltonian system.
The resulting integrator is volume-preserve and preserves the symplectic structure
Expand All @@ -225,7 +228,7 @@ def euclidean_integrator(
one_step = generalized_two_stage_integrator(
momentum_update_fn,
position_update_fn,
cofficients,
coefficients,
format_output_fn=format_euclidean_state_output,
)
return one_step
Expand All @@ -251,8 +254,8 @@ def euclidean_integrator(
of the kinetic energy. We are trading accuracy in exchange, and it is not
clear whether this is the right tradeoff.
"""
velocity_verlet_cofficients = [0.5, 1.0, 0.5]
velocity_verlet = generate_euclidean_integrator(velocity_verlet_cofficients)
velocity_verlet_coefficients = [0.5, 1.0, 0.5]
velocity_verlet = generate_euclidean_integrator(velocity_verlet_coefficients)

"""
Two-stage palindromic symplectic integrator derived in :cite:p:`blanes2014numerical`.
Expand All @@ -268,8 +271,8 @@ def euclidean_integrator(
b1 = 0.1931833275037836
a1 = 0.5
b2 = 1 - 2 * b1
mclachlan_cofficients = [b1, a1, b2, a1, b1]
mclachlan = generate_euclidean_integrator(mclachlan_cofficients)
mclachlan_coefficients = [b1, a1, b2, a1, b1]
mclachlan = generate_euclidean_integrator(mclachlan_coefficients)

"""
Three stages palindromic symplectic integrator derived in :cite:p:`mclachlan1995numerical`
Expand All @@ -284,8 +287,20 @@ def euclidean_integrator(
a1 = 0.29619504261126
b2 = 0.5 - b1
a2 = 1 - 2 * a1
yoshida_cofficients = [b1, a1, b2, a2, b2, a1, b1]
yoshida = generate_euclidean_integrator(yoshida_cofficients)
yoshida_coefficients = [b1, a1, b2, a2, b2, a1, b1]
yoshida = generate_euclidean_integrator(yoshida_coefficients)

"""11 stage Omelyan integrator [I.P. Omelyan, I.M. Mryglod and R. Folk, Comput. Phys. Commun. 151 (2003) 272.],
4MN5FV in [Takaishi, Tetsuya, and Philippe De Forcrand. "Testing and tuning symplectic integrators for the hybrid Monte Carlo algorithm in lattice QCD." Physical Review E 73.3 (2006): 036706.]
popular in LQCD"""
b1 = 0.08398315262876693
a1 = 0.2539785108410595
b2 = 0.6822365335719091
a2 = -0.03230286765269967
b3 = 0.5 - b1 - b2
a3 = 1 - 2 * (a1 + a2)
omelyan_coefficients = [b1, a1, b2, a2, b3, a3, b3, a2, b2, a1, b1]
omelyan = generate_euclidean_integrator(omelyan_coefficients)


# Intergrators with non Euclidean updates
Expand Down Expand Up @@ -372,9 +387,12 @@ def isokinetic_integrator(
return isokinetic_integrator


isokinetic_leapfrog = generate_isokinetic_integrator(velocity_verlet_cofficients)
isokinetic_yoshida = generate_isokinetic_integrator(yoshida_cofficients)
isokinetic_mclachlan = generate_isokinetic_integrator(mclachlan_cofficients)
isokinetic_velocity_verlet = generate_isokinetic_integrator(
velocity_verlet_coefficients
)
isokinetic_yoshida = generate_isokinetic_integrator(yoshida_coefficients)
isokinetic_mclachlan = generate_isokinetic_integrator(mclachlan_coefficients)
isokinetic_omelyan = generate_isokinetic_integrator(omelyan_coefficients)


def partially_refresh_momentum(momentum, rng_key, step_size, L):
Expand Down
2 changes: 1 addition & 1 deletion blackjax/mcmc/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def buildtree_integrate(
"""
if tree_depth == 0:
# Base case - take one leapfrog step in the direction v.
# Base case - take one velocity_verlet step in the direction v.
next_state = integrator(initial_state, direction * step_size)
new_proposal = generate_proposal(initial_energy, next_state)
is_diverging = -new_proposal.weight > divergence_threshold
Expand Down
12 changes: 8 additions & 4 deletions tests/mcmc/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,15 @@ def kinetic_energy(p, position=None):
"velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4},
"mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-4},
"yoshida": {"algorithm": integrators.yoshida, "precision": 1e-4},
"omelyan": {"algorithm": integrators.omelyan, "precision": 1e-4},
"implicit_midpoint": {
"algorithm": integrators.implicit_midpoint,
"precision": 1e-4,
},
"isokinetic_leapfrog": {"algorithm": integrators.isokinetic_leapfrog},
"isokinetic_velocity_verlet": {"algorithm": integrators.isokinetic_velocity_verlet},
"isokinetic_mclachlan": {"algorithm": integrators.isokinetic_mclachlan},
"isokinetic_yoshida": {"algorithm": integrators.isokinetic_yoshida},
"isokinetic_omelyan": {"algorithm": integrators.isokinetic_omelyan},
}


Expand All @@ -168,6 +170,7 @@ class IntegratorTest(chex.TestCase):
"velocity_verlet",
"mclachlan",
"yoshida",
"omelyan",
"implicit_midpoint",
],
)
Expand Down Expand Up @@ -241,13 +244,13 @@ def test_esh_momentum_update(self, dims):
np.testing.assert_array_almost_equal(next_momentum, next_momentum1)

@chex.all_variants(with_pmap=False)
def test_isokinetic_leapfrog(self):
def test_isokinetic_velocity_verlet(self):
cov = jnp.asarray([[1.0, 0.5, 0.1], [0.5, 2.0, -0.1], [0.1, -0.1, 3.0]])
logdensity_fn = lambda x: stats.multivariate_normal.logpdf(
x, jnp.zeros([3]), cov
)

step = self.variant(integrators.isokinetic_leapfrog(logdensity_fn))
step = self.variant(integrators.isokinetic_velocity_verlet(logdensity_fn))

rng = jax.random.key(4263456)
key0, key1 = jax.random.split(rng, 2)
Expand Down Expand Up @@ -296,9 +299,10 @@ def test_isokinetic_leapfrog(self):
@chex.all_variants(with_pmap=False)
@parameterized.parameters(
[
"isokinetic_leapfrog",
"isokinetic_velocity_verlet",
"isokinetic_mclachlan",
"isokinetic_yoshida",
"isokinetic_omelyan",
],
)
def test_isokinetic_integrator(self, integrator_name):
Expand Down

0 comments on commit 20666de

Please sign in to comment.