diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index 23155db..4125eeb 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -5,10 +5,10 @@ @pytest.fixture def ho_multistate_sampler() -> MultiStateSampler: """ - Create a MultiStateSampler object for performing multistate simulations for a harmonic oscillator. + Create a multi-state sampler for a harmonic oscillator system. Returns: - MultiStateSampler: The created MultiStateSampler object. + MultiStateSampler: The multi-state sampler object. """ import math from openmm import unit @@ -29,10 +29,17 @@ def ho_multistate_sampler() -> MultiStateSampler: / (math.e - 1.0) for i in range(n_replicas) ] + import jax.numpy as jnp - ho_potential = HarmonicOscillatorPotential(ho.topology) + x0s = [ + unit.Quantity(jnp.array([[x0, 0.0, 0.0]]), unit.angstrom) + for x0 in jnp.linspace(0.0, 1.0, n_replicas) + ] thermodynamic_states = [ - ThermodynamicState(ho_potential, temperature=T) for T in temperatures + ThermodynamicState( + HarmonicOscillatorPotential(ho.topology, x0=x0), temperature=T + ) + for T, x0 in zip(temperatures, x0s) ] sampler_state = [SamplerState(ho.positions) for _ in temperatures] @@ -68,4 +75,21 @@ def test_multistate_class(ho_multistate_sampler): def test_multistate_minimize(ho_multistate_sampler): + """ + Test function for the `minimize` method of the `ho_multistate_sampler` object. + It checks if the sampler states are correctly minimized. + + Parameters + ---------- + ho_multistate_sampler: The `ho_multistate_sampler` object to be tested. + """ + + import numpy as np + ho_multistate_sampler.minimize() + + assert np.allclose( + ho_multistate_sampler.sampler_states[0].x0, np.array([[0.0, 0.0, 0.0]]) + ) + assert np.allclose(ho_multistate_sampler.sampler_states[1].x0, np.array([[0.05, 0.0, 0.0]]), atol=1e-2) + assert np.allclose(ho_multistate_sampler.sampler_states[2].x0, np.array([[0.1, 0.0, 0.0]]), atol=1e-2)