Skip to content

Commit

Permalink
Refactoring integrator tests and adding simple implicit integrator test
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-graham committed Jan 8, 2020
1 parent 93b8a85 commit b3d0e8f
Showing 1 changed file with 52 additions and 27 deletions.
79 changes: 52 additions & 27 deletions tests/test_integrators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import mici.integrators as integrators
import mici.systems as systems
import mici.states as states
from mici.states import ChainState
from mici.errors import IntegratorError

SEED = 3046987125
Expand All @@ -12,8 +12,11 @@

class IntegratorTestCase(object):

def __init__(self):
self.rng = np.random.RandomState(SEED)
def __init__(self, integrator, states, h_diff_tol=5e-3, rng=None):
self.integrator = integrator
self.states = states
self.h_diff_tol = h_diff_tol
self.rng = np.random.RandomState(SEED) if rng is None else rng

def integrate_with_reversal(self, init_state, n_step):
state = init_state
Expand Down Expand Up @@ -113,58 +116,80 @@ class TestLeapfrogIntegratorWithEuclideanMetricSystemLinear(
LinearSystemIntegratorTestCase):

def __init__(self):
super().__init__()
system = systems.EuclideanMetricSystem(
lambda q: 0.5 * np.sum(q**2), grad_neg_log_dens=lambda q: q)
self.integrator = integrators.LeapfrogIntegrator(system, 0.5)
self.states = {size: [
states.ChainState(pos=q, mom=p, dir=1)
for q, p in self.rng.standard_normal((N_STATE, 2, size))]
integrator = integrators.LeapfrogIntegrator(system, 0.5)
rng = np.random.RandomState(SEED)
states = {size: [
ChainState(pos=q, mom=p, dir=1)
for q, p in rng.standard_normal((N_STATE, 2, size))]
for size in SIZES}
self.h_diff_tol = 5e-3
h_diff_tol = 5e-3
super().__init__(integrator, states, h_diff_tol, rng)


class TestLeapfrogIntegratorWithEuclideanMetricSystemNonLinear(
IntegratorTestCase):

def __init__(self):
super().__init__()
system = systems.EuclideanMetricSystem(
lambda q: 0.25 * np.sum(q**4), grad_neg_log_dens=lambda q: q**3)
self.integrator = integrators.LeapfrogIntegrator(system, 0.1)
self.states = {size: [
states.ChainState(pos=q, mom=p, dir=1)
for q, p in self.rng.standard_normal((N_STATE, 2, size))]
integrator = integrators.LeapfrogIntegrator(system, 0.1)
rng = np.random.RandomState(SEED)
states = {size: [
ChainState(pos=q, mom=p, dir=1)
for q, p in rng.standard_normal((N_STATE, 2, size))]
for size in SIZES}
self.h_diff_tol = 2e-2
h_diff_tol = 2e-2
super().__init__(integrator, states, h_diff_tol, rng)


class TestLeapfrogIntegratorWithGaussianEuclideanMetricLinearSystem(
LinearSystemIntegratorTestCase):

def __init__(self):
super().__init__()
system = systems.GaussianEuclideanMetricSystem(
lambda q: 0, grad_neg_log_dens=lambda q: 0 * q)
self.integrator = integrators.LeapfrogIntegrator(system, 0.5)
self.states = {size: [
states.ChainState(pos=q, mom=p, dir=1)
for q, p in self.rng.standard_normal((N_STATE, 2, size))]
integrator = integrators.LeapfrogIntegrator(system, 0.5)
rng = np.random.RandomState(SEED)
states = {size: [
ChainState(pos=q, mom=p, dir=1)
for q, p in rng.standard_normal((N_STATE, 2, size))]
for size in SIZES}
self.h_diff_tol = 1e-10
h_diff_tol = 1e-10
super().__init__(integrator, states, h_diff_tol, rng)


class TestLeapfrogIntegratorWithGaussianEuclideanMetricLinearSystem(
IntegratorTestCase):

def __init__(self):
super().__init__()
system = systems.GaussianEuclideanMetricSystem(
lambda q: 0.125 * np.sum(q**4),
grad_neg_log_dens=lambda q: 0.5 * q**3)
self.integrator = integrators.LeapfrogIntegrator(system, 0.1)
self.states = {size: [
states.ChainState(pos=q, mom=p, dir=1)
for q, p in self.rng.standard_normal((N_STATE, 2, size))]
integrator = integrators.LeapfrogIntegrator(system, 0.1)
rng = np.random.RandomState(SEED)
states = {size: [
ChainState(pos=q, mom=p, dir=1)
for q, p in rng.standard_normal((N_STATE, 2, size))]
for size in SIZES}
h_diff_tol = 1e-2
super().__init__(integrator, states, h_diff_tol, rng)


class TestImplicitLeapfrogIntegratorWithRiemannianMetricSystemLinear(
LinearSystemIntegratorTestCase):

def __init__(self):
system = systems.DenseRiemannianMetricSystem(
lambda q: 0.5 * np.sum(q**2), grad_neg_log_dens=lambda q: q,
metric_func=lambda q: np.identity(q.shape[0]),
vjp_metric_func=lambda q: lambda m: np.zeros_like(q))
integrator = integrators.ImplicitLeapfrogIntegrator(system, 0.5)
rng = np.random.RandomState(SEED)
states = {size: [
ChainState(pos=q, mom=p, dir=1)
for q, p in rng.standard_normal((N_STATE, 2, size))]
for size in SIZES}
self.h_diff_tol = 1e-2
h_diff_tol = 5e-3
super().__init__(integrator, states, h_diff_tol, rng)

0 comments on commit b3d0e8f

Please sign in to comment.