Skip to content

Commit

Permalink
Refactor multi-state sampler and add test for minimize method
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Dec 29, 2023
1 parent d03879c commit 551ce2b
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions chiron/tests/test_multistate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
@pytest.fixture
def ho_multistate_sampler() -> MultiStateSampler:
"""
Create a MultiStateSampler object for performing multistate simulations for a harmonic oscillator.
Create a multi-state sampler for a harmonic oscillator system.
Returns:
MultiStateSampler: The created MultiStateSampler object.
MultiStateSampler: The multi-state sampler object.
"""
import math
from openmm import unit
Expand All @@ -29,10 +29,17 @@ def ho_multistate_sampler() -> MultiStateSampler:
/ (math.e - 1.0)
for i in range(n_replicas)
]
import jax.numpy as jnp

ho_potential = HarmonicOscillatorPotential(ho.topology)
x0s = [
unit.Quantity(jnp.array([[x0, 0.0, 0.0]]), unit.angstrom)
for x0 in jnp.linspace(0.0, 1.0, n_replicas)
]
thermodynamic_states = [
ThermodynamicState(ho_potential, temperature=T) for T in temperatures
ThermodynamicState(
HarmonicOscillatorPotential(ho.topology, x0=x0), temperature=T
)
for T, x0 in zip(temperatures, x0s)
]
sampler_state = [SamplerState(ho.positions) for _ in temperatures]

Expand Down Expand Up @@ -68,4 +75,21 @@ def test_multistate_class(ho_multistate_sampler):


def test_multistate_minimize(ho_multistate_sampler):
"""
Test function for the `minimize` method of the `ho_multistate_sampler` object.
It checks if the sampler states are correctly minimized.
Parameters
----------
ho_multistate_sampler: The `ho_multistate_sampler` object to be tested.
"""

import numpy as np

ho_multistate_sampler.minimize()

assert np.allclose(
ho_multistate_sampler.sampler_states[0].x0, np.array([[0.0, 0.0, 0.0]])
)
assert np.allclose(ho_multistate_sampler.sampler_states[1].x0, np.array([[0.05, 0.0, 0.0]]), atol=1e-2)
assert np.allclose(ho_multistate_sampler.sampler_states[2].x0, np.array([[0.1, 0.0, 0.0]]), atol=1e-2)

0 comments on commit 551ce2b

Please sign in to comment.