diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index 0e307f5..2917564 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -9,6 +9,7 @@ on: pull_request: branches: - "main" + - "multistage" schedule: # Weekly tests run on main by default: # Scheduled workflows run on the latest commit on the default or base branch. diff --git a/Examples/Idealgas.py b/Examples/Idealgas.py new file mode 100644 index 0000000..32ef6f0 --- /dev/null +++ b/Examples/Idealgas.py @@ -0,0 +1,150 @@ +from openmmtools.testsystems import IdealGas +from openmm import unit + +""" +This example explore an ideal gas system, where the particles are non-interacting. +This will use the MonteCarloBarostatMove to sample the volume of the system and +MonteCarloDisplacementMove to sample the particle positions. + +This utilizes the IdealGas example from openmmtools to initialize particle positions and topology. + +""" + +# Use the IdealGas example from openmmtools to initialize particle positions and topology +# For this example, the topology provides the masses for the particles + +n_particles = 216 +temperature = 298 * unit.kelvin +pressure = 1 * unit.atmosphere + +ideal_gas = IdealGas(nparticles=n_particles, temperature=temperature, pressure=pressure) + + +from chiron.potential import IdealGasPotential +from chiron.utils import PRNG, get_list_of_mass +import jax.numpy as jnp + +# particles are non interacting +cutoff = 0.0 * unit.nanometer +ideal_gas_potential = IdealGasPotential(ideal_gas.topology) + +from chiron.states import SamplerState, ThermodynamicState + +# define the thermodynamic state +thermodynamic_state = ThermodynamicState( + potential=ideal_gas_potential, + temperature=temperature, + pressure=pressure, +) + +PRNG.set_seed(1234) + + +# define the sampler state +sampler_state = SamplerState( + positions=ideal_gas.positions, + current_PRNG_key=PRNG.get_random_key(), + box_vectors=ideal_gas.system.getDefaultPeriodicBoxVectors(), +) + +from chiron.neighbors import PairListNsqrd, OrthogonalPeriodicSpace + +# define the pair list for an orthogonal periodic space +# since particles are non-interacting, this will not really do much +# but will be used to appropriately wrap particles in space +nbr_list = PairListNsqrd(OrthogonalPeriodicSpace(), cutoff=cutoff) +nbr_list.build_from_state(sampler_state) + +from chiron.reporters import MCReporter + +# initialize a reporter to save the simulation data +filename = "test_mc_ideal_gas.h5" +import os + +if os.path.isfile(filename): + os.remove(filename) +reporter = MCReporter(filename, 100) + + +from chiron.mcmc import ( + MonteCarloDisplacementMove, + MonteCarloBarostatMove, + MoveSchedule, + MCMCSampler, +) + +# initialize the displacement move +mc_barostat_move = MonteCarloBarostatMove( + volume_max_scale=0.2, + number_of_moves=10, + reporter=reporter, + autotune=True, + autotune_interval=100, +) + +# initialize the barostat move and the move schedule +metropolis_displacement_move = MonteCarloDisplacementMove( + displacement_sigma=0.1 * unit.nanometer, + number_of_moves=100, + autotune=True, + autotune_interval=100, +) + +# define the move schedule +move_set = MoveSchedule( + [ + ("MonteCarloDisplacementMove", metropolis_displacement_move), + ("MonteCarloBarostatMove", mc_barostat_move), + ] +) + +sampler = MCMCSampler(move_set) +sampler.run( + sampler_state, thermodynamic_state, n_iterations=10, nbr_list=nbr_list +) # how many times to repeat + +# get the volume from the reporter +volume = reporter.get_property("volume") +step = reporter.get_property("elapsed_step") + + +import matplotlib.pyplot as plt + +plt.plot(step, volume) +plt.show() + +# get expectations +ideal_volume = ideal_gas.get_volume_expectation(thermodynamic_state) +ideal_volume_std = ideal_gas.get_volume_standard_deviation(thermodynamic_state) + +print("ideal volume and standard deviation: ", ideal_volume, ideal_volume_std) + + +volume_mean = jnp.mean(jnp.array(volume)) * unit.nanometer**3 +volume_std = jnp.std(jnp.array(volume)) * unit.nanometer**3 + + +print("measured volume and standard deviation: ", volume_mean, volume_std) + +# get the masses of particles from the topology +masses = get_list_of_mass(ideal_gas.topology) + +sum_of_masses = jnp.sum(jnp.array(masses.value_in_unit(unit.amu))) * unit.amu + +ideal_density = sum_of_masses / unit.AVOGADRO_CONSTANT_NA / ideal_volume +measured_density = sum_of_masses / unit.AVOGADRO_CONSTANT_NA / volume_mean + +assert jnp.isclose( + ideal_density.value_in_unit(unit.kilogram / unit.meter**3), + measured_density.value_in_unit(unit.kilogram / unit.meter**3), + atol=1e-1, +) +# see if within 5% of ideal volume +assert ( + abs(ideal_volume - volume_mean) / ideal_volume < 0.05 +), f"Warning: {abs(ideal_volume - volume_mean) / ideal_volume} exceeds the 5% threshold" + +# see if within 10% of the ideal standard deviation of the volume +assert ( + abs(ideal_volume_std - volume_std) / ideal_volume_std < 0.1 +), f"Warning: {abs(ideal_volume_std - volume_std) / ideal_volume_std} exceeds the 10% threshold" diff --git a/Examples/LJ_MCMC.py b/Examples/LJ_MCMC.py new file mode 100644 index 0000000..b3cb2e6 --- /dev/null +++ b/Examples/LJ_MCMC.py @@ -0,0 +1,165 @@ +from openmm import unit +from openmm import app + +""" +This example explore a Lennard-Jones system, where a single bead represents a united atom methane molecule, +modeled with the UA-TraPPE force field. + + +""" +n_particles = 1100 +temperature = 140 * unit.kelvin +pressure = 13.00765 * unit.atmosphere +mass = unit.Quantity(16.04, unit.gram / unit.mole) + +# create the topology +lj_topology = app.Topology() +element = app.Element(1000, "CH4", "CH4", mass) +chain = lj_topology.addChain() +for i in range(n_particles): + residue = lj_topology.addResidue("CH4", chain) + lj_topology.addAtom("CH4", element, residue) + +import jax.numpy as jnp + +# these were generated in Mbuild using fill_box which wraps packmol +# a minimum spacing of 0.4 nm was used during construction. + +from chiron.utils import get_full_path + +positions = jnp.load(get_full_path("Examples/methane_coords.npy")) * unit.nanometer + +box_vectors = ( + jnp.array( + [ + [4.275021399280942, 0.0, 0.0], + [0.0, 4.275021399280942, 0.0], + [0.0, 0.0, 4.275021399280942], + ] + ) + * unit.nanometer +) + +from chiron.potential import LJPotential +from chiron.utils import PRNG +import jax.numpy as jnp + +# + +# initialize the LennardJones potential for UA-TraPPE methane +# +sigma = 0.373 * unit.nanometer +epsilon = 0.2941 * unit.kilocalories_per_mole +cutoff = 1.4 * unit.nanometer + +lj_potential = LJPotential(lj_topology, sigma=sigma, epsilon=epsilon, cutoff=cutoff) + +from chiron.states import SamplerState, ThermodynamicState + +# define the thermodynamic state +thermodynamic_state = ThermodynamicState( + potential=lj_potential, + temperature=temperature, + pressure=pressure, +) + +PRNG.set_seed(1234) + + +# define the sampler state +sampler_state = SamplerState( + positions=positions, current_PRNG_key=PRNG.get_random_key(), box_vectors=box_vectors +) + + +from chiron.neighbors import PairListNsqrd, OrthogonalPeriodicSpace + +# define the pair list for an orthogonal periodic space +# since particles are non-interacting, this will not really do much +# but will appropriately wrap particles in space +nbr_list = PairListNsqrd(OrthogonalPeriodicSpace(), cutoff=cutoff) +nbr_list.build_from_state(sampler_state) + +# CRI: minimizer is not working correctly on my mac +# from chiron.minimze import minimize_energy +# +# results = minimize_energy( +# sampler_state.positions, lj_potential.compute_energy, nbr_list, maxiter=100 +# ) +# +# min_x = results.params +# +# sampler_state.positions = min_x + +from chiron.reporters import MCReporter + +# initialize a reporter to save the simulation data +import os + + +filename_displacement = "test_mc_lj_disp.h5" + +if os.path.isfile(filename_displacement): + os.remove(filename_displacement) +reporter_displacement = MCReporter(filename_displacement, 10) + +from chiron.mcmc import MonteCarloDisplacementMove + +mc_displacement_move = MonteCarloDisplacementMove( + displacement_sigma=0.001 * unit.nanometer, + number_of_moves=100, + reporter=reporter_displacement, + report_interval=10, + autotune=True, + autotune_interval=100, +) + +filename_barostat = "test_mc_lj_barostat.h5" +if os.path.isfile(filename_barostat): + os.remove(filename_barostat) +reporter_barostat = MCReporter(filename_barostat, 1) + + +from chiron.mcmc import MonteCarloBarostatMove + +mc_barostat_move = MonteCarloBarostatMove( + volume_max_scale=0.1, + number_of_moves=10, + reporter=reporter_barostat, + report_interval=1, + autotune=True, + autotune_interval=50, +) + +from chiron.reporters import LangevinDynamicsReporter + +filename_langevin = "test_mc_lj_langevin.h5" + +if os.path.isfile(filename_langevin): + os.remove(filename_langevin) +reporter_langevin = LangevinDynamicsReporter(filename_langevin, 10) + +from chiron.mcmc import LangevinDynamicsMove + +langevin_dynamics_move = LangevinDynamicsMove( + timestep=1.0 * unit.femtoseconds, + collision_rate=1.0 / unit.picoseconds, + number_of_steps=1000, + reporter=reporter_langevin, + report_interval=10, +) + +from chiron.mcmc import MoveSchedule + +move_set = MoveSchedule( + [ + ("LangevinDynamicsMove", langevin_dynamics_move), + ("MonteCarloDisplacementMove", mc_displacement_move), + ("MonteCarloBarostatMove", mc_barostat_move), + ] +) + +from chiron.mcmc import MCMCSampler + +sampler = MCMCSampler(move_set) +sampler.run(sampler_state, thermodynamic_state, n_iterations=100, nbr_list=nbr_list) diff --git a/Examples/LJ_langevin.py b/Examples/LJ_langevin.py index d769b1a..8bb38d1 100644 --- a/Examples/LJ_langevin.py +++ b/Examples/LJ_langevin.py @@ -19,31 +19,47 @@ lj_fluid.topology, sigma=sigma, epsilon=epsilon, cutoff=cutoff ) + +from chiron.utils import PRNG + +PRNG.set_seed(1234) + from chiron.states import SamplerState, ThermodynamicState # define the sampler state sampler_state = SamplerState( - x0=lj_fluid.positions, box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors() + positions=lj_fluid.positions, + current_PRNG_key=PRNG.get_random_key(), + box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors(), ) # define the thermodynamic state thermodynamic_state = ThermodynamicState( - potential=lj_potential, temperature=300 * unit.kelvin + potential=lj_potential, + temperature=300 * unit.kelvin, ) + from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace -# define the neighbor list for an orthogonal periodic space +# Set up a neighbor list for an orthogonal periodic box with a cutoff of 3.0 * sigma and skin of 0.5 * sigma, +# where sigma = 0.34 nm. +# The class we instantiate, NeighborListNsqrd, uses an O(N^2) calculation to build the neighbor list, +# but uses a buffer (i.e., the skin) to avoid needing to perform the O(N^2) calculation at every step. +# With this routine, the calculation at each step between builds is O(N*n_max_neighbors). +# For the conditions considered here, n_max_neighbors is set to 180 (note this will increase if necessary) +# and thus there is ~5 reduction in computational cost compared to a brute force approach (i.e., PairListNsqrd). + skin = 0.5 * unit.nanometer nbr_list = NeighborListNsqrd( OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 ) -# build the neighbor list from the sampler state +# perform the initial build of the neighbor list from the sampler state nbr_list.build_from_state(sampler_state) -from chiron.reporters import _SimulationReporter +from chiron.reporters import LangevinDynamicsReporter # initialize a reporter to save the simulation data filename = "test_lj.h5" @@ -51,18 +67,25 @@ if os.path.isfile(filename): os.remove(filename) -reporter = _SimulationReporter("test_lj.h5", lj_fluid.topology, 1) +reporter = LangevinDynamicsReporter( + "test_lj.h5", + 1, + lj_fluid.topology, +) from chiron.integrators import LangevinIntegrator # initialize the Langevin integrator -integrator = LangevinIntegrator(reporter=reporter, report_frequency=100) -print("init_energy: ", lj_potential.compute_energy(sampler_state.x0, nbr_list)) +integrator = LangevinIntegrator(reporter=reporter, report_interval=100) +print("init_energy: ", lj_potential.compute_energy(sampler_state.positions, nbr_list)) -integrator.run( +# run the simulation +# note, typically we will not be calling the integrator directly, +# but instead using the LangevinDynamics Move in the MCMC Sampler. +updated_sampler_state, updated_nbr_list = integrator.run( sampler_state, thermodynamic_state, - n_steps=5000, + number_of_steps=1000, nbr_list=nbr_list, progress_bar=True, ) @@ -71,9 +94,11 @@ # read the data from the reporter with h5py.File("test_lj.h5", "r") as f: - energies = f["energy"][:] + energies = f["potential_energy"][:] steps = f["step"][:] +energies = reporter.get_property("potential_energy") +steps = reporter.get_property("step") # plot the energy import matplotlib.pyplot as plt diff --git a/Examples/LJ_mcmove.py b/Examples/LJ_mcmove.py index bc673f6..09fa6fa 100644 --- a/Examples/LJ_mcmove.py +++ b/Examples/LJ_mcmove.py @@ -20,10 +20,15 @@ ) from chiron.states import SamplerState, ThermodynamicState +from chiron.utils import PRNG + +PRNG.set_seed(1234) # define the sampler state sampler_state = SamplerState( - x0=lj_fluid.positions, box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors() + positions=lj_fluid.positions, + current_PRNG_key=PRNG.get_random_key(), + box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors(), ) # define the thermodynamic state @@ -39,29 +44,61 @@ nbr_list = NeighborListNsqrd( OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 ) -from chiron.neighbors import PairList +from chiron.neighbors import PairListNsqrd # build the neighbor list from the sampler state nbr_list.build_from_state(sampler_state) -from chiron.reporters import _SimulationReporter +from chiron.reporters import MCReporter # initialize a reporter to save the simulation data -filename = "test_lj.h5" +filename = "test_mc_lj.h5" import os if os.path.isfile(filename): os.remove(filename) -reporter = _SimulationReporter("test_mc_lj.h5", lj_fluid.topology, 1) +reporter = MCReporter(filename, 1) -from chiron.mcmc import MetropolisDisplacementMove +from chiron.mcmc import MonteCarloDisplacementMove -mc_move = MetropolisDisplacementMove( - seed=1234, +mc_move = MonteCarloDisplacementMove( displacement_sigma=0.01 * unit.nanometer, - nr_of_moves=1000, + number_of_moves=5000, reporter=reporter, + report_interval=1, + autotune=True, + autotune_interval=100, ) -mc_move.run(sampler_state, thermodynamic_state, nbr_list, True) +mc_move.update(sampler_state, thermodynamic_state, nbr_list) + +stats = mc_move.statistics +print(stats["n_accepted"] / stats["n_proposed"]) + + +acceptance_probability = reporter.get_property("acceptance_probability") +displacement_sigma = reporter.get_property("displacement_sigma") +potential_energy = reporter.get_property("potential_energy") +step = reporter.get_property("step") + +# plot the energy +import matplotlib.pyplot as plt + +plt.subplot(3, 1, 1) + +plt.plot(step, displacement_sigma) +plt.ylabel("displacement_sigma (nm)") + +plt.subplot(3, 1, 2) + +plt.plot(step, acceptance_probability) +plt.ylabel("acceptance_probability") + + +plt.subplot(3, 1, 3) + +plt.plot(step, potential_energy) +plt.xlabel("Step") +plt.ylabel("potential_energy (kj/mol)") +plt.show() diff --git a/Examples/methane_coords.npy b/Examples/methane_coords.npy new file mode 100644 index 0000000..b769491 Binary files /dev/null and b/Examples/methane_coords.npy differ diff --git a/chiron/integrators.py b/chiron/integrators.py index 0d77452..fa04063 100644 --- a/chiron/integrators.py +++ b/chiron/integrators.py @@ -5,7 +5,7 @@ from openmm import unit from .states import SamplerState, ThermodynamicState from .reporters import LangevinDynamicsReporter -from typing import Optional +from typing import Optional, Tuple from .potential import NeuralNetworkPotential from .neighbors import PairsBase @@ -24,9 +24,10 @@ class LangevinIntegrator: def __init__( self, - stepsize=1.0 * unit.femtoseconds, + timestep=1.0 * unit.femtoseconds, collision_rate=1.0 / unit.picoseconds, - report_frequency: int = 100, + refresh_velocities: bool = False, + report_interval: int = 100, reporter: Optional[LangevinDynamicsReporter] = None, save_traj_in_memory: bool = False, ) -> None: @@ -35,12 +36,14 @@ def __init__( Parameters ---------- - stepsize : unit.Quantity, optional + timestep : unit.Quantity, optional Time step of integration with units of time. Default is 1.0 * unit.femtoseconds. collision_rate : unit.Quantity, optional Collision rate for the Langevin dynamics, with units 1/time. Default is 1.0 / unit.picoseconds. - report_frequency : int, optional - Frequency of saving the simulation data. Default is 100. + refresh_velocities : bool, optional + Flag indicating whether to reinitialize the velocities each time the run function is called. Default is False. + report_interval : int, optional + Interval between saving the simulation data. Default is 100. reporter : SimulationReporter, optional Reporter object for saving the simulation data. Default is None. save_traj_in_memory: bool @@ -50,11 +53,11 @@ def __init__( from loguru import logger as log self.kB = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA - log.info(f"stepsize = {stepsize}") + log.info(f"timestep = {timestep}") log.info(f"collision_rate = {collision_rate}") - log.info(f"report_frequency = {report_frequency}") + log.info(f"report_interval = {report_interval}") - self.stepsize = stepsize + self.timestep = timestep self.collision_rate = collision_rate if reporter: log.info( @@ -62,30 +65,21 @@ def __init__( ) log.info(f"and logging to {reporter.log_file_path}") self.reporter = reporter - self.report_frequency = report_frequency + self.report_interval = report_interval self.velocities = None self.save_traj_in_memory = save_traj_in_memory self.traj = [] - - def set_velocities(self, vel: unit.Quantity) -> None: - """ - Set the initial velocities for the Langevin Integrator. - - Parameters - ---------- - vel : unit.Quantity - Velocities to be set for the integrator. - """ - self.velocities = vel + self.refresh_velocities = refresh_velocities + self._move_iteration = 0 def run( self, sampler_state: SamplerState, thermodynamic_state: ThermodynamicState, - n_steps: int = 5_000, + number_of_steps: int = 5_000, nbr_list: Optional[PairsBase] = None, progress_bar=False, - ): + ) -> Tuple[SamplerState, PairsBase]: """ Run the integrator to perform Langevin dynamics molecular dynamics simulation. @@ -95,13 +89,19 @@ def run( The initial state of the simulation, including positions. thermodynamic_state : ThermodynamicState The thermodynamic state of the system, including temperature and potential. - n_steps : int, optional + number_of_steps : int, optional Number of simulation steps to perform. nbr_list : PairBase, optional Neighbor list for the system. progress_bar : bool, optional Flag indicating whether to display a progress bar during integration. + Returns + ------- + sampler_state : SamplerState + The final state of the simulation, including positions, velocities, and current PRNG key. + nbr_list : PairBase + The neighbor list for the final state of the simulation. If the NeighborList object is None, the function returns None. """ from .utils import get_list_of_mass from tqdm import tqdm @@ -114,10 +114,10 @@ def run( self.box_vectors = sampler_state.box_vectors self.progress_bar = progress_bar temperature = thermodynamic_state.temperature - x0 = sampler_state.x0 + x0 = sampler_state.positions log.debug("Running Langevin dynamics") - log.debug(f"n_steps = {n_steps}") + log.debug(f"number_of_steps = {number_of_steps}") log.debug(f"temperature = {temperature}") # Initialize the random number generator @@ -129,18 +129,38 @@ def run( :, None ] sigma_v = jnp.sqrt(kbT_unitless / mass_unitless) - stepsize_unitless = self.stepsize.value_in_unit_system(unit.md_unit_system) + timestep_unitless = self.timestep.value_in_unit_system(unit.md_unit_system) collision_rate_unitless = self.collision_rate.value_in_unit_system( unit.md_unit_system ) - a = jnp.exp((-collision_rate_unitless * stepsize_unitless)) - b = jnp.sqrt(1 - jnp.exp(-2 * collision_rate_unitless * stepsize_unitless)) + a = jnp.exp((-collision_rate_unitless * timestep_unitless)) + b = jnp.sqrt(1 - jnp.exp(-2 * collision_rate_unitless * timestep_unitless)) # Initialize velocities - if self.velocities is None: - v0 = sigma_v * random.normal(key, x0.shape) - else: - v0 = self.velocities.value_in_unit_system(unit.md_unit_system) + if self.refresh_velocities: + # v0 = sigma_v * random.normal(key, positions.shape) + from .utils import initialize_velocities + + sampler_state.velocities = initialize_velocities( + temperature, potential.topology, key + ) + + elif sampler_state._velocities is None: + # v0 = sigma_v * random.normal(key, positions.shape) + from .utils import initialize_velocities + + sampler_state.velocities = initialize_velocities( + temperature, potential.topology, key + ) + elif sampler_state._velocities.shape[0] != sampler_state.positions.shape[0]: + from .utils import initialize_velocities + + sampler_state.velocities = initialize_velocities( + temperature, potential.topology, key + ) + + # extract the velocities from the sampler state + v0 = sampler_state.velocities x = x0 v = v0 @@ -151,56 +171,76 @@ def run( F = potential.compute_force(x, nbr_list) # propagation loop - for step in tqdm(range(n_steps)) if self.progress_bar else range(n_steps): + for step in ( + tqdm(range(number_of_steps)) + if self.progress_bar + else range(number_of_steps) + ): key, subkey = random.split(key) # v - v += (stepsize_unitless * 0.5) * F / mass_unitless + v += (timestep_unitless * 0.5) * F / mass_unitless # r - x += (stepsize_unitless * 0.5) * v + x += (timestep_unitless * 0.5) * v - if nbr_list is not None: - x = self._wrap_and_rebuild_neighborlist(x, nbr_list) - # o random_noise_v = random.normal(subkey, x.shape) v = (a * v) + (b * sigma_v * random_noise_v) - x += (stepsize_unitless * 0.5) * v + x += (timestep_unitless * 0.5) * v + if nbr_list is not None: - x = self._wrap_and_rebuild_neighborlist(x, nbr_list) + x, nbr_list = self._wrap_and_rebuild_neighborlist(x, nbr_list) F = potential.compute_force(x, nbr_list) # v - v += (stepsize_unitless * 0.5) * F / mass_unitless + v += (timestep_unitless * 0.5) * F / mass_unitless - if step % self.report_frequency == 0: + elapsed_step = step + self._move_iteration * number_of_steps + if (elapsed_step) % self.report_interval == 0: if hasattr(self, "reporter") and self.reporter is not None: - self._report(x, potential, nbr_list, step) + self._report( + x, potential, nbr_list, step, self._move_iteration, elapsed_step + ) if self.save_traj_in_memory: self.traj.append(x) log.debug("Finished running Langevin dynamics") - # save the final state of the simulation in the sampler_state object - sampler_state.x0 = x - sampler_state.v0 = v + + # return the final state of the simulation as a sampler_state object + import copy + + updated_sampler_state = copy.deepcopy(sampler_state) + + updated_sampler_state.positions = x + updated_sampler_state.velocities = v + updated_sampler_state.current_PRNG_key = key + + return updated_sampler_state, nbr_list def _wrap_and_rebuild_neighborlist(self, x: jnp.array, nbr_list: PairsBase): """ - Wrap the coordinates and rebuild the neighborlist if necessary. + Wrap the positions and rebuild the neighborlist if necessary. Parameters ---------- x: jnp.array - The coordinates of the particles. + The positions of the particles. nbr_list: PairsBsse The neighborlist object. + + Returns + ------- + x: jnp.array + The wrapped positions. + nbr_list: PairsBase + The neighborlist object; this may or may not have been rebuilt. """ - x = nbr_list.space.wrap(x) + x = nbr_list.space.wrap(x, self.box_vectors) # check if we need to rebuild the neighborlist after moving the particles if nbr_list.check(x): nbr_list.build(x, self.box_vectors) - return x + return x, nbr_list def _report( self, @@ -208,6 +248,8 @@ def _report( potential: NeuralNetworkPotential, nbr_list: PairsBase, step: int, + iteration: int, + elapsed_step: int, ): """ Reports the trajectory, energy, step, and box vectors (if available) to the reporter. @@ -221,7 +263,12 @@ def _report( nbr_list: PairsBase The neighbor list step: int - The current time step. + The current step in the move; this resets each iteration. + iteration: int + The number iterations the move has been called. + elapsed_step: int, + The total number of steps that have been taken in the simulation move. + Returns: None @@ -230,8 +277,10 @@ def _report( "positions": x, "potential_energy": potential.compute_energy(x, nbr_list), "step": step, + "iteration": iteration, + "elapsed_step": elapsed_step, } if nbr_list is not None: - d["box_vectors"] = nbr_list.space.box_vectors + d["box_vectors"] = nbr_list.box_vectors self.reporter.report(d) diff --git a/chiron/mcmc.py b/chiron/mcmc.py index 285cb34..64d71d2 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -2,49 +2,101 @@ from openmm import unit from typing import Tuple, List, Optional import jax.numpy as jnp -from chiron.reporters import LangevinDynamicsReporter, _SimulationReporter +from chiron.reporters import LangevinDynamicsReporter, _SimulationReporter, MCReporter +from .neighbors import PairsBase + +from abc import ABC, abstractmethod class MCMCMove: def __init__( self, - nr_of_moves: int, + number_of_moves: int, reporter: Optional[_SimulationReporter] = None, - report_frequency: Optional[int] = 100, + report_interval: Optional[int] = 100, ): """ Initialize a move within the molecular system. Parameters ---------- - nr_of_moves : int + number_of_moves : int Number of moves to be applied. reporter : _SimulationReporter, optional Reporter object for saving the simulation data. Default is None. - report_frequency : int, optional + report_interval : int, optional + Interval for saving the simulation data in the reporter. + Default is 100. + """ - self.nr_of_moves = nr_of_moves + self.number_of_moves = number_of_moves self.reporter = reporter - self.report_frequency = report_frequency + self.report_interval = report_interval + + # we need to keep track of which iteration we are on + self._move_iteration = 0 + + # we also need to keep track of attempts made (i.e., total elapsed steps), in case the number_of_moves is changed + self._number_of_attempts_made = 0 + from loguru import logger as log if self.reporter is not None: log.info( f"Using reporter {self.reporter} saving to {self.reporter.workdir}" ) - assert self.report_frequency is not None + assert self.report_interval is not None + + @abstractmethod + def update( + self, + sampler_state: SamplerState, + thermodynamic_state: ThermodynamicState, + nbr_list: Optional[PairsBase] = None, + ) -> Tuple[SamplerState, ThermodynamicState, Optional[PairsBase]]: + """ + Update the state of the system. + + Parameters + ---------- + sampler_state : SamplerState + The sampler state to run the integrator on. + thermodynamic_state : ThermodynamicState + The thermodynamic state to run the integrator on. + nbr_list : PairsBase, optional + The neighbor list to use for the simulation. + + Returns + ------- + sampler_state : SamplerState + The updated sampler state. + thermodynamic_state : ThermodynamicState + The updated thermodynamic state. + nbr_list: PairsBase + The updated neighbor/pair list. If no nbr_list is passed, this will be None. + + """ + pass + + @property + def number_of_attemps_made(self): + """ + Return the total number of steps that have been attempted in the move. + """ + return self._number_of_attempts_made class LangevinDynamicsMove(MCMCMove): def __init__( self, - stepsize=1.0 * unit.femtoseconds, - collision_rate=1.0 / unit.picoseconds, + timestep: unit.Quantity = 1.0 * unit.femtoseconds, + collision_rate: unit.Quantity = 1.0 / unit.picoseconds, + refresh_velocities: bool = False, reporter: Optional[LangevinDynamicsReporter] = None, - report_frequency: int = 100, - nr_of_steps=1_000, + report_interval: int = 100, + number_of_steps: int = 1_000, save_traj_in_memory: bool = False, ): """ @@ -52,17 +104,20 @@ def __init__( Parameters ---------- - stepsize : unit.Quantity + timestep : unit.Quantity Time step size for the integration. collision_rate : unit.Quantity Collision rate for the Langevin dynamics. + refresh_velocities : bool, optional + Whether to reinitialize the velocities each time the run function is called. + Default is False. reporter : LangevinDynamicsReporter, optional Reporter object for saving the simulation data. Default is None. - report_frequency : int - Frequency of saving the simulation data. + report_interval : int + Interval for saving the simulation data. Default is 100. - nr_of_steps : int, optional + number_of_steps : int, optional Number of steps to run the integrator for. Default is 1_000. save_traj_in_memory: bool @@ -70,30 +125,32 @@ def __init__( Default is False. NOTE: Only for debugging purposes. """ super().__init__( - nr_of_moves=nr_of_steps, + number_of_moves=number_of_steps, reporter=reporter, - report_frequency=report_frequency, + report_interval=report_interval, ) - self.stepsize = stepsize + self.timestep = timestep self.collision_rate = collision_rate self.save_traj_in_memory = save_traj_in_memory self.traj = [] from chiron.integrators import LangevinIntegrator self.integrator = LangevinIntegrator( - stepsize=self.stepsize, + timestep=self.timestep, collision_rate=self.collision_rate, - report_frequency=report_frequency, + refresh_velocities=refresh_velocities, + report_interval=report_interval, reporter=reporter, save_traj_in_memory=save_traj_in_memory, ) - def run( + def update( self, sampler_state: SamplerState, thermodynamic_state: ThermodynamicState, - ): + nbr_list: Optional[PairsBase] = None, + ) -> Tuple[SamplerState, ThermodynamicState, Optional[PairsBase]]: """ Run the integrator to perform molecular dynamics simulation. @@ -103,6 +160,17 @@ def run( The sampler state to run the integrator on. thermodynamic_state : ThermodynamicState The thermodynamic state to run the integrator on. + nbr_list : PairsBase, optional + The neighbor list to use for the simulation. + + Returns + ------- + sampler_state : SamplerState + The updated sampler state. + thermodynamic_state : ThermodynamicState + The thermodynamic state; note this is not modified by the Langevin dynamics algorithm. + nbr_list: PairsBase + The updated neighbor/pair list. If a nbr_list is not set, this will be None. """ assert isinstance( @@ -112,87 +180,837 @@ def run( thermodynamic_state, ThermodynamicState ), f"Thermodynamic state must be ThermodynamicState, not {type(thermodynamic_state)}" - self.integrator.run( + updated_sampler_state, updated_nbr_list = self.integrator.run( thermodynamic_state=thermodynamic_state, sampler_state=sampler_state, - n_steps=self.nr_of_moves, + number_of_steps=self.number_of_moves, + nbr_list=nbr_list, ) + # update the elapsed steps + self._number_of_attempts_made += self.number_of_moves if self.save_traj_in_memory: self.traj.append(self.integrator.traj) self.integrator.traj = [] + self._move_iteration += 1 + + # The thermodynamic_state will not change for the langevin move + return updated_sampler_state, thermodynamic_state, updated_nbr_list + class MCMove(MCMCMove): def __init__( - self, nr_of_moves: int, reporter: Optional[_SimulationReporter] + self, + number_of_moves: int, + reporter: Optional[_SimulationReporter], + report_interval: int = 1, + autotune: bool = False, + autotune_interval: int = 100, + acceptance_method: str = "Metropolis-Hastings", ) -> None: - super().__init__(nr_of_moves, reporter=reporter) + """ + Initialize the move. - def apply_move(self): + Parameters + ---------- + number_of_moves + Number of moves to be attempted in each call to update. + reporter + Reporter object for saving the simulation step data. + report_interval + Interval for saving the simulation data. + autotune + Whether to automatically tune the parameters of the MC move to achieve a target acceptance ratio. + For example, for a simple displacement move this would update the displacement_sigma. + autotune_interval + Frequency of autotuning the MC move parameters to achieve a target acceptance ratio. + acceptance_method + Methodology to use for accepting or rejecting the proposed state. + Default is "Metropolis-Hastings". """ - Apply a Monte Carlo move to the system. + super().__init__( + number_of_moves=number_of_moves, + reporter=reporter, + report_interval=report_interval, + ) + self.acceptance_method = acceptance_method # I think we should pass a class/function instead of a string, like space. - This method should be overridden by subclasses to define specific types of moves. + self.reset_statistics() + self.autotune = autotune + self.autotune_interval = autotune_interval - Raises - ------ - NotImplementedError - If the method is not implemented in subclasses. + def update( + self, + sampler_state: SamplerState, + thermodynamic_state: ThermodynamicState, + nbr_list: Optional[PairsBase] = None, + ) -> Tuple[SamplerState, ThermodynamicState, Optional[PairsBase]]: """ + Perform the defined move and update the state. - raise NotImplementedError("apply_move() must be implemented in subclasses") + Parameters + ---------- + sampler_state : SamplerState + The initial state of the simulation, including positions. + thermodynamic_state : ThermodynamicState + The thermodynamic state of the system, including temperature and potential. + nbr_list : PairBase, optional + Neighbor list for the system. - def compute_acceptance_probability( + Returns + ------- + sampler_state : SamplerState + The updated sampler state. + thermodynamic_state : ThermodynamicState + The updated thermodynamic state. + nbr_list: PairsBase + The updated neighbor/pair list. If a nbr_list is not set, this will be None. + """ + + self._current_reduced_potential = None + for i in range(self.number_of_moves): + sampler_state, thermodynamic_state, nbr_list = self._step( + sampler_state, + thermodynamic_state, + nbr_list, + ) + self._number_of_attempts_made += 1 + + # We should use self._number_of_attempts_made as the "step" otherwise, if we just used i, instances where + # self.report_interval > self.number_of_moves would only report on the + # first step, which might actually be more frequent than we specify + + if hasattr(self, "reporter"): + if self.reporter is not None: + if self._number_of_attempts_made % self.report_interval == 0: + self._report( + i, + self._move_iteration, + self._number_of_attempts_made, + self.n_accepted / self.n_proposed, + sampler_state, + thermodynamic_state, + nbr_list, + ) + if self.autotune: + # if we only used i, we might never actually update the parameters if we have a move that is called infrequently + if ( + self._number_of_attempts_made % self.autotune_interval == 0 + and self._number_of_attempts_made > 0 + ): + self._autotune() + # keep track of how many times this function has been called + self._move_iteration += 1 + + return sampler_state, thermodynamic_state, nbr_list + + @abstractmethod + def _report( self, - old_state: SamplerState, - new_state: SamplerState, + step: int, + iteration: int, + number_of_attempts_made: int, + acceptance_probability: float, + sampler_state: SamplerState, + thermodynamic_state: ThermodynamicState, + nbr_list: Optional[PairsBase] = None, ): """ - Compute the acceptance probability for a move from an old state to a new state. + Report the current state of the MC move. + + Since different moves will be modifying different quantities, + this needs to be defined for each move. Parameters ---------- - old_state : object - The state of the system before the move. - new_state : object - The state of the system after the move. + step : int + The current step of the simulation move. + iteration : int + The current iteration of the move sequence (i.e., how many times has this been called thus far). + number_of_attempts_made : int + The total number of steps that have been taken in the simulation move. step+ nr_moves*iteration + acceptance_probability : float + The acceptance probability of the move. + sampler_state : SamplerState + The sampler state of the system. + thermodynamic_state : ThermodynamicState + The thermodynamic state of the system. + nbr_list : Optional[PairBase]=None + The neighbor list or pair list for evaluating interactions in the system, default None + """ + pass + + @abstractmethod + def _autotune(self): + """ + This will autotune the move parameters to reach a target acceptance probability. + This will be specific to the type of move, e.g., a displacement_sigma for a displacement move + or a maximum volume change factor for a Monte Carlo barostat move. + + Since different moves will be modifying different quantities, this needs to be defined for each move. + + Note this will modify the class parameters in place. + """ + pass + + def _step( + self, + current_sampler_state: SamplerState, + current_thermodynamic_state: ThermodynamicState, + current_nbr_list: Optional[PairsBase] = None, + ) -> Tuple[SamplerState, ThermodynamicState, Optional[PairsBase]]: + """ + Performs an individual MC step. + + This will call the _propose function which will be specific to the type of move. + + Parameters + ---------- + current_sampler_state : SamplerState, required + Current sampler state. + current_thermodynamic_state : ThermodynamicState, required + Current thermodynamic state. + current_nbr_list : Optional[PairsBase] + Neighbor list associated with the current state. Returns ------- - float - Acceptance probability as a float. + sampler_state : SamplerState + The updated sampler state; if a move is rejected this will be unchanged. + Note, if the proposed move is rejected, the current PRNG key will be updated to ensure + that we are using a different random number for the next iteration. + thermodynamic_state : ThermodynamicState + The updated thermodynamic state; if a move is rejected this will be unchanged. + Note, many MC moves will not modify the thermodynamic state regardless of acceptance of the move. + nbr_list: PairsBase, optional + The updated neighbor/pair list. If a nbr_list is not set, this will be None. + If the move is rejected, this will correspond to the neighbor + + """ + + # if this is the first time we are calling this function during this iteration + # we will need to calculate the reduced potential for the current state + # this is toggled by the calculate_current_reduced_potential flag + # otherwise, we can use the one that was saved from the last step, for efficiency + if self._current_reduced_potential is None: + current_reduced_potential = ( + current_thermodynamic_state.get_reduced_potential( + current_sampler_state, current_nbr_list + ) + ) + # save the current_reduced_potential so we don't have to recalculate + # it on the next iteration if the move is rejected + self._current_reduced_potential = current_reduced_potential + else: + current_reduced_potential = self._current_reduced_potential + + # propose a new state and calculate the log proposal ratio + # this will be specific to the type of move + # in addition to the sampler_state, this will require/return the thermodynamic state + # for systems that e.g., make changes to particle identity. + # For efficiency, we will also return a copy of the nbr_list associated with the proposed state + # because if the move is rejected, we can move back the original state without having to rebuild the nbr_list + # if it were modified due to the proposed state. + ( + proposed_sampler_state, + proposed_thermodynamic_state, + proposed_reduced_potential, + log_proposal_ratio, + proposed_nbr_list, + ) = self._propose( + current_sampler_state, + current_thermodynamic_state, + current_reduced_potential, + current_nbr_list, + ) + + if jnp.isnan(proposed_reduced_potential): + decision = False + else: + # accept or reject the proposed state + decision = self._accept_or_reject( + log_proposal_ratio, + proposed_sampler_state.new_PRNG_key, + acceptance_method=self.acceptance_method, + ) + # a function that will update the statistics for the move + + self._update_statistics(decision) + + if decision: + # save the reduced potential of the accepted state so + # we don't have to recalculate it the next iteration + self._current_reduced_potential = proposed_reduced_potential + + # replace the current state with the proposed state + # not sure this needs to be a separate function but for simplicity in outlining the code it is fine + # or should this return the new sampler_state and thermodynamic_state? + + return ( + proposed_sampler_state, + proposed_thermodynamic_state, + proposed_nbr_list, + ) + else: + # if we reject the move, we need to update the current_PRNG key to ensure that + # we are using a different random number for the next iteration + # this is needed because the _step function returns a SamplerState instead of updating it in place + current_sampler_state._current_PRNG_key = ( + proposed_sampler_state._current_PRNG_key + ) + + return current_sampler_state, current_thermodynamic_state, current_nbr_list + + def _update_statistics(self, decision): + """ + Update the statistics for the move. + """ + if decision: + self.n_accepted += 1 + self.n_proposed += 1 + + @property + def statistics(self): + """The acceptance statistics as a dictionary.""" + return dict(n_accepted=self.n_accepted, n_proposed=self.n_proposed) + + @statistics.setter + def statistics(self, value): + self.n_accepted = value["n_accepted"] + self.n_proposed = value["n_proposed"] + + def reset_statistics(self): + """Reset the acceptance statistics.""" + self.n_accepted = 0 + self.n_proposed = 0 + + @abstractmethod + def _propose( + self, + current_sampler_state: SamplerState, + current_thermodynamic_state: ThermodynamicState, + current_reduced_potential: float, + current_nbr_list: Optional[PairsBase] = None, + ) -> Tuple[SamplerState, ThermodynamicState, float, float, Optional[PairsBase]]: """ - self._check_state_compatiblity(old_state, new_state) - old_system = self.system(old_state) - new_system = self.system(new_state) + Propose a new state and calculate the log proposal ratio. + + This will accept the relevant quantities for the current state, returning the proposed state quantities + and the log proposal ratio. + + This will need to be defined for each new move. + + Parameters + ---------- + current_sampler_state : SamplerState, required + Current sampler state. + current_thermodynamic_state : ThermodynamicState, required + Current thermodynamic state. + current_reduced_potential : float, required + Current reduced potential. + current_nbr_list : PairsBase, required + Neighbor list associated with the current state. + + Returns + ------- + proposed_sampler_state : SamplerState + Proposed sampler state. + proposed_thermodynamic_state : ThermodynamicState + Proposed thermodynamic state. + proposed_reduced_potential : float + Proposed reduced potential. + log_proposal_ratio : float + Log proposal ratio. + proposed_nbr_list : PairsBase + Proposed neighbor list. If not defined, this will be None. - energy_before_state_change = old_system.compute_energy(old_state.position) - energy_after_state_change = new_system.compute_energy(new_state.position) - # Implement the logic to compute the acceptance probability + """ pass - def accept_or_reject(self, probability): + def _accept_or_reject( + self, + log_proposal_ratio, + key, + acceptance_method, + ): + """ + Accept or reject the proposed state with a given methodology. + """ + # define the acceptance probability + if acceptance_method == "Metropolis-Hastings": + import jax.random as jrandom + + compare_to = jrandom.uniform(key) + if -log_proposal_ratio <= 0.0 or compare_to < jnp.exp(log_proposal_ratio): + return True + else: + return False + + +class MonteCarloDisplacementMove(MCMove): + """ + A Monte Carlo move that randomly displaces particles in the system. + + For each move, all particles will be randomly displaced at once, where the random displacement is drawn from + a normal distribution. The standard deviation of the distribution is defined by the `displacement_sigma` parameter. + + Displacements can be restricted to a subset of particles by defining the `atom_subset` parameter, which is a list of + particle indices that will be allowed to move. If `atom_subset` is not defined, all particles will be displaced. + + Note, the displacement moves are applied on a per-particle basis; this does not support collective moves. + + The value of the `displacement_sigma` can be autotuned to achieve a target acceptance ratio between 0.4 and 0.6, + by setting the autotune parameter to True. The frequency of autotuning is defined by setting `autotune_interval`. + + + """ + + def __init__( + self, + displacement_sigma=1.0 * unit.nanometer, + number_of_moves: int = 100, + atom_subset: Optional[List[int]] = None, + report_interval: int = 1, + reporter: Optional[MCReporter] = None, + autotune: bool = False, + autotune_interval: int = 100, + acceptance_method="Metropolis-Hastings", + ): """ - Decide whether to accept or reject the move based on the acceptance probability. + Initialize the Displacement Move class. Parameters ---------- - probability : float - Acceptance probability. + displacement_sigma : float or unit.Quantity, optional + The standard deviation of the displacement for each move. Default is 1.0 nm. + number_of_moves : int, optional + The number of move attempts to perform. Default is 100. + For a given move, all particles will be randomly displaced at once (unless atom_subset is), + rather than moving each particle one at a time. + atom_subset : list of int, optional + A list of particle indices that represent a subset of all particles. + If defined, only those particles in the list will have their positions random displaced. + Default is None. + reporter : SimulationReporter, optional + The reporter to write the data to. Default is None. + autotune : bool, optional + Whether to autotune the displacement_sigma of the move to achieve an acceptance ratio between 0.4 and 0.6. + Default is False. + autotune_interval : int, optional + Frequency of autotuning displacement_sigma of the move. Default is 100. + acceptance_method : str, optional + Methodology to use for accepting or rejecting the proposed state. + Default is "Metropolis-Hastings". Returns ------- - bool - Boolean indicating if the move is accepted. + None """ - import jax.numpy as jnp + super().__init__( + number_of_moves=number_of_moves, + reporter=reporter, + report_interval=report_interval, + autotune=autotune, + autotune_interval=autotune_interval, + acceptance_method=acceptance_method, + ) + self.displacement_sigma = displacement_sigma + + self.atom_subset = atom_subset + self.atom_subset_mask = None + + def _report( + self, + step: int, + iteration: int, + number_of_attempts_made: int, + acceptance_probability: float, + sampler_state: SamplerState, + thermodynamic_state: ThermodynamicState, + nbr_list: Optional[PairsBase] = None, + ): + """ + Report the current state of the MC displacement move. + + Parameters + ---------- + step : int + The current step of the simulation move. + iteration : int + The current iteration of the move sequence (i.e., how many times has this been called thus far). + number_of_attempts_made : int + The total number of steps that have been taken in the simulation move. step+ nr_moves*iteration + acceptance_probability : float + The acceptance probability of the move. + sampler_state : SamplerState + The sampler state of the system. + thermodynamic_state : ThermodynamicState + The thermodynamic state of the system. + nbr_list : Optional[PairBase]=None + The neighbor list or pair list for evaluating interactions in the system, default None - return jnp.random.rand() < probability + """ + potential = thermodynamic_state.potential.compute_energy( + sampler_state.positions, nbr_list + ) + self.reporter.report( + { + "step": step, + "iteration": iteration, + "number_of_attempts_made": number_of_attempts_made, + "potential_energy": potential, + "displacement_sigma": self.displacement_sigma.value_in_unit_system( + unit.md_unit_system + ), + "acceptance_probability": acceptance_probability, + } + ) + + def _autotune(self): + """ + Update the displacement_sigma to reach a target acceptance probability between 0.4 and 0.6. + """ + acceptance_ratio = self.n_accepted / self.n_proposed + if acceptance_ratio > 0.6: + self.displacement_sigma *= 1.1 + elif acceptance_ratio < 0.4: + self.displacement_sigma /= 1.1 + + def _propose( + self, + current_sampler_state: SamplerState, + current_thermodynamic_state: ThermodynamicState, + current_reduced_potential: float, + current_nbr_list: Optional[PairsBase] = None, + ) -> Tuple[SamplerState, ThermodynamicState, float, float, Optional[PairsBase]]: + """ + Implements the logic specific to displacement moves. + + Parameters + ---------- + current_sampler_state : SamplerState, required + Current sampler state. + current_thermodynamic_state : ThermodynamicState, required + Current thermodynamic state. + current_reduced_potential : float, required + Current reduced potential. + current_nbr_list : Optional[PairsBase] + Neighbor list associated with the current state. + + Returns + ------- + proposed_sampler_state : SamplerState + Proposed sampler state. + proposed_thermodynamic_state : ThermodynamicState + Proposed thermodynamic state. + proposed_reduced_potential : float + Proposed reduced potential. + log_proposal_ratio : float + Log proposal ratio. + proposed_nbr_list : PairsBase + Proposed neighbor list. If not defined, this will be None. + """ + + # create a mask for the atom subset: if a value of the mask is 0 + # the particle won't move; if 1 the particle will be moved + if self.atom_subset is not None and self.atom_subset_mask is None: + import jax.numpy as jnp + + self.atom_subset_mask = jnp.zeros(current_sampler_state.n_particles) + for atom in self.atom_subset: + self.atom_subset_mask = self.atom_subset_mask.at[atom].set(1) + + key = current_sampler_state.new_PRNG_key + + nr_of_atoms = current_sampler_state.n_particles + + unitless_displacement_sigma = self.displacement_sigma.value_in_unit_system( + unit.md_unit_system + ) + import jax.random as jrandom + + scaled_displacement_vector = ( + jrandom.normal(key, shape=(nr_of_atoms, 3)) * unitless_displacement_sigma + ) + import copy + + proposed_sampler_state = copy.deepcopy(current_sampler_state) + + if self.atom_subset is not None: + proposed_sampler_state.positions = ( + proposed_sampler_state.positions + + scaled_displacement_vector * self.atom_subset_mask + ) + else: + proposed_sampler_state.positions = ( + proposed_sampler_state.positions + scaled_displacement_vector + ) + + # after proposing a move we need to wrap particles and see if we need to rebuild the neighborlist + if current_nbr_list is not None: + proposed_sampler_state.positions = current_nbr_list.space.wrap( + proposed_sampler_state.positions, + proposed_sampler_state.box_vectors, + ) + + # if we need to rebuild the neighbor the neighborlist + # we will make a copy and then build + if current_nbr_list.check(proposed_sampler_state.positions): + import copy + + proposed_nbr_list = copy.deepcopy(current_nbr_list) + + proposed_nbr_list.build( + proposed_sampler_state.positions, proposed_sampler_state.box_vectors + ) + # if we don't need to update the neighborlist, just make a new variable that refers to the original + else: + proposed_nbr_list = current_nbr_list + else: + proposed_nbr_list = None + + proposed_reduced_potential = current_thermodynamic_state.get_reduced_potential( + proposed_sampler_state, proposed_nbr_list + ) + + log_proposal_ratio = -proposed_reduced_potential + current_reduced_potential + + # since do not change the thermodynamic state we can return + # 'current_thermodynamic_state' rather than making a copy + return ( + proposed_sampler_state, + current_thermodynamic_state, + proposed_reduced_potential, + log_proposal_ratio, + proposed_nbr_list, + ) + + +class MonteCarloBarostatMove(MCMove): + """ + A Monte Carlo move that randomly changes the volume of the system. + + The volume change is drawn from a normal distribution with a mean of 0 and a standard deviation defined + by the product of the `volume_max_scale` parameter and the current volume. Particle positions are scaled + proportionately with the change in volume. This routine operates on a per-particle basis and does not support + collective moves (i.e., it is an "atomic" barostat move where particle center-of-mass positions are scaled; + it is not aware of "molecules" which would be scaled by the molecule center-of-mass). + + The `volume_max_scale` parameter can be autotuned to achieve a target acceptance ratio between 0.25 and 0.75, + by setting the autotune parameter to True. The frequency of autotuning is defined by setting `autotune_interval`. + Note, the maximum value of `volume_max_scale` is capped at 0.3 in the auto-tuning process. + + + """ + + def __init__( + self, + volume_max_scale=0.01, + number_of_moves: int = 100, + report_interval: int = 1, + reporter: Optional[LangevinDynamicsReporter] = None, + autotune: bool = False, + autotune_interval: int = 100, + acceptance_method="Metropolis-Hastings", + ): + """ + Initialize the Monte Carlo Barostat Move class. + + Parameters + ---------- + volume_max_scale : float, optional + The scaling factor multiplied by volume to set the maximum volume change allowed. + number_of_moves : int, optional + The number of volume update moves attempts to perform. Default is 100. + reporter : SimulationReporter, optional + The reporter to write the data to. Default is None. + autotune : bool, optional + Whether to autotune the volume_max_scale value of the move to achieve a target probability + between 0.25 and 0.75. Default is False. volume_max_scale is capped at 0.3 + autotune_interval : int, optional + Frequency of autotuning the volume_max_scale of the move. Default is 100. + acceptance_method : str, optional + Methodology to use for accepting or rejecting the proposed state. + Default is "Metropolis-Hastings". + + Returns + ------- + None + """ + super().__init__( + number_of_moves=number_of_moves, + reporter=reporter, + report_interval=report_interval, + autotune=autotune, + autotune_interval=autotune_interval, + acceptance_method=acceptance_method, + ) + self.volume_max_scale = volume_max_scale + + def _report( + self, + step: int, + iteration: int, + number_of_attempts_made: int, + acceptance_probability: float, + sampler_state: SamplerState, + thermodynamic_state: ThermodynamicState, + nbr_list: Optional[PairsBase] = None, + ): + """ + + Parameters + ---------- + step : int + The current step of the simulation move. + iteration : int + The current iteration of the move sequence (i.e., how many times has this been called thus far). + number_of_attempts_made : int + The total number of steps that have been taken in the simulation move. step+ nr_moves*iteration + acceptance_probability : float + The acceptance probability of the move. + sampler_state : SamplerState + The sampler state of the system. + thermodynamic_state : ThermodynamicState + The thermodynamic state of the system. + nbr_list : Optional[PairBase]=None + The neighbor list or pair list for evaluating interactions in the system, default None + """ + + potential = thermodynamic_state.potential.compute_energy( + sampler_state.positions, nbr_list + ) + volume = ( + sampler_state.box_vectors[0][0] + * sampler_state.box_vectors[1][1] + * sampler_state.box_vectors[2][2] + ) + self.reporter.report( + { + "step": step, + "iteration": iteration, + "number_of_attempts_made": number_of_attempts_made, + "potential_energy": potential, + "volume": volume, + "box_vectors": sampler_state.box_vectors, + "max_volume_scale": self.volume_max_scale, + "acceptance_probability": acceptance_probability, + } + ) + + def _autotune(self): + """ + Update the volume_max_scale parameter to ensure our acceptance probability is within the range of 0.25 to 0.75. + The maximum volume_max_scale will be capped at 0.3. + """ + acceptance_ratio = self.n_accepted / self.n_proposed + if acceptance_ratio < 0.25: + self.volume_max_scale /= 1.1 + elif acceptance_ratio > 0.75: + self.volume_max_scale = min(self.volume_max_scale * 1.1, 0.3) + + def _propose( + self, + current_sampler_state: SamplerState, + current_thermodynamic_state: ThermodynamicState, + current_reduced_potential: float, + current_nbr_list: Optional[PairsBase] = None, + ) -> Tuple[SamplerState, ThermodynamicState, float, float, Optional[PairsBase]]: + """ + Implement the logic specific to displacement changes. + + Parameters + ---------- + current_sampler_state : SamplerState, required + Current sampler state. + current_thermodynamic_state : ThermodynamicState, required + Current thermodynamic state. + current_reduced_potential : float, required + Current reduced potential. + current_nbr_list : PairsBase, optional + Neighbor list associated with the current state. + + Returns + ------- + proposed_sampler_state : SamplerState + Proposed sampler state. + proposed_thermodynamic_state : ThermodynamicState + Proposed thermodynamic state. + proposed_reduced_potential : float + Proposed reduced potential. + log_proposal_ratio : float + Log proposal ratio. + proposed_nbr_list : PairsBase + Proposed neighbor list. If not defined, this will be None. + + """ + from loguru import logger as log + + key = current_sampler_state.new_PRNG_key + + import jax.random as jrandom + + nr_of_atoms = current_sampler_state.n_particles + + initial_volume = ( + current_sampler_state.box_vectors[0][0] + * current_sampler_state.box_vectors[1][1] + * current_sampler_state.box_vectors[2][2] + ) + + # Calculate the maximum amount the volume can change by + delta_volume_max = self.volume_max_scale * initial_volume + + # Calculate the volume change by generating a random number between -1 and 1 + # and multiplying by the maximum allowed volume change, delta_volume_max + delta_volume = jrandom.uniform(key, minval=-1, maxval=1) * delta_volume_max + # calculate the new volume + proposed_volume = initial_volume + delta_volume + + # calculate the length scale factor for particle positions and box vectors + length_scaling_factor = jnp.power(proposed_volume / initial_volume, 1.0 / 3.0) + + import copy + + proposed_sampler_state = copy.deepcopy(current_sampler_state) + proposed_sampler_state.positions = ( + current_sampler_state.positions * length_scaling_factor + ) + + proposed_sampler_state.box_vectors = ( + current_sampler_state.box_vectors * length_scaling_factor + ) + + if current_nbr_list is not None: + proposed_nbr_list = copy.deepcopy(current_nbr_list) + # after scaling the box vectors and positions we should always rebuild the neighborlist + proposed_nbr_list.build( + proposed_sampler_state.positions, proposed_sampler_state.box_vectors + ) + + proposed_reduced_potential = current_thermodynamic_state.get_reduced_potential( + proposed_sampler_state, proposed_nbr_list + ) + # NPT acceptance criteria was originally defined in McDonald 1972, https://doi.org/10.1080/00268977200100031 + # (see equation 9). The acceptance probability is given by: + # ⎡−β (ΔU + PΔV ) + N ln(V new /V old )⎤ + log_proposal_ratio = -( + proposed_reduced_potential - current_reduced_potential + ) + nr_of_atoms * jnp.log(proposed_volume / initial_volume) + + # we do not change the thermodynamic state so we can return 'current_thermodynamic_state' + return ( + proposed_sampler_state, + current_thermodynamic_state, + proposed_reduced_potential, + log_proposal_ratio, + proposed_nbr_list, + ) class RotamerMove(MCMove): - def apply_move(self): + def _propose(self): """ Implement the logic specific to rotamer changes. """ @@ -200,7 +1018,7 @@ def apply_move(self): class ProtonationStateMove(MCMove): - def apply_move(self): + def _propose(self): """ Implement the logic specific to protonation state changes. """ @@ -208,7 +1026,7 @@ def apply_move(self): class TautomericStateMove(MCMove): - def apply_move(self): + def _propose(self): """ Implement the logic specific to tautomeric state changes. """ @@ -283,20 +1101,39 @@ def run( sampler_state: SamplerState, thermodynamic_state: ThermodynamicState, n_iterations: int = 1, + nbr_list: Optional[PairsBase] = None, ): """ Run the sampler for a specified number of iterations. Parameters ---------- + sampler_state : SamplerState + The initial state of the sampler. + thermodynamic_state : ThermodynamicState + The thermodynamic state of the system. n_iterations : int, optional Number of iterations of the sampler to run. + Default is 1. + nbr_list : PairsBase, optional + The neighbor list to use for the simulation. + + Returns + ------- + sampler_state : SamplerState + The updated sampler state. + thermodynamic_state : ThermodynamicState + The updated thermodynamic state. + nbr_list: PairsBase + The updated neighbor/pair list. If a nbr_list is not set, this will be None. + """ from loguru import logger as log from copy import deepcopy sampler_state = deepcopy(sampler_state) thermodynamic_state = deepcopy(thermodynamic_state) + nbr_list = deepcopy(nbr_list) log.info("Running MCMC sampler") log.info(f"move_schedule = {self.move.move_schedule}") @@ -304,7 +1141,10 @@ def run( log.info(f"Iteration {iteration + 1}/{n_iterations}") for move_name, move in self.move.move_schedule: log.debug(f"Performing: {move_name}") - move.run(sampler_state, thermodynamic_state) + + sampler_state, thermodynamic_state, nbr_list = move.update( + sampler_state, thermodynamic_state, nbr_list + ) log.info("Finished running MCMC sampler") log.debug("Closing reporter") @@ -312,298 +1152,4 @@ def run( if move.reporter is not None: move.reporter.flush_buffer() log.debug(f"Closed reporter {move.reporter.log_file_path}") - return sampler_state - - -from .neighbors import PairsBase - - -class MetropolizedMove(MCMove): - """A base class for metropolized moves. - - Only the proposal needs to be specified by subclasses through the method - _propose_positions(). - - Parameters - ---------- - atom_subset : slice or list of int, optional - If specified, the move is applied only to those atoms specified by these - indices. If None, the move is applied to all atoms (default is None). - - Attributes - ---------- - n_accepted : int - The number of proposals accepted. - n_proposed : int - The total number of attempted moves. - atom_subset - - Examples - -------- - TBC - """ - - def __init__( - self, - atom_subset: Optional[List[int]] = None, - nr_of_moves: int = 100, - reporter: Optional[_SimulationReporter] = None, - report_frequency: int = 1, - ): - self.n_accepted = 0 - self.n_proposed = 0 - self.atom_subset = atom_subset - super().__init__(nr_of_moves=nr_of_moves, reporter=reporter) - from loguru import logger as log - - self.report_frequency = report_frequency - log.debug(f"Atom subset is {atom_subset}.") - - @property - def statistics(self): - """The acceptance statistics as a dictionary.""" - return dict(n_accepted=self.n_accepted, n_proposed=self.n_proposed) - - @statistics.setter - def statistics(self, value): - self.n_accepted = value["n_accepted"] - self.n_proposed = value["n_proposed"] - - def apply( - self, - thermodynamic_state: ThermodynamicState, - sampler_state: SamplerState, - nbr_list=Optional[PairsBase], - ): - """Apply a metropolized move to the sampler state. - - Total number of acceptances and proposed move are updated. - - Parameters - ---------- - thermodynamic_state : ThermodynamicState - The thermodynamic state to use to apply the move. - sampler_state : SamplerState - The initial sampler state to apply the move to. This is modified. - nbr_list: Neighbor List or Pair List routine, - The routine to use to calculate the interacting atoms. - Default is None and will use an unoptimized pairlist without PBC - """ - import jax.numpy as jnp - from loguru import logger as log - - # Compute initial energy - initial_energy = thermodynamic_state.get_reduced_potential( - sampler_state, nbr_list - ) # NOTE: in kT - log.debug(f"Initial energy is {initial_energy} kT.") - - # Store initial positions of the atoms that are moved. - x0 = sampler_state.x0 - atom_subset = self.atom_subset - if atom_subset is None: - initial_positions = jnp.copy(x0) - else: - initial_positions = jnp.copy(sampler_state.x0[jnp.array(atom_subset)]) - log.debug(f"Initial positions are {initial_positions} nm.") - # Propose perturbed positions. Modifying the reference changes the sampler state. - proposed_positions = self._propose_positions(initial_positions) - - log.debug(f"Proposed positions are {proposed_positions} nm.") - # Compute the energy of the proposed positions. - if atom_subset is None: - sampler_state.x0 = proposed_positions - else: - sampler_state.x0 = sampler_state.x0.at[jnp.array(atom_subset)].set( - proposed_positions - ) - if nbr_list is not None: - if nbr_list.check(sampler_state.x0): - nbr_list.build(sampler_state.x0, sampler_state.box_vectors) - - proposed_energy = thermodynamic_state.get_reduced_potential( - sampler_state, nbr_list - ) # NOTE: in kT - # Accept or reject with Metropolis criteria. - delta_energy = proposed_energy - initial_energy - log.debug(f"Delta energy is {delta_energy} kT.") - import jax.random as jrandom - - self.key, subkey = jrandom.split(self.key) - - compare_to = jrandom.uniform(subkey) - if not jnp.isnan(proposed_energy) and ( - delta_energy <= 0.0 or compare_to < jnp.exp(-delta_energy) - ): - self.n_accepted += 1 - log.debug(f"Check suceeded: {compare_to=} < {jnp.exp(-delta_energy)}") - log.debug( - f"Move accepted. Energy change: {delta_energy:.3f} kT. Number of accepted moves: {self.n_accepted}." - ) - if self.n_proposed % self.report_frequency == 0: - self.reporter.report( - { - "energy": proposed_energy, # in kT - "step": self.n_proposed, - "traj": sampler_state.x0, - } - ) - else: - # Restore original positions. - if atom_subset is None: - sampler_state.x0 = initial_positions - else: - sampler_state.x0 = sampler_state.x0.at[jnp.array([atom_subset])].set( - initial_positions - ) - log.debug( - f"Move rejected. Energy change: {delta_energy:.3f} kT. Number of rejected moves: {self.n_proposed - self.n_accepted}." - ) - self.n_proposed += 1 - - def _propose_positions(self, positions: jnp.array): - """Return new proposed positions. - - These method must be implemented in subclasses. - - Parameters - ---------- - positions : nx3 jnp.ndarray - The original positions of the subset of atoms that these move - applied to. - - Returns - ------- - proposed_positions : nx3 jnp.ndarray - The new proposed positions. - - """ - raise NotImplementedError( - "This MetropolizedMove does not know how to propose new positions." - ) - - -class MetropolisDisplacementMove(MetropolizedMove): - """A metropolized move that randomly displace a subset of atoms. - - Parameters - ---------- - displacement_sigma : openmm.unit.Quantity - The standard deviation of the normal distribution used to propose the - random displacement (units of length, default is 1.0*nanometer). - atom_subset : slice or list of int, optional - If specified, the move is applied only to those atoms specified by these - indices. If None, the move is applied to all atoms (default is None). - - Attributes - ---------- - n_accepted : int - The number of proposals accepted. - n_proposed : int - The total number of attempted moves. - displacement_sigma - atom_subset - - See Also - -------- - MetropolizedMove - - """ - - def __init__( - self, - displacement_sigma=1.0 * unit.nanometer, - nr_of_moves: int = 100, - atom_subset: Optional[List[int]] = None, - reporter: Optional[LangevinDynamicsReporter] = None, - ): - """ - Initialize the MCMC class. - - Parameters - ---------- - seed : int, optional - The seed for the random number generator. Default is 1234. - displacement_sigma : float or unit.Quantity, optional - The standard deviation of the displacement for each move. Default is 1.0 nm. - nr_of_moves : int, optional - The number of moves to perform. Default is 100. - atom_subset : list of int, optional - A subset of atom indices to consider for the moves. Default is None. - reporter : SimulationReporter, optional - The reporter to write the data to. Default is None. - Returns - ------- - None - """ - super().__init__(nr_of_moves=nr_of_moves, reporter=reporter) - self.displacement_sigma = displacement_sigma - self.atom_subset = atom_subset - self.key = None - - def displace_positions( - self, positions: jnp.array, displacement_sigma=1.0 * unit.nanometer - ): - """Return the positions after applying a random displacement to them. - - Parameters - ---------- - positions : nx3 jnp.array unit.Quantity - The positions to displace. - displacement_sigma : openmm.unit.Quantity - The standard deviation of the normal distribution used to propose - the random displacement (units of length, default is 1.0*nanometer). - - Returns - ------- - rotated_positions : nx3 numpy.ndarray openmm.unit.Quantity - The displaced positions. - - """ - import jax.random as jrandom - - self.key, subkey = jrandom.split(self.key) - nr_of_atoms = positions.shape[0] - unitless_displacement_sigma = displacement_sigma.value_in_unit_system( - unit.md_unit_system - ) - displacement_vector = ( - jrandom.normal(subkey, shape=(nr_of_atoms, 3)) * 0.1 - ) # NOTE: convert from Angstrom to nm - scaled_displacement_vector = displacement_vector * unitless_displacement_sigma - updated_position = positions + scaled_displacement_vector - - return updated_position - - def _propose_positions(self, initial_positions: jnp.array) -> jnp.array: - """Implement MetropolizedMove._propose_positions for apply().""" - return self.displace_positions(initial_positions, self.displacement_sigma) - - def run( - self, - sampler_state: SamplerState, - thermodynamic_state: ThermodynamicState, - nbr_list=None, - progress_bar=True, - ): - from tqdm import tqdm - from loguru import logger as log - from jax import random - - self.key = sampler_state.new_PRNG_key - - for trials in ( - tqdm(range(self.nr_of_moves)) if progress_bar else range(self.nr_of_moves) - ): - self.apply(thermodynamic_state, sampler_state, nbr_list) - if trials % 100 == 0: - log.debug(f"Acceptance rate: {self.n_accepted / self.n_proposed}") - if self.reporter is not None: - self.reporter.report( - { - "Acceptance rate": self.n_accepted / self.n_proposed, - "step": self.n_proposed, - } - ) - - log.info(f"Acceptance rate: {self.n_accepted / self.n_proposed}") + return sampler_state, thermodynamic_state, nbr_list diff --git a/chiron/multistate.py b/chiron/multistate.py index 26f3ea3..e67d7bb 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -75,6 +75,7 @@ def __init__( self._neighborhoods = None self._n_accepted_matrix = None self._n_proposed_matrix = None + self._reporter = reporter # NOTE: reporter needs to be putlic, API change ahead self._metadata = None self._mcmc_sampler = copy.deepcopy(mcmc_sampler) @@ -322,14 +323,14 @@ def _minimize_replica( # Perform minimization minimized_state = minimize_energy( - sampler_state.x0, + sampler_state.positions, thermodynamic_state.potential.compute_energy, self.nbr_list, maxiter=max_iterations, ) # Update the sampler state - self._sampler_states[replica_id].x0 = minimized_state.params + self._sampler_states[replica_id].positions = minimized_state.params # Compute and log final energy final_energy = thermodynamic_state.get_reduced_potential(sampler_state) @@ -394,11 +395,17 @@ def _propagate_replica(self, replica_id: int): thermodynamic_state_id = self._replica_thermodynamic_states[replica_id] sampler_state = self._sampler_states[replica_id] thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id] + mcmc_sampler = self._mcmc_sampler[thermodynamic_state_id] # Propagate using the mcmc sampler - self._sampler_states[replica_id] = mcmc_sampler.run(sampler_state, thermodynamic_state) + # NOTE this needs to be updated to support neighborlists + ( + self._sampler_states[replica_id], + self._thermodynamic_states[thermodynamic_state_id], + nbr_list, + ) = mcmc_sampler.run(sampler_state, thermodynamic_state) # Append the new state to the trajectory for analysis. - self._traj[replica_id].append(self._sampler_states[replica_id].x0) + self._traj[replica_id].append(self._sampler_states[replica_id].positions) def _perform_swap_proposals(self): """ @@ -577,9 +584,9 @@ def _report_positions(self): log.debug("Reporting positions...") # numpy array with shape (n_replicas, n_atoms, 3) - xyz = np.zeros((self.n_replicas, self._sampler_states[0].x0.shape[0], 3)) + xyz = np.zeros((self.n_replicas, self._sampler_states[0].positions.shape[0], 3)) for replica_id in range(self.n_replicas): - xyz[replica_id] = self._sampler_states[replica_id].x0 + xyz[replica_id] = self._sampler_states[replica_id].positions return {"positions": xyz} def _report(self, property: str) -> None: @@ -598,17 +605,29 @@ def _report(self, property: str) -> None: from loguru import logger as log log.debug(f"Reporting {property}...") - match property: - case "positions": - return self._report_positions() - case "states": - pass - case "u_kn": - return self._report_energy_matrix() - case "trajectory": - return - case "mixing_statistics": - return + if property == "positions": + return self._report_positions() + elif property == "states": + pass + elif property == "u_kn": + return self._report_energy_matrix() + elif property == "trajectory": + return + elif "mixing_statistics": + return + + # match isn't in python 3.9; we can discuss if we want to drop python 3.0 support or just keep the if/else structure + # match property: + # case "positions": + # return self._report_positions() + # case "states": + # pass + # case "u_kn": + # return self._report_energy_matrix() + # case "trajectory": + # return + # case "mixing_statistics": + # return def _report_iteration(self): """ diff --git a/chiron/neighbors.py b/chiron/neighbors.py index 9941e44..be80cf0 100644 --- a/chiron/neighbors.py +++ b/chiron/neighbors.py @@ -3,62 +3,36 @@ import jax import jax.numpy as jnp from functools import partial -from typing import Tuple, Union +from typing import Tuple, Union, Optional from .states import SamplerState from openmm import unit -# split out the displacement calculation from the neighborlist for flexibility +# split out the displacement calculation from the neighbor list and pair list for flexibility from abc import ABC, abstractmethod class Space(ABC): - def __init__( - self, box_vectors: Union[jnp.array, unit.Quantity, None] = None - ) -> None: - """ - Abstract base class for defining the simulation space. + """ + Abstract Base Class for different simulation spaces. - Parameters - ---------- - box_vectors: jnp.array, optional - Box vectors for the system. - """ - if box_vectors is not None: - if isinstance(box_vectors, unit.Quantity): - if not box_vectors.unit.is_compatible(unit.nanometer): - raise ValueError( - f"Box vectors require distance unit, not {box_vectors.unit}" - ) - self.box_vectors = box_vectors.value_in_unit_system(unit.md_unit_system) - elif isinstance(box_vectors, jnp.ndarray): - if box_vectors.shape != (3, 3): - raise ValueError( - f"box_vectors should be a 3x3 array, shape provided: {box_vectors.shape}" - ) - - self.box_vectors = box_vectors - else: - raise TypeError( - f"box_vectors must be a jnp.array or unit.Quantity, not {type(box_vectors)}" - ) + This class will define two functions: + - displacement, i.e., how to calculate the displacement vector and distance between two points + - wrap, i.e., how to wrap a particle in the box (i.e., apply boundary conditions). - @property - def box_vectors(self) -> jnp.array: - return self._box_vectors + Note, this class does not store the box_vectors; they will need to be passed to each function. - @box_vectors.setter - def box_vectors(self, box_vectors: jnp.array) -> None: - self._box_vectors = box_vectors + + """ @abstractmethod def displacement( - self, xyz_1: jnp.array, xyz_2: jnp.array + self, xyz_1: jnp.array, xyz_2: jnp.array, box_vectors: jnp.array ) -> Tuple[jnp.array, jnp.array]: pass @abstractmethod - def wrap(self, xyz: jnp.array) -> jnp.array: + def wrap(self, xyz: jnp.array, box_vectors: jnp.array) -> jnp.array: pass @@ -68,20 +42,9 @@ class OrthogonalPeriodicSpace(Space): """ - @property - def box_vectors(self) -> jnp.array: - return self._box_vectors - - @box_vectors.setter - def box_vectors(self, box_vectors: jnp.array) -> None: - self._box_vectors = box_vectors - self._box_lengths = jnp.array( - [box_vectors[0][0], box_vectors[1][1], box_vectors[2][2]] - ) - @partial(jax.jit, static_argnums=(0,)) def displacement( - self, xyz_1: jnp.array, xyz_2: jnp.array + self, xyz_1: jnp.array, xyz_2: jnp.array, box_vectors: jnp.array ) -> Tuple[jnp.array, jnp.array]: """ Calculate the periodic distance between two points. @@ -89,9 +52,10 @@ def displacement( Parameters ---------- xyz_1: jnp.array - Coordinates of the first point + Positions of the first point xyz_2: jnp.array - Coordinates of the second point + Positions of the second point + box_vectors: jnp.array Returns ------- @@ -101,36 +65,43 @@ def displacement( Distance between the two points """ - # calculate uncorrect r_ij + # calculate uncorrected r_ij r_ij = xyz_1 - xyz_2 - # calculated corrected displacement vector - r_ij = ( - jnp.mod(r_ij + self._box_lengths * 0.5, self._box_lengths) - - self._box_lengths * 0.5 + box_lengths = jnp.array( + [box_vectors[0][0], box_vectors[1][1], box_vectors[2][2]] ) + # calculated corrected displacement vector + # using modulus seems faster in JAX + r_ij = jnp.mod(r_ij + box_lengths * 0.5, box_lengths) - box_lengths * 0.5 # calculate the scalar distance dist = jnp.linalg.norm(r_ij, axis=-1) return r_ij, dist @partial(jax.jit, static_argnums=(0,)) - def wrap(self, xyz: jnp.array) -> jnp.array: + def wrap(self, xyz: jnp.array, box_vectors: jnp.array) -> jnp.array: """ - Wrap the coordinates of the system. + Wrap the positions of the system. Parameters ---------- xyz: jnp.array - Coordinates of the system + Positions of the system + box_vectors: jnp.array + Box vectors for the system Returns ------- jnp.array - Wrapped coordinates of the system + Wrapped positions of the system """ - xyz = xyz - jnp.floor(xyz / self._box_lengths) * self._box_lengths + box_lengths = jnp.array( + [box_vectors[0][0], box_vectors[1][1], box_vectors[2][2]] + ) + + xyz = xyz - jnp.floor(xyz / box_lengths) * box_lengths return xyz @@ -141,6 +112,7 @@ def displacement( self, xyz_1: jnp.array, xyz_2: jnp.array, + box_vectors: jnp.array, ) -> Tuple[jnp.array, jnp.array]: """ Calculate the periodic distance between two points. @@ -148,9 +120,11 @@ def displacement( Parameters ---------- xyz_1: jnp.array - Coordinates of the first point + Positions of the first point xyz_2: jnp.array - Coordinates of the second point + Positions of the second point + box_vectors: jnp.array + Box vectors for the system. Returns ------- @@ -169,20 +143,22 @@ def displacement( return r_ij, dist @partial(jax.jit, static_argnums=(0,)) - def wrap(self, xyz: jnp.array) -> jnp.array: + def wrap(self, xyz: jnp.array, box_vectors: jnp.array) -> jnp.array: """ - Wrap the coordinates of the system. - For the Non-periodic system, this does not alter the coordinates + Wrap the positions of the system. + For the Non-periodic system, this does not alter the positions Parameters ---------- xyz: jnp.array - Coordinates of the system + Positions of the system + box_vectors: jnp.array + Box vectors for the system Returns ------- jnp.array - Wrapped coordinates of the system + Wrapped positions of the system """ return xyz @@ -190,7 +166,7 @@ def wrap(self, xyz: jnp.array) -> jnp.array: class PairsBase(ABC): """ - Abstract Base Class for different algorithms that determine which particles are interacting. + Abstract Base Class for different algorithms that determine which particle pairs are interacting. Parameters ---------- @@ -207,46 +183,58 @@ class PairsBase(ABC): >>> import jax.numpy as jnp >>> >>> space = OrthogonalPeriodicSpace() # define the simulation space, in this case an orthogonal periodic space - >>> sampler_state = SamplerState(x0=jnp.array([[0.0, 0.0, 0.0], [2, 0.0, 0.0], [0.0, 2, 0.0]]), + >>> sampler_state = SamplerState(positions=jnp.array([[0.0, 0.0, 0.0], [2, 0.0, 0.0], [0.0, 2, 0.0]]), >>> box_vectors=jnp.array([[10, 0.0, 0.0], [0.0, 10, 0.0], [0.0, 0.0, 10]])) >>> >>> pair_list = PairsBase(space, cutoff=2.5*unit.nanometer) # initialize the pair list >>> pair_list.build_from_state(sampler_state) # build the pair list from the sampler state >>> - >>> coordinates = sampler_state.x0 # get the coordinates from the sampler state, without units attached + >>> positions = sampler_state.positions # get the positions from the sampler state, without units attached >>> >>> # the calculate function will produce information used to calculate the energy - >>> n_neighbors, padding_mask, dist, r_ij = pair_list.calculate(coordinates) + >>> n_neighbors, padding_mask, dist, r_ij = pair_list.calculate(positions) >>> """ def __init__( self, space: Space, - cutoff: unit.Quantity = unit.Quantity(1.2, unit.nanometer), + cutoff: Optional[unit.Quantity] = unit.Quantity(1.2, unit.nanometer), ): + """ + Initialize the PairsBase class + + Parameters + ---------- + space: Space + Class that defines how to calculate the displacement between two points and apply the boundary conditions + This should not be changed after initialization. + cutoff: unit.Quantity, default = 1.2 unit.nanometer + Cutoff distance for the neighborlist + + """ if not isinstance(space, Space): raise TypeError(f"space must be of type Space, found {type(space)}") if not cutoff.unit.is_compatible(unit.angstrom): raise ValueError( f"cutoff must be a unit.Quantity with units of distance, cutoff.unit = {cutoff.unit}" ) - self.cutoff = cutoff.value_in_unit_system(unit.md_unit_system) + self.cutoff = cutoff self.space = space @abstractmethod def build( self, - coordinates: Union[jnp.array, unit.Quantity], + positions: Union[jnp.array, unit.Quantity], box_vectors: Union[jnp.array, unit.Quantity], ): """ - Build list from an array of coordinates and array of box vectors. + Build list from an array of positions and array of box vectors. Parameters ---------- - coordinates: jnp.array or unit.Quantity - Shape[n_particles,3] array of particle coordinates, either with or without units attached. + positions: jnp.array or unit.Quantity + Shape[n_particles,3] array of particle positions, either with or without units attached. If the array is passed as a unit.Quantity, the units must be distances and will be converted to nanometers. box_vectors: jnp.array or unit.Quantity Shape[3,3] array of box vectors for the system, either with or without units attached. @@ -261,24 +249,36 @@ def build( def _validate_build_inputs( self, - coordinates: Union[jnp.array, unit.Quantity], + positions: Union[jnp.array, unit.Quantity], box_vectors: Union[jnp.array, unit.Quantity], ): """ Validate the inputs to the build function. + + This will raise ValueErrors if the inputs are not of the correct type or shape or compatible units + + Parameters + ---------- + positions: jnp.array or unit.Quantity + Shape[n_particles,3] array of particle positions, either with or without units attached. + If the array is passed as a unit.Quantity, the units must be distances and will be converted to nanometers. + box_vectors: jnp.array or unit.Quantity + Shape[3,3] array of box vectors for the system, either with or without units attached. + If the array is passed as a unit.Quantity, the units must be distances and will be converted to nanometers. + """ - if isinstance(coordinates, unit.Quantity): - if not coordinates.unit.is_compatible(unit.nanometer): + if isinstance(positions, unit.Quantity): + if not positions.unit.is_compatible(unit.nanometer): raise ValueError( - f"Coordinates require distance units, not {coordinates.unit}" + f"Positions require distance units, not {positions.unit}" ) - self.ref_coordinates = coordinates.value_in_unit_system(unit.md_unit_system) - if isinstance(coordinates, jnp.ndarray): - if coordinates.shape[1] != 3: + self.ref_positions = positions.value_in_unit_system(unit.md_unit_system) + if isinstance(positions, jnp.ndarray): + if positions.shape[1] != 3: raise ValueError( - f"coordinates should be a Nx3 array, shape provided: {coordinates.shape}" + f"positions should be a Nx3 array, shape provided: {positions.shape}" ) - self.ref_coordinates = coordinates + self.ref_positions = positions if isinstance(box_vectors, unit.Quantity): if not box_vectors.unit.is_compatible(unit.nanometer): raise ValueError( @@ -300,7 +300,7 @@ def build_from_state(self, sampler_state: SamplerState): Parameters ---------- sampler_state: SamplerState - SamplerState object containing the coordinates and box vectors + SamplerState object containing the positions and box vectors Returns ------- @@ -309,22 +309,22 @@ def build_from_state(self, sampler_state: SamplerState): if not isinstance(sampler_state, SamplerState): raise TypeError(f"Expected SamplerState, got {type(sampler_state)} instead") - coordinates = sampler_state.x0 + positions = sampler_state.positions if sampler_state.box_vectors is None: raise ValueError(f"SamplerState does not contain box vectors") box_vectors = sampler_state.box_vectors - self.build(coordinates, box_vectors) + self.build(positions, box_vectors) @abstractmethod - def calculate(self, coordinates: jnp.array): + def calculate(self, positions: jnp.array): """ - Calculate the neighbor list for the current state + Calculate the list of interacting particles for the current state Parameters ---------- - coordinates: jnp.array - Shape[N,3] array of particle coordinates + positions: jnp.array + Shape[N,3] array of particle positions Returns ------- @@ -343,15 +343,16 @@ def calculate(self, coordinates: jnp.array): pass @abstractmethod - def check(self, coordinates: jnp.array) -> bool: + def check(self, positions: jnp.array) -> bool: """ - Check if the internal variables need to be reset. E.g., rebuilding a neighborlist - Should do nothing for a simple pairlist. + Check if the internal variables need to be reset. E.g., rebuilding a neighborlist if particles moved to far, + or rebuilding if number of particles changes. + Parameters ---------- - coordinates: jnp.array - Array of particle coordinates + positions: jnp.array + Array of particle positions Returns ------- bool @@ -362,22 +363,69 @@ def check(self, coordinates: jnp.array) -> bool: class NeighborListNsqrd(PairsBase): """ - N^2 neighborlist implementation that returns the particle pair ids, displacement vectors, and distances. + A JAX based neighbor list implementation used to determine which pairs of particles are interacting + (i.e., those particles that fall within the specified cutoff). + + The neighbor list (i.e., list of particles within a distance of cutoff+skin of a given particle) is generated + within the `build` function using an O(N^2) calculation rather than using a spatial partitioning scheme + (e.g., cell-list). The `calculate` function that uses the neighbor list to determine which particle pairs are + interacting and determine the distances and displacement vectors between interacting pairs of particles for + use in the calculation of the interaction energies/forces. The routines are subject to the boundary conditions + specified by the Space class. + + Notes: + This neighbor list not include self-interactions and only includes unique pairs (i.e., no double-counting). + This is sometimes referred to as a "half" neighbor list. E.g. consider the pair of neighboring particles (A, B): + in the "half" neighbor list approach, B is in the neighbor list of A, but A is not in the neighbor list of B + as that pair is already accounted for. + . + The output of the `calculate` function is padded to a fixed size, `n_max_neighbors` (default=100), + to allow for efficient jitted computations in JAX. As such, values need to be masked using the `padding_mask` + array returned by the `calculate` function. The padding mask is an array of 1s and 0s, where 1 indicates an + interacting neighbor and 0 indicates the pair is either non-interacting or simply a padded value. + The `build` function will iteratively increase `n_max_neighbors` by 10 until we can store all neighbors. + + The `check` function, which indicates if the neighbor list should be rebuilt, will return True if: + - the number of particles changes + - any of the particles have moved more than half the skin distance from their reference positions (i.e., the + positions of particles when the neighbor list was last built). - Parameters - ---------- - space: Space - Class that defines how to calculate the displacement between two points and apply the boundary conditions - cutoff: float, default = 2.5 - Cutoff distance for the neighborlist - skin: float, default = 0.4 - Skin distance for the neighborlist - n_max_neighbors: int, default=200 - Maximum number of neighbors for each particle. Used for padding arrays for efficient jax computations - This will be checked and dynamically updated during the build stage - Examples - -------- + Parameters + ---------- + space: Space + Class that defines how to calculate the displacement between two points and apply the boundary conditions. + This should not be changed after initialization. + cutoff: unit.Quantity, default = 1.2 unit.nanometer + Cutoff distance for the neighborlist + skin: unit.Quantity, default = 0.4 unit.nanometer + Skin distance, i.e., buffer, for the neighborlist + Larger values of the skin will reduce the frequency of rebuilding the neighbor list, + but will increase the number of neighbors to consider. + n_max_neighbors: int, default=200 + Maximum number of neighbors for each particle. This is used for padding arrays for efficient jax computations + n_max_neighbors will be dynamically updated (in increments of 10) as part of the build function. + Examples + -------- + >>> from openmm import unit + >>> import jax.numpy as jnp + >>> + >>> from chiron.states import SamplerState + >>> sampler_state = SamplerState(positions=jnp.array([[0.0, 0.0, 0.0], [2, 0.0, 0.0], [0.0, 2, 0.0]])*unit.nanometer, + >>> box_vectors=jnp.array([[10, 0.0, 0.0], [0.0, 10, 0.0], [0.0, 0.0, 10]])*unit.nanometer) + >>> + >>> from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace + >>> nbr_list = NeighborListNsqrd(OrthogonalPeriodicSpace(), cutoff=1.2*unit.nanometer, skin=0.4*unit.nanometer) + >>> + >>> # build the neighborlist + >>> nbr_list.build_from_state(sampler_state) # build the pair list from the sampler state + >>> + >>> # calculate which particles are interacting along with their distances and displacement vectors + >>> n_neighbors, neighbor_list, padding_mask, dist, r_ij = nbr_list.calculate(sampler_state.positions) + >>> + >>> # check the neighborlist + >>> if nbr_list.check(sampler_state.positions): + >>> nbr_list.build_from_state(sampler_state) # rebuild the pair list from the sampler state """ @@ -390,27 +438,64 @@ def __init__( ): if not isinstance(space, Space): raise TypeError(f"space must be of type Space, found {type(space)}") - if not cutoff.unit.is_compatible(unit.angstrom): - raise ValueError( - f"cutoff must be a unit.Quantity with units of distance, cutoff.unit = {cutoff.unit}" - ) + if not skin.unit.is_compatible(unit.angstrom): raise ValueError( f"cutoff must be a unit.Quantity with units of distance, skin.unit = {skin.unit}" ) - self.cutoff = cutoff.value_in_unit_system(unit.md_unit_system) - self.skin = skin.value_in_unit_system(unit.md_unit_system) - self.cutoff_and_skin = self.cutoff + self.skin + self.cutoff = cutoff + self.skin = skin self.n_max_neighbors = n_max_neighbors self.space = space - # set a a simple variable to know if this has at least been built once as opposed to just initialized - # this does not imply that the neighborlist is up to date + # this variable will ensure that `calculate` will fail if we try to call it before building + # note: self.is_built=True does not imply that the neighborlist is up-to-date + self.is_built = False + + @property + def cutoff(self) -> unit.Quantity: + return self._cutoff + + @cutoff.setter + def cutoff(self, cutoff: unit.Quantity) -> None: + if not cutoff.unit.is_compatible(unit.nanometer): + raise ValueError( + f"cutoff must be a unit.Quantity with units of distance, cutoff.unit = {cutoff.unit}" + ) + self._cutoff = cutoff + + # if we change the cutoff or skin we need to rebuild + # we will set the variable to ensure that attempts to call the calculate function will fail if + # we have not rebuilt the neighbor list + self.is_built = False + + @property + def skin(self) -> unit.Quantity: + return self._skin + + @skin.setter + def skin(self, skin: unit.Quantity) -> None: + if not skin.unit.is_compatible(unit.nanometer): + raise ValueError( + f"skin must be a unit.Quantity with units of distance, skin.unit = {skin.unit}" + ) + self._skin = skin + + # if we change the cutoff or skin we need to rebuild + # we will set the variable to ensure that attempts to call the calculate function will fail if + # we have not rebuilt the neighbor list self.is_built = False - # note, we need to use the partial decorator in order to use the jit decorate - # so that it knows to ignore the `self` argument + # Note, we need to use the partial decorator and declare self as static in order to JIT a function within a class. + # This approach treats internal variables of the class as static within this function; e.g., if set self.cutoff = 2, + # called the function, then changed it to 3, the value of self.cutoff in this function would still be 2. + # Thus, we need to pass any variables that may change as arguments, rather than referencing self.variable_name. + # While we could create a custom pytree instead of declaring the class as static (allowing us to reference class + # variables directly within the JITTED function), any changes to those internal variables, say self.cutoff, + # would mean a change to the hash of any JITTEd function that depends on the variable, requiring JAX to recompile + # the function, which is a slow operation. As such, it is also more efficient to just pass variables as arguments. + @partial(jax.jit, static_argnums=(0,)) def _pairs_mask(self, particle_ids: jnp.array): """ @@ -444,9 +529,18 @@ def _pairs_mask(self, particle_ids: jnp.array): return temp_mask + # note: since n_max_neighbors dictates the output size, we will define it as a static argument + # to allow us to jit this function @partial(jax.jit, static_argnums=(0, 5)) def _build_neighborlist( - self, particle_i, reduction_mask, pid, coordinates, n_max_neighbors + self, + particle_i, + reduction_mask, + pid, + positions, + n_max_neighbors, + cutoff_and_skin, + box_vectors, ): """ Jitted function to build the neighbor list for a single particle @@ -454,13 +548,17 @@ def _build_neighborlist( Parameters ---------- particle_i: jnp.array - X,Y,Z coordinates of particle i + X,Y,Z positions of particle i reduction_mask: jnp.array Mask to exclude self-interactions and double counting of pairs - coordinates: jnp.array - X,Y,Z coordinates of all particles + positions: jnp.array + X,Y,Z positions of all particles n_max_neighbors: int Maximum number of neighbors for each particle. Used for padding arrays for efficient jax computations + cutoff_and_skin: float + Cutoff distance for the neighborlist plus the skin distance, in nanometers. + box_vectors: jnp.array + Box vectors for the system Returns ------- @@ -472,19 +570,27 @@ def _build_neighborlist( Number of neighbors for the particle """ - # calculate the displacement between particle i and all other particles - r_ij, dist = self.space.displacement(particle_i, coordinates) + # Calculate the displacement between particle i and all other particles + # NOTE: It would be safer to pass the displacement calculate as a callable function, instead of referencing + # self.space. If someone changes the boundary conditions (i.e., changes space in the class), + # self.space.displacement will not change since the self is marked as status. + # However, I ran into issues passing a function through vmap, and I haven't been able to figure out how to + # resolve it yet. I do not want to remove vmap, as that would require substantially changing the flow of + # the code. For now, I've noted in the docstring that space should not change after initialization -- CRI + r_ij, dist = self.space.displacement(particle_i, positions, box_vectors) - # neighbor_mask will be an array of length n_particles (i.e., length of coordinates) + # neighbor_mask will be an array of length n_particles (i.e., length of positions) # where each element is True if the particle is a neighbor, False if it is not # subject to both the cutoff+skin and the reduction mask that eliminates double counting and self-interactions neighbor_mask = jnp.where( - (dist < self.cutoff_and_skin) & (reduction_mask), True, False + (dist < cutoff_and_skin) & (reduction_mask), True, False ) # when we pad the neighbor list, we will use last particle id in the neighbor list - # this choice was made such that when we use the neighbor list in the masked energy calculat + # this choice was made such that when we use the neighbor list in the masked energy calculation # the padded values will result in reasonably well defined values fill_value = jnp.argmax(neighbor_mask) + # if the max value is the same as the particle of interest, which can occur if particle 0 has no neighbors + # we will just increment by 1 to avoid calculating a self interaction fill_value = jnp.where(fill_value == pid, fill_value + 1, fill_value) # count up the number of neighbors @@ -506,16 +612,16 @@ def _build_neighborlist( def build( self, - coordinates: Union[jnp.array, unit.Quantity], + positions: Union[jnp.array, unit.Quantity], box_vectors: Union[jnp.array, unit.Quantity], ): """ - Build the neighborlist from an array of coordinates and box vectors. + Build the neighbor list from an array of positions and box vectors. Parameters ---------- - coordinates: jnp.array - Shape[N,3] array of particle coordinates + positions: jnp.array + Shape[N,3] array of particle positions box_vectors: jnp.array Shape[3,3] array of box vectors @@ -525,14 +631,14 @@ def build( """ - # set our reference coordinates - # the call to x0 and box_vectors automatically convert these to jnp arrays in the correct unit system - if isinstance(coordinates, unit.Quantity): - if not coordinates.unit.is_compatible(unit.nanometer): + # set our reference positions + # the call to positions and box_vectors automatically convert these to jnp arrays in the correct unit system + if isinstance(positions, unit.Quantity): + if not positions.unit.is_compatible(unit.nanometer): raise ValueError( - f"Coordinates require distance units, not {coordinates.unit}" + f"Positions require distance units, not {positions.unit}" ) - coordinates = coordinates.value_in_unit_system(unit.md_unit_system) + positions = positions.value_in_unit_system(unit.md_unit_system) if isinstance(box_vectors, unit.Quantity): if not box_vectors.unit.is_compatible(unit.nanometer): @@ -546,16 +652,17 @@ def build( f"box_vectors should be a 3x3 array, shape provided: {box_vectors.shape}" ) - self.ref_coordinates = coordinates + self.ref_positions = positions self.box_vectors = box_vectors + cutoff_and_skin = self.cutoff + self.skin + # the neighborlist assumes that the box vectors do not change between building and calculating the neighbor list # changes to the box vectors require rebuilding the neighbor list - self.space.box_vectors = self.box_vectors # store the ids of all the particles self.particle_ids = jnp.array( - range(0, self.ref_coordinates.shape[0]), dtype=jnp.uint32 + range(0, self.ref_positions.shape[0]), dtype=jnp.uint32 ) # calculate which pairs to exclude @@ -569,13 +676,15 @@ def build( # n_neighbors: an array of shape (n_particles) where each element is the number of neighbors for that particle self.neighbor_mask, self.neighbor_list, self.n_neighbors = jax.vmap( - self._build_neighborlist, in_axes=(0, 0, 0, None, None) + self._build_neighborlist, in_axes=(0, 0, 0, None, None, None, None) )( - self.ref_coordinates, + self.ref_positions, reduction_mask, self.particle_ids, - self.ref_coordinates, + self.ref_positions, self.n_max_neighbors, + cutoff_and_skin.value_in_unit_system(unit.md_unit_system), + self.box_vectors, ) self.neighbor_list = self.neighbor_list.reshape(-1, self.n_max_neighbors) @@ -588,13 +697,15 @@ def build( self.n_max_neighbors = int(jnp.max(self.n_neighbors) + 10) self.neighbor_mask, self.neighbor_list, self.n_neighbors = jax.vmap( - self._build_neighborlist, in_axes=(0, 0, 0, None, None) + self._build_neighborlist, in_axes=(0, 0, 0, None, None, None, None) )( - self.ref_coordinates, + self.ref_positions, reduction_mask, self.particle_ids, - self.ref_coordinates, + self.ref_positions, self.n_max_neighbors, + cutoff_and_skin.value_in_unit_system(unit.md_unit_system), + self.box_vectors, ) self.neighbor_list = self.neighbor_list.reshape(-1, self.n_max_neighbors) @@ -603,7 +714,7 @@ def build( @partial(jax.jit, static_argnums=(0,)) def _calc_distance_per_particle( - self, particle1, neighbors, neighbor_mask, coordinates + self, particle1, neighbors, neighbor_mask, positions, cutoff, box_vectors ): """ Jitted function to calculate the distance between a particle and its neighbors @@ -616,8 +727,12 @@ def _calc_distance_per_particle( Array of particle ids for the neighbors of particle1 neighbor_mask: jnp.array Mask to exclude padding from the neighbor list of particle1 - coordinates: jnp.array - X,Y,Z coordinates of all particles + positions: jnp.array + X,Y,Z positions of all particles + cutoff: float + Cutoff distance for the neighborlist, in nanometers + box_vectors: jnp.array + Box vectors for the system Returns ------- @@ -635,26 +750,27 @@ def _calc_distance_per_particle( particles1 = jnp.repeat(particle1, neighbors.shape[0]) # calculate the displacement between particle i and all neighbors + # See note above: if self.space changes, it will not show up here because self is static. r_ij, dist = self.space.displacement( - coordinates[particles1], coordinates[neighbors] + positions[particles1], positions[neighbors], box_vectors ) # calculate the mask to determine if the particle is a neighbor # this will be done based on the interaction cutoff and using the neighbor_mask to exclude padding - mask = jnp.where((dist < self.cutoff) & (neighbor_mask), 1, 0) + mask = jnp.where((dist < cutoff) & (neighbor_mask), 1, 0) # calculate the number of pairs n_pairs = mask.sum() return n_pairs, mask, dist, r_ij - def calculate(self, coordinates: jnp.array): + def calculate(self, positions: jnp.array): """ Calculate the neighbor list for the current state Parameters ---------- - coordinates: jnp.array - Shape[N,3] array of particle coordinates + positions: jnp.array + Shape[N,3] array of particle positions Returns ------- @@ -669,20 +785,34 @@ def calculate(self, coordinates: jnp.array): r_ij: jnp.array Array of displacement vectors between each particle and its neighbors. Shape (n_particles, n_max_neighbors, 3) """ - # coordinates = sampler_state.x0 + # positions = sampler_state.positions # note, we assume the box vectors do not change between building and calculating the neighbor list # changes to the box vectors require rebuilding the neighbor list n_neighbors, padding_mask, dist, r_ij = jax.vmap( - self._calc_distance_per_particle, in_axes=(0, 0, 0, None) - )(self.particle_ids, self.neighbor_list, self.neighbor_mask, coordinates) + self._calc_distance_per_particle, in_axes=(0, 0, 0, None, None, None) + )( + self.particle_ids, + self.neighbor_list, + self.neighbor_mask, + positions, + self.cutoff.value_in_unit_system(unit.md_unit_system), + self.box_vectors, + ) # mask = mask.reshape(-1, self.n_max_neighbors) return n_neighbors, self.neighbor_list, padding_mask, dist, r_ij @partial(jax.jit, static_argnums=(0,)) - def _calculate_particle_displacement(self, particle, coordinates, ref_coordinates): + def _calculate_particle_displacement( + self, + particle: int, + positions: jnp.array, + ref_positions: jnp.array, + skin: float, + box_vectors: jnp.array, + ): """ - Calculate the displacement of a particle from the reference coordinates. + Calculate the displacement of a particle from the reference positions. If the displacement exceeds the half the skin distance, return True, otherwise return False. This function is designed to allow it to be jitted and vmapped over particle indices. @@ -691,50 +821,61 @@ def _calculate_particle_displacement(self, particle, coordinates, ref_coordinate ---------- particle: int Particle id - coordinates: jnp.array - Array of particle coordinates - ref_coordinates: jnp.array - Array of reference particle coordinates + positions: jnp.array + Array of particle positions + ref_positions: jnp.array + Array of reference particle positions + skin: float + Skin distance for the neighborlist, in nanometers + box_vectors: jnp.array + Box vectors for the system + Returns ------- bool True if the particle is outside the skin distance, False if it is not. """ - # calculate the displacement of a particle from the initial coordinates - + # calculate the displacement of a particle from the initial positions + # again, note that if self.space changes, it will not show up here because self is static. r_ij, displacement = self.space.displacement( - coordinates[particle], ref_coordinates[particle] + positions[particle], ref_positions[particle], box_vectors ) - status = jnp.where(displacement >= self.skin / 2.0, True, False) + status = jnp.where(displacement >= skin / 2.0, True, False) del displacement return status - def check(self, coordinates: jnp.array) -> bool: + def check(self, positions: jnp.array) -> bool: """ - Check if the neighbor list needs to be rebuilt based on displacement of the particles from the reference coordinates. + Check if the neighbor list needs to be rebuilt based on displacement of the particles from the reference positions. If a particle moves more than 0.5 skin distance, the neighborlist will be rebuilt. - Will also return True if the size of the coordinates array changes. + Will also return True if the size of the positions array changes. Note, this could also accept a user defined criteria for distance, but this is not implemented yet. Parameters ---------- - coordinates: jnp.array - Array of particle coordinates + positions: jnp.array + Array of particle positions Returns ------- bool True if the neighbor list needs to be rebuilt, False if it does not. """ - if self.ref_coordinates.shape[0] != coordinates.shape[0]: + if self.ref_positions.shape[0] != positions.shape[0]: return True status = jax.vmap( - self._calculate_particle_displacement, in_axes=(0, None, None) - )(self.particle_ids, coordinates, self.ref_coordinates) + self._calculate_particle_displacement, in_axes=(0, None, None, None, None) + )( + self.particle_ids, + positions, + self.ref_positions, + self.skin.value_in_unit_system(unit.md_unit_system), + self.box_vectors, + ) if jnp.any(status): del status return True @@ -743,56 +884,114 @@ def check(self, coordinates: jnp.array) -> bool: return False -class PairList(PairsBase): +class PairListNsqrd(PairsBase): """ - N^2 pairlist implementation that returns the particle pair ids, displacement vectors, and distances. + A class that implements a simple pair list using JAX that determine which pairs of particles are interacting. + This class can be defined with cutoff (i.e., only returning information about pairs separated by distances + less than the cutoff) or without a cutoff (i.e., information about all possible pairs are returned). + Note, in both cases, distances are calculated using the boundary conditions defined by the simulation Space class + and only unique pairs are returned (i.e., no double counting and no self-interactions). + + This performs an O(N^2) calculation each time the `calculate` function is called and thus will be inefficient + for all but very small system sizes. + + The calculate function will return various pieces of information about the interacting pairs + (e.g., number of neighbors, neighbor ids, distances, displacement vectors) that can be used to calculate the + interaction potential/force. For efficiency of the jitted functions, the `calculate` function array + sizes are fixed. For example, distance has shape (n_particles, n_particles-1), regardless of the number of particles + that are actually neighbors (note: self interactions are removed hence n_particles-1). The `padding_mask` array + returned by `calculate` is used to exclude those pairs that are not interacting. The `padding_mask` contains values + of 1 for interacting particles and 0 for non-interacting. Parameters ---------- space: Space Class that defines how to calculate the displacement between two points and apply the boundary conditions - cutoff: float, default = 2.5 - Cutoff distance for the pair list calculation + cutoff: Optional[unit.Quantity], default = None + Cutoff distance for the pair list calculation. If None, the pair list will be calculated without a cutoff, + applying the boundary conditions as defined in space. + Examples -------- - >>> from chiron.neighbors import PairList, OrthogonalPeriodicSpace - >>> from chiron.states import SamplerState >>> import jax.numpy as jnp + >>> import openmm.unit as unit >>> - >>> space = OrthogonalPeriodicSpace() - >>> pair_list = PairList(space, cutoff=2.5) - >>> sampler_state = SamplerState(x0=jnp.array([[0.0, 0.0, 0.0], [2, 0.0, 0.0], [0.0, 2, 0.0]]), + >>> from chiron.states import SamplerState + >>> sampler_state = SamplerState(positions=jnp.array([[0.0, 0.0, 0.0], [2, 0.0, 0.0], [0.0, 2, 0.0]]), >>> box_vectors=jnp.array([[10, 0.0, 0.0], [0.0, 10, 0.0], [0.0, 0.0, 10]])) + >>> + >>> from chiron.neighbors import PairListNsqrd, OrthogonalPeriodicSpace + >>> pair_list = PairListNsqrd(OrthogonalPeriodicSpace(), cutoff=1.2*unit.nanometer) >>> pair_list.build_from_state(sampler_state) >>> - >>> # mask and distances are of shape (n_particles, n_particles-1), - >>> displacement_vectors of shape (n_particles, n_particles-1, 3) - >>> # mask, is a bool array that is True if the particle is within the cutoff distance, False if it is not - >>> # n_pairs is of shape (n_particles) and is per row sum of the mask. The mask ensure we also do not double count pairs - >>> n_pairs, mask, distances, displacement_vectors = pair_list.calculate(sampler_state.x0) + >>> # n_pairs is of shape (n_particles) and is per row sum of the padding_mask. + >>> # pairs, padding mask and distances are of shape (n_particles, n_particles-1), + >>> # displacement_vectors are of shape (n_particles, n_particles-1, 3) + >>> # padding_mask, is a bool array that is True if the particle is within the cutoff distance, False if it is not + >>> n_pairs, pairs, padding_mask, distances, displacement_vectors = pair_list.calculate(sampler_state.positions) """ def __init__( self, space: Space, - cutoff: unit.Quantity = unit.Quantity(1.2, unit.nanometer), + cutoff: Optional[unit.Quantity] = None, ): + """ + Initialize the PairListNsqrd class + + Parameters + ---------- + space: Space + Class that defines how to calculate the displacement between two points and apply the boundary conditions. + This should not change after initialization. + cutoff: Optional[unit.Quantity], default = None + Cutoff distance for the pair list calculation. If None, the pair list will be calculated without a cutoff. + """ if not isinstance(space, Space): raise TypeError(f"space must be of type Space, found {type(space)}") - if not cutoff.unit.is_compatible(unit.angstrom): - raise ValueError( - f"cutoff must be a unit.Quantity with units of distance, cutoff.unit = {cutoff.unit}" - ) - self.cutoff = cutoff.value_in_unit_system(unit.md_unit_system) + # keeping this public in case we want to change it later + # validation is performed in the setter + self.cutoff = cutoff + self.space = space - # set a a simple variable to know if this has at least been built once as opposed to just initialized - # this does not imply that the neighborlist is up to date + # the init function does not setup the internal arrays we need to use calculate + # this is handled in the `build` function + # this variable can be used to check that the pair list has been built before trying to use it self.is_built = False - # note, we need to use the partial decorator in order to use the jit decorate - # so that it knows to ignore the `self` argument + @property + def cutoff(self): + """ + Cutoff distance for the pair list calculation. If None, the pair list will be calculated without a cutoff. + + Returns + ------- + cutoff: unit.Quantity + Cutoff distance for the pair list calculation. If None, the pair list will be calculated without a cutoff. + """ + return self._cutoff + + @cutoff.setter + def cutoff(self, cutoff): + if cutoff is not None: + if not cutoff.unit.is_compatible(unit.angstrom): + raise ValueError( + f"cutoff must be a unit.Quantity with units of distance, cutoff.unit = {cutoff.unit}" + ) + # Note, since this is just a simple pair list, we do not need to rebuild by changing the cutoff + self._cutoff = cutoff + + # Note, we need to use the partial decorator and declare self as static in order to JIT a function within a class. + # As mentioned in a comment above in the NeighborListNsqrd class, this approach treats internal variables of the + # class as static within this function; e.g., if set self.cutoff = 2, called the function, then changed it to 3, + # the value of self.cutoff in this function would still be 2. Thus, we need to pass any variables that may change + # as arguments, rather than referencing self.variable_name. While we could create a custom pytree instead of + # declaring the class as static (allowing us to reference class variables directly within the JITTED function), + # any changes to those internal variables, say self.cutoff, would mean a change to the hash of any JITTEd function + # that depends on the variable, requiring JAX to recompile the function, which is a slow operation. + # As such, it is also more efficient to just pass variables as arguments. @partial(jax.jit, static_argnums=(0,)) def _pairs_and_mask(self, particle_ids: jnp.array): """ @@ -826,12 +1025,14 @@ def _pairs_and_mask(self, particle_ids: jnp.array): particles_i = jnp.reshape(particle_ids, (particle_ids.shape[0], 1)) # create a mask to exclude self interactions and double counting temp_mask = particles_i != particles_j + # remove self interactions all_pairs = jax.vmap(self._remove_self_interactions, in_axes=(0, 0))( particles_j, temp_mask ) del temp_mask all_pairs = jnp.array(all_pairs[0], dtype=jnp.uint32) + # create the mask that will remove any double counting of pairs reduction_mask = jnp.where(particles_i < all_pairs, True, False) return all_pairs, reduction_mask @@ -844,16 +1045,16 @@ def _remove_self_interactions(self, particles, temp_mask): def build( self, - coordinates: Union[jnp.array, unit.Quantity], + positions: Union[jnp.array, unit.Quantity], box_vectors: Union[jnp.array, unit.Quantity], ): """ - Build the neighborlist from an array of coordinates and box vectors. + Build the list from an array of positions and box vectors. Parameters ---------- - coordinates: jnp.array - Shape[n_particles,3] array of particle coordinates + positions: jnp.array + Shape[n_particles,3] array of particle positions box_vectors: jnp.array Shape[3,3] array of box vectors @@ -863,27 +1064,80 @@ def build( """ - # set our reference coordinates - # this will set self.ref_coordinates=coordinates and self.box_vectors - self._validate_build_inputs(coordinates, box_vectors) + # validate the positions and box vectors + self._validate_build_inputs(positions, box_vectors) - self.n_particles = self.ref_coordinates.shape[0] + self.n_particles = self.ref_positions.shape[0] - # the neighborlist assumes that the box vectors do not change between building and calculating the neighbor list - # changes to the box vectors require rebuilding the neighbor list - self.space.box_vectors = self.box_vectors + # the PairsList assumes that the box vectors do not change between building and calculating the neighbor list # store the ids of all the particles - self.particle_ids = jnp.array(range(0, coordinates.shape[0]), dtype=jnp.uint32) + self.particle_ids = jnp.array(range(0, positions.shape[0]), dtype=jnp.uint32) # calculate which pairs to exclude self.all_pairs, self.reduction_mask = self._pairs_and_mask(self.particle_ids) self.is_built = True - @partial(jax.jit, static_argnums=(0,)) - def _calc_distance_per_particle( - self, particle1, neighbors, neighbor_mask, coordinates + @partial(jax.jit, static_argnums=(0)) + def _calc_distance_per_particle_with_cutoff( + self, particle1, neighbors, neighbor_mask, positions, cutoff, box_vectors + ): + """ + Jitted function to calculate the distance between a particle and all possible neighbors + + Parameters + ---------- + particle1: int + Particle id + neighbors: jnp.array + Array of particle ids for the possible particle pairs of particle1 + neighbor_mask: jnp.array + Mask to exclude double particles to prevent double counting + positions: jnp.array + X,Y,Z positions of all particles, shaped (n_particles, 3) + cutoff: float + Cutoff distance for the interaction. + box_vectors: jnp.array + Box vectors for the system + + Returns + ------- + n_pairs: int + Number of interacting pairs for the particle + mask: jnp.array + Mask to exclude padding particles not within the cutoff particle1. + If a particle is within the interaction cutoff, the mask is 1, otherwise it is 0 + Array has shape (n_particles, n_particles-1) as it excludes self interactions + dist: jnp.array + Array of distances between the particle and all other particles in the system. + Array has shape (n_particles, n_particles-1) as it excludes self interactions + r_ij: jnp.array + Array of displacement vectors between the particle and all other particles in the system. + Array has shape (n_particles, n_particles-1, 3) as it excludes self interactions + . + + """ + # repeat the particle id for each neighbor + particles1 = jnp.repeat(particle1, neighbors.shape[0]) + + # calculate the displacement between particle i and all neighbors + # See note above: if self.space changes, it will not show up here because self is static. + r_ij, dist = self.space.displacement( + positions[particles1], positions[neighbors], box_vectors + ) + # calculate the mask to determine if the particle is a neighbor + # this will be done based on the interaction cutoff and using the neighbor_mask to exclude padding + mask = jnp.where((dist < cutoff) & (neighbor_mask), 1, 0) + + # calculate the number of pairs + n_pairs = mask.sum() + + return n_pairs, mask, dist, r_ij + + @partial(jax.jit, static_argnums=(0)) + def _calc_distance_per_particle_no_cutoff( + self, particle1, neighbors, neighbor_mask, positions, box_vectors ): """ Jitted function to calculate the distance between a particle and all possible neighbors @@ -896,8 +1150,10 @@ def _calc_distance_per_particle( Array of particle ids for the possible particle pairs of particle1 neighbor_mask: jnp.array Mask to exclude double particles to prevent double counting - coordinates: jnp.array - X,Y,Z coordinates of all particles, shaped (n_particles, 3) + positions: jnp.array + X,Y,Z positions of all particles, shaped (n_particles, 3) + box_vectors: jnp.array + Box vectors of the system Returns ------- @@ -914,31 +1170,33 @@ def _calc_distance_per_particle( Array of displacement vectors between the particle and all other particles in the system. Array has shape (n_particles, n_particles-1, 3) as it excludes self interactions + """ # repeat the particle id for each neighbor particles1 = jnp.repeat(particle1, neighbors.shape[0]) # calculate the displacement between particle i and all neighbors + # See note above: if self.space changes, it will not show up here because self is static. r_ij, dist = self.space.displacement( - coordinates[particles1], coordinates[neighbors] + positions[particles1], positions[neighbors], box_vectors ) # calculate the mask to determine if the particle is a neighbor # this will be done based on the interaction cutoff and using the neighbor_mask to exclude padding - mask = jnp.where((dist < self.cutoff) & (neighbor_mask), 1, 0) + mask = jnp.where(neighbor_mask, 1, 0) # calculate the number of pairs n_pairs = mask.sum() return n_pairs, mask, dist, r_ij - def calculate(self, coordinates: jnp.array): + def calculate(self, positions: jnp.array): """ - Calculate the neighbor list for the current state + Calculate the list of neighbor pairs for the current state Parameters ---------- - coordinates: jnp.array - Shape[n_particles,3] array of particle coordinates + positions: jnp.array + Shape[n_particles,3] array of particle positions Returns ------- @@ -953,35 +1211,53 @@ def calculate(self, coordinates: jnp.array): r_ij: jnp.array Array of displacement vectors between particle pairs. Shape: (n_particles, n_particles-1, 3). """ - if coordinates.shape[0] != self.n_particles: + if positions.shape[0] != self.n_particles: raise ValueError( f"Number of particles cannot changes without rebuilding. " - f"Coordinates must have shape ({self.n_particles}, 3), found {coordinates.shape}" + f"Positions must have shape ({self.n_particles}, 3), found {positions.shape}" ) - # coordinates = self.space.wrap(coordinates) - - n_neighbors, padding_mask, dist, r_ij = jax.vmap( - self._calc_distance_per_particle, in_axes=(0, 0, 0, None) - )(self.particle_ids, self.all_pairs, self.reduction_mask, coordinates) - + # if we did not define a cutoff, we will + if self.cutoff is None: + n_neighbors, padding_mask, dist, r_ij = jax.vmap( + self._calc_distance_per_particle_no_cutoff, + in_axes=(0, 0, 0, None, None), + )( + self.particle_ids, + self.all_pairs, + self.reduction_mask, + positions, + self.box_vectors, + ) + else: + n_neighbors, padding_mask, dist, r_ij = jax.vmap( + self._calc_distance_per_particle_with_cutoff, + in_axes=(0, 0, 0, None, None, None), + )( + self.particle_ids, + self.all_pairs, + self.reduction_mask, + positions, + self.cutoff.value_in_unit_system(unit.md_unit_system), + self.box_vectors, + ) return n_neighbors, self.all_pairs, padding_mask, dist, r_ij - def check(self, coordinates: jnp.array) -> bool: + def check(self, positions: jnp.array) -> bool: """ Check if we need to reconstruct internal arrays. For a simple pairlist this will always return False, unless the number of particles change. Parameters ---------- - coordinates: jnp.array - Array of particle coordinates + positions: jnp.array + Array of particle positions Returns ------- bool True if we need to rebuild the neighbor list, False if we do not. """ - if coordinates.shape[0] != self.n_particles: + if positions.shape[0] != self.n_particles: return True else: return False diff --git a/chiron/potential.py b/chiron/potential.py index 6d9b415..1d2d340 100644 --- a/chiron/potential.py +++ b/chiron/potential.py @@ -63,6 +63,70 @@ def compute_pairlist(self, positions, cutoff) -> jnp.array: return distance[interacting_mask], displacement_vectors[interacting_mask], pairs +class IdealGasPotential(NeuralNetworkPotential): + def __init__( + self, + topology: Topology, + ): + """ + Initialize the Ideal Gas potential. + + Parameters + ---------- + topology : Topology + The topology of the system + + """ + + if not isinstance(topology, (Topology, property)) and topology is not None: + raise TypeError( + f"Topology must be a Topology object, a property, or None, got type(topology) = {type(topology)}" + ) + + self.topology = topology + + def compute_energy(self, positions: jnp.array, nbr_list=None, debug_mode=False): + """ + Compute the energy for an ideal gas, which is always 0. + + Parameters + ---------- + positions : jnp.array + The positions of the particles in the system + nbr_list : NeighborList, default=None + Instance of a neighbor list or pair list class to use. + If None, an unoptimized N^2 pairlist will be used without PBC conditions. + Returns + ------- + potential_energy : float + The total potential energy of the system. + + """ + # Compute the pair distances and displacement vectors + + return 0.0 + + def compute_force(self, positions: jnp.array, nbr_list=None) -> jnp.array: + """ + Compute the force for ideal gas particles, which is always 0. + + Parameters + ---------- + positions : jnp.array + The positions of the particles in the system + nbr_list : NeighborList, optional + Instance of the neighborlist class to use. By default, set to None, which will use an N^2 pairlist + + Returns + ------- + force : jnp.array + The forces on the particles in the system + + """ + + return 0.0 + + class LJPotential(NeuralNetworkPotential): def __init__( self, @@ -200,7 +264,7 @@ def compute_energy(self, positions: jnp.array, nbr_list=None, debug_mode=False): raise ValueError("Neighborlist must be built before use") # ensure that the cutoff in the neighbor list is the same as the cutoff in the potential - if nbr_list.cutoff != self.cutoff: + if nbr_list.cutoff.value_in_unit_system(unit.md_unit_system) != self.cutoff: raise ValueError( f"Neighborlist cutoff ({nbr_list.cutoff}) must be the same as the potential cutoff ({self.cutoff})" ) @@ -285,7 +349,7 @@ def __init__( The topology object representing the molecular system. k : unit.Quantity, optional The spring constant of the harmonic potential. Default is 1.0 kcal/mol/Å^2. - x0 : unit.Quantity, optional + positions : unit.Quantity, optional The equilibrium position of the harmonic potential. Default is [0.0,0.0,0.0] Å. U0 : unit.Quantity, optional The offset potential energy of the harmonic potential. Default is 0.0 kcal/mol. @@ -302,7 +366,9 @@ def __init__( if not isinstance(k, unit.Quantity): raise TypeError(f"k must be a unit.Quantity, type(k) = {type(k)}") if not isinstance(x0, unit.Quantity): - raise TypeError(f"x0 must be a unit.Quantity, type(x0) = {type(x0)}") + raise TypeError( + f"positions must be a unit.Quantity, type(positions) = {type(x0)}" + ) if not isinstance(U0, unit.Quantity): raise TypeError(f"U0 must be a unit.Quantity, type(U0) = {type(U0)}") @@ -312,9 +378,11 @@ def __init__( ) if not x0.unit.is_compatible(unit.angstrom): raise ValueError( - f"x0 must be a unit.Quantity with units of distance, x0.unit = {x0.unit}" + f"positions must be a unit.Quantity with units of distance, positions.unit = {x0.unit}" ) - assert x0.shape[1] == 3, f"x0 must be a NX3 vector, x0.shape = {x0.shape}" + assert ( + x0.shape[1] == 3 + ), f"positions must be a NX3 vector, positions.shape = {x0.shape}" if not U0.unit.is_compatible(unit.kilocalories_per_mole): raise ValueError( f"U0 must be a unit.Quantity with units of energy, U0.unit = {U0.unit}" @@ -324,9 +392,11 @@ def __init__( log.debug("Initializing HarmonicOscillatorPotential") log.debug(f"k = {k}") - log.debug(f"x0 = {x0}") + log.debug(f"positions = {x0}") log.debug(f"U0 = {U0}") - log.debug("Energy is calculate: U(x) = (K/2) * ( (x-x0)^2 + y^2 + z^2 ) + U0") + log.debug( + "Energy is calculate: U(x) = (K/2) * ( (x-positions)^2 + y^2 + z^2 ) + U0" + ) self.k = jnp.array( k.value_in_unit_system(unit.md_unit_system) ) # spring constant @@ -339,7 +409,7 @@ def __init__( self.topology = topology def compute_energy(self, positions: jnp.array, nbr_list=None): - # the functional form is given by U(x) = (K/2) * ( (x-x0)^2 + y^2 + z^2 ) + U0 + # the functional form is given by U(x) = (K/2) * ( (x-positions)^2 + y^2 + z^2 ) + U0 # https://github.com/choderalab/openmmtools/blob/main/openmmtools/testsystems.py#L695 # compute the displacement vectors diff --git a/chiron/reporters.py b/chiron/reporters.py index 27457a6..156e86a 100644 --- a/chiron/reporters.py +++ b/chiron/reporters.py @@ -367,7 +367,7 @@ def _write_to_trajectory(self, positions: np.ndarray) -> None: file_handler=self._write_xtc_file_handle, positions=positions, iteration=self.buffer.get("step"), - box_vecotrs=self.buffer.get("box_vectors"), + box_vectors=self.buffer.get("box_vectors"), ) def read_from_trajectory(self) -> np.ndarray: @@ -409,7 +409,7 @@ def _write_to_xtc( file_handler: md.formats.XTCTrajectoryFile, positions: np.ndarray, iteration: np.ndarray, - box_vecotrs: Optional[np.ndarray] = None, + box_vectors: Optional[np.ndarray] = None, ): """ Write position data to an XTC file. @@ -428,5 +428,5 @@ def _write_to_xtc( file_handler.write( positions, time=iteration, - box=box_vecotrs, + box=box_vectors, ) diff --git a/chiron/states.py b/chiron/states.py index 99459ae..6ebdca1 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -11,26 +11,40 @@ class SamplerState: Parameters ---------- - x0 : unit.Quantity + positions : unit.Quantity The current positions of the particles in the simulation. velocities : unit.Quantity, optional The velocities of the particles in the simulation. box_vectors : unit.Quantity, optional The box vectors defining the simulation's periodic boundary conditions. + Examples + -------- + + from chiron.states import SamplerState + from chiron.utils import PRNG + from openmmtools.testsystems import HarmonicOscillator + + ho = HarmonicOscillator() + PRNG.set_seed(1234) + + sampler_state = SamplerState(positions = ho.positions, PRNG.get_random_key()) + """ def __init__( self, - x0: unit.Quantity, + positions: unit.Quantity, current_PRNG_key: random.PRNGKey, velocities: Optional[unit.Quantity] = None, box_vectors: Optional[unit.Quantity] = None, ) -> None: # NOTE: all units are internally in the openMM units system as documented here: # http://docs.openmm.org/latest/userguide/theory/01_introduction.html#units - if not isinstance(x0, unit.Quantity): - raise TypeError(f"x0 must be a unit.Quantity, got {type(x0)} instead.") + if not isinstance(positions, unit.Quantity): + raise TypeError( + f"positions must be a unit.Quantity, got {type(positions)} instead." + ) if velocities is not None and not isinstance(velocities, unit.Quantity): raise TypeError( f"velocities must be a unit.Quantity, got {type(velocities)} instead." @@ -45,8 +59,10 @@ def __init__( raise TypeError( f"box_vectors must be a unit.Quantity or openMM box, got {type(box_vectors)} instead." ) - if not x0.unit.is_compatible(unit.nanometer): - raise ValueError(f"x0 must have units of distance, got {x0.unit} instead.") + if not positions.unit.is_compatible(unit.nanometer): + raise ValueError( + f"positions must have units of distance, got {positions.unit} instead." + ) if velocities is not None and not velocities.unit.is_compatible( unit.nanometer / unit.picosecond ): @@ -63,26 +79,27 @@ def __init__( raise ValueError( f"box_vectors must be a 3x3 array, got {box_vectors.shape} instead." ) - if velocities is not None and x0.shape != velocities.shape: + if velocities is not None and positions.shape != velocities.shape: raise ValueError( - f"x0 and velocities must have the same shape, got {x0.shape} and {velocities.shape} instead." + f"positions and velocities must have the same shape, got {positions.shape} and {velocities.shape} instead." ) if current_PRNG_key is None: raise ValueError(f"random_seed must be set.") - self._x0 = x0 + self._positions = positions self._velocities = velocities self._current_PRNG_key = current_PRNG_key self._box_vectors = box_vectors self._distance_unit = unit.nanometer + self._time_unit = unit.picosecond @property def n_particles(self) -> int: - return self._x0.shape[0] + return self._positions.shape[0] @property - def x0(self) -> jnp.array: - return self._convert_to_jnp(self._x0) + def positions(self) -> jnp.array: + return self._convert_to_jnp(self._positions) @property def velocities(self) -> jnp.array: @@ -96,17 +113,40 @@ def box_vectors(self) -> jnp.array: return None return self._convert_to_jnp(self._box_vectors) - @x0.setter - def x0(self, x0: Union[jnp.array, unit.Quantity]) -> None: + @positions.setter + def positions(self, x0: Union[jnp.array, unit.Quantity]) -> None: if isinstance(x0, unit.Quantity): - self._x0 = x0 + self._positions = x0 + else: + self._positions = unit.Quantity(x0, self._distance_unit) + + @box_vectors.setter + def box_vectors(self, box_vectors: Union[jnp.array, unit.Quantity]) -> None: + if isinstance(box_vectors, unit.Quantity): + self._box_vectors = box_vectors else: - self._x0 = unit.Quantity(x0, self._distance_unit) + self._box_vectors = unit.Quantity(box_vectors, self._distance_unit) + + @velocities.setter + def velocities(self, velocities: Union[jnp.array, unit.Quantity]) -> None: + if velocities.shape != self._positions.shape: + raise ValueError( + f"velocities must have the same shape as positions, got {velocities.shape} and {self._positions.shape} instead." + ) + if isinstance(velocities, unit.Quantity): + self._velocities = velocities + else: + self._velocities = unit.Quantity( + velocities, self._distance_unit / self._time_unit + ) @property def distance_unit(self) -> unit.Unit: return self._distance_unit + def velocity_unit(self) -> unit.Unit: + return self._distance_unit / self._time_unit + @property def new_PRNG_key(self) -> random.PRNGKey: key, subkey = random.split(self._current_PRNG_key) @@ -201,7 +241,7 @@ def __init__( from .utils import get_nr_of_particles self.nr_of_particles = get_nr_of_particles(self.potential.topology) - self._check_completness() + self._check_completeness() def check_variables(self) -> None: """ @@ -215,7 +255,7 @@ def check_variables(self) -> None: set_variables = [var for var in variables if getattr(self, var) is not None] return set_variables - def _check_completness(self): + def _check_completeness(self): # check which variables are set set_variables = self.check_variables() from loguru import logger as log @@ -242,7 +282,7 @@ def get_reduced_potential( ---------- sampler_state : SamplerState The sampler state for which to compute the reduced potential. - nbr_list : NeighborList or PairList, optional + nbr_list : NeighborList or PairListNsqrd, optional The neighbor list or pair list routine to use for calculating the reduced potential. Returns @@ -263,17 +303,25 @@ def get_reduced_potential( self.beta = 1.0 / ( unit.BOLTZMANN_CONSTANT_kB * (self.temperature * unit.kelvin) ) - # log.debug(f"sample state: {sampler_state.x0}") + # log.debug(f"sample state: {sampler_state.positions}") reduced_potential = ( unit.Quantity( - self.potential.compute_energy(sampler_state.x0, nbr_list), + self.potential.compute_energy(sampler_state.positions, nbr_list), unit.kilojoule_per_mole, ) ) / unit.AVOGADRO_CONSTANT_NA # log.debug(f"reduced potential: {reduced_potential}") if self.pressure is not None: - reduced_potential += self.pressure * self.volume + self.volume = ( + sampler_state.box_vectors[0][0] + * sampler_state.box_vectors[1][1] + * sampler_state.box_vectors[2][2] + ) * unit.nanometer**3 + from loguru import logger as log + + reduced_potential += self.pressure * self.volume + # add chemical potential return self.beta * reduced_potential def kT_to_kJ_per_mol(self, energy): @@ -295,7 +343,7 @@ def calculate_reduced_potential_at_states( The sampler state for which to compute the reduced potential. thermodynamic_states : list of ThermodynamicState The thermodynamic states for which to compute the reduced potential. - nbr_list : NeighborList or PairList, optional + nbr_list : NeighborList or PairListNsqrd, optional Returns ------- list of float diff --git a/chiron/tests/conftest.py b/chiron/tests/conftest.py index 742d5ab..9f1cca6 100644 --- a/chiron/tests/conftest.py +++ b/chiron/tests/conftest.py @@ -29,8 +29,8 @@ def provide_testsystems_and_potentials(): import jax.numpy as jnp hoa_potential = HarmonicOscillatorPotential( - ho.topology, - ho.K, + hoa.topology, + hoa.K, x0=unit.Quantity( jnp.array( [ @@ -54,5 +54,3 @@ def provide_testsystems_and_potentials(): (hoa, hoa_potential), ] return TESTSYSTEM_AND_POTENTIAL - - diff --git a/chiron/tests/test_convergence_tests.py b/chiron/tests/test_convergence_tests.py index 16cfad3..df30515 100644 --- a/chiron/tests/test_convergence_tests.py +++ b/chiron/tests/test_convergence_tests.py @@ -44,9 +44,17 @@ def test_convergence_of_MC_estimator(prep_temp_dir): from chiron.states import ThermodynamicState, SamplerState thermodynamic_state = ThermodynamicState( - harmonic_potential, temperature=300, volume=30 * (unit.angstrom**3) + harmonic_potential, + temperature=300 * unit.kelvin, + volume=30 * (unit.angstrom**3), + ) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + + sampler_state = SamplerState( + positions=ho.positions, current_PRNG_key=PRNG.get_random_key() ) - sampler_state = SamplerState(ho.positions) from chiron.reporters import _SimulationReporter @@ -55,16 +63,16 @@ def test_convergence_of_MC_estimator(prep_temp_dir): simulation_reporter = _SimulationReporter(f"{prep_temp_dir}/test_{id}.h5") # Initalize the move set (here only LangevinDynamicsMove) - from chiron.mcmc import MetropolisDisplacementMove, MoveSchedule, MCMCSampler + from chiron.mcmc import MonteCarloDisplacementMove, MoveSchedule, MCMCSampler - mc_displacement_move = MetropolisDisplacementMove( - nr_of_moves=100_000, + mc_displacement_move = MonteCarloDisplacementMove( + number_of_moves=1_000, displacement_sigma=0.5 * unit.angstrom, atom_subset=[0], reporter=simulation_reporter, ) - move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) + move_set = MoveSchedule([("MonteCarloDisplacementMove", mc_displacement_move)]) # Initalize the sampler sampler = MCMCSampler(move_set, sampler_state, thermodynamic_state) @@ -82,7 +90,9 @@ def test_convergence_of_MC_estimator(prep_temp_dir): plt.plot(chiron_energy) print("Expectation values generated with chiron") - es = chiron_energy + import jax.numpy as jnp + + es = jnp.array(chiron_energy) print(es.mean(), es.std()) print("Expectation values from openmmtools") @@ -133,12 +143,16 @@ def test_langevin_dynamics_with_LJ_fluid(prep_temp_dir): ) print(lj_fluid.system.getDefaultPeriodicBoxVectors()) + from chiron.utils import PRNG + + PRNG.set_seed(1234) sampler_state = SamplerState( - x0=lj_fluid.positions, + positions=lj_fluid.positions, box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors(), + current_PRNG_key=PRNG.get_random_key(), ) - print(sampler_state.x0.shape) + print(sampler_state.positions.shape) print(sampler_state.box_vectors) nbr_list = NeighborListNsqrd( @@ -153,16 +167,137 @@ def test_langevin_dynamics_with_LJ_fluid(prep_temp_dir): potential=lj_potential, temperature=300 * unit.kelvin ) - from chiron.reporters import _SimulationReporter + from chiron.reporters import LangevinDynamicsReporter id = uuid.uuid4() - reporter = _SimulationReporter(f"{prep_temp_dir}/test_{id}.h5") + reporter = LangevinDynamicsReporter(f"{prep_temp_dir}/test_{id}.h5") - integrator = LangevinIntegrator(reporter=reporter, report_frequency=100) + integrator = LangevinIntegrator(reporter=reporter, report_interval=100) integrator.run( sampler_state, thermodynamic_state, - n_steps=2000, + number_of_steps=1000, nbr_list=nbr_list, progress_bar=True, ) + + +@pytest.mark.skip(reason="Tests takes too long") +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Test takes too long.") +def test_ideal_gas(prep_temp_dir): + from openmmtools.testsystems import IdealGas + from openmm import unit + + n_particles = 216 + temperature = 298 * unit.kelvin + pressure = 1 * unit.atmosphere + mass = unit.Quantity(39.9, unit.gram / unit.mole) + + ideal_gas = IdealGas( + nparticles=n_particles, temperature=temperature, pressure=pressure + ) + + from chiron.potential import IdealGasPotential + from chiron.utils import PRNG + import jax.numpy as jnp + + # + cutoff = 0.0 * unit.nanometer + ideal_gas_potential = IdealGasPotential(ideal_gas.topology) + + from chiron.states import SamplerState, ThermodynamicState + + # define the thermodynamic state + thermodynamic_state = ThermodynamicState( + potential=ideal_gas_potential, + temperature=temperature, + pressure=pressure, + ) + + PRNG.set_seed(1234) + + # define the sampler state + sampler_state = SamplerState( + positions=ideal_gas.positions, + current_PRNG_key=PRNG.get_random_key(), + box_vectors=ideal_gas.system.getDefaultPeriodicBoxVectors(), + ) + + from chiron.neighbors import PairListNsqrd, OrthogonalPeriodicSpace + + # define the pair list for an orthogonal periodic space + # since particles are non-interacting, this will not really do much + # but will appropriately wrap particles in space + nbr_list = PairListNsqrd(OrthogonalPeriodicSpace(), cutoff=cutoff) + nbr_list.build_from_state(sampler_state) + + from chiron.reporters import MCReporter + + # initialize a reporter to save the simulation data + filename = "test_mc_ideal_gas.h5" + import os + + if os.path.isfile(filename): + os.remove(filename) + reporter = MCReporter(filename, 1) + + from chiron.mcmc import ( + MonteCarloDisplacementMove, + MonteCarloBarostatMove, + MoveSchedule, + MCMCSampler, + ) + + mc_displacement_move = MonteCarloDisplacementMove( + displacement_sigma=0.1 * unit.nanometer, + number_of_moves=10, + reporter=reporter, + autotune=True, + autotune_interval=100, + ) + + mc_barostat_move = MonteCarloBarostatMove( + volume_max_scale=0.2, + number_of_moves=100, + reporter=reporter, + autotune=True, + autotune_interval=100, + ) + move_set = MoveSchedule( + [ + ("MonteCarloDisplacementMove", mc_displacement_move), + ("MonteCarloBarostatMove", mc_barostat_move), + ] + ) + + sampler = MCMCSampler(move_set) + sampler.run( + sampler_state, thermodynamic_state, n_iterations=10, nbr_list=nbr_list + ) # how many times to repeat + + volume = reporter.get_property("volume") + + # get expectations + ideal_volume = ideal_gas.get_volume_expectation(thermodynamic_state) + ideal_volume_std = ideal_gas.get_volume_standard_deviation(thermodynamic_state) + + print(ideal_volume, ideal_volume_std) + + volume_mean = jnp.mean(jnp.array(volume)) * unit.nanometer**3 + volume_std = jnp.std(jnp.array(volume)) * unit.nanometer**3 + + print(volume_mean, volume_std) + + ideal_density = mass * n_particles / unit.AVOGADRO_CONSTANT_NA / ideal_volume + measured_density = mass * n_particles / unit.AVOGADRO_CONSTANT_NA / volume_mean + + assert jnp.isclose( + ideal_density.value_in_unit(unit.kilogram / unit.meter**3), + measured_density.value_in_unit(unit.kilogram / unit.meter**3), + atol=1e-1, + ) + # see if within 5% of ideal volume + assert abs(ideal_volume - volume_mean) / ideal_volume < 0.05 + + # see if within 10% of the ideal standard deviation of the volume + assert abs(ideal_volume_std - volume_std) / ideal_volume_std < 0.1 diff --git a/chiron/tests/test_integrators.py b/chiron/tests/test_integrators.py index 341c987..9ed5b4e 100644 --- a/chiron/tests/test_integrators.py +++ b/chiron/tests/test_integrators.py @@ -42,10 +42,12 @@ def test_langevin_dynamics(prep_temp_dir, provide_testsystems_and_potentials): reporter = LangevinDynamicsReporter() - integrator = LangevinIntegrator(reporter=reporter, report_frequency=1) - integrator.run( + integrator = LangevinIntegrator( + reporter=reporter, report_interval=1, refresh_velocities=True + ) + updated_sampler_state, updated_nbr_list = integrator.run( sampler_state, thermodynamic_state, - n_steps=20, + number_of_steps=20, ) i = i + 1 diff --git a/chiron/tests/test_mcmc.py b/chiron/tests/test_mcmc.py index 82ef9a0..bef08ba 100644 --- a/chiron/tests/test_mcmc.py +++ b/chiron/tests/test_mcmc.py @@ -41,7 +41,7 @@ def test_sample_from_harmonic_osciallator(prep_temp_dir): PRNG.set_seed(1234) sampler_state = SamplerState( - x0=ho.positions, current_PRNG_key=PRNG.get_random_key() + positions=ho.positions, current_PRNG_key=PRNG.get_random_key() ) from chiron.integrators import LangevinIntegrator @@ -53,13 +53,16 @@ def test_sample_from_harmonic_osciallator(prep_temp_dir): reporter = LangevinDynamicsReporter() integrator = LangevinIntegrator( - stepsize=2 * unit.femtosecond, reporter=reporter, report_frequency=1 + timestep=2 * unit.femtosecond, + reporter=reporter, + report_interval=1, + refresh_velocities=True, ) integrator.run( sampler_state, thermodynamic_state, - n_steps=5, + number_of_steps=5, ) integrator.reporter.flush_buffer() import jax.numpy as jnp @@ -122,7 +125,32 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_LangevinDynamics BaseReporter.set_directory(prep_temp_dir) simulation_reporter = LangevinDynamicsReporter(1) - langevin_move = LangevinDynamicsMove(nr_of_steps=10, reporter=simulation_reporter) + + # the following will reinitialize the velocities for each iteration + langevin_move = LangevinDynamicsMove( + number_of_steps=10, refresh_velocities=True, reporter=simulation_reporter + ) + + move_set = MoveSchedule([("LangevinMove", langevin_move)]) + + # Initalize the sampler + sampler = MCMCSampler(move_set) + + # Run the sampler with the thermodynamic state and sampler state and return the sampler state + sampler.run( + sampler_state, thermodynamic_state, n_iterations=2 + ) # how many times to repeat + + # the following will use the initialize velocities function + from chiron.utils import initialize_velocities + + sampler_state.velocities = initialize_velocities( + thermodynamic_state.temperature, ho.topology, sampler_state._current_PRNG_key + ) + + langevin_move = LangevinDynamicsMove( + number_of_steps=10, refresh_velocities=False, reporter=simulation_reporter + ) move_set = MoveSchedule([("LangevinMove", langevin_move)]) @@ -146,7 +174,7 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_MetropolisDispla """ from openmm import unit from chiron.potential import HarmonicOscillatorPotential - from chiron.mcmc import MetropolisDisplacementMove, MoveSchedule, MCMCSampler + from chiron.mcmc import MonteCarloDisplacementMove, MoveSchedule, MCMCSampler # Initalize the testsystem from openmmtools.testsystems import HarmonicOscillator @@ -178,14 +206,14 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_MetropolisDispla BaseReporter.set_directory(wd) simulation_reporter = MCReporter(1) - mc_displacement_move = MetropolisDisplacementMove( - nr_of_moves=10, + mc_displacement_move = MonteCarloDisplacementMove( + number_of_moves=10, displacement_sigma=0.1 * unit.angstrom, atom_subset=[0], reporter=simulation_reporter, ) - move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) + move_set = MoveSchedule([("MonteCarloDisplacementMove", mc_displacement_move)]) # Initalize the sampler sampler = MCMCSampler(move_set) @@ -206,7 +234,7 @@ def test_sample_from_harmonic_osciallator_array_with_MCMC_classes_and_Metropolis sampler states, and uses the Metropolis displacement move in an MCMC sampling scheme. """ from openmm import unit - from chiron.mcmc import MetropolisDisplacementMove, MoveSchedule, MCMCSampler + from chiron.mcmc import MonteCarloDisplacementMove, MoveSchedule, MCMCSampler # Initalize the testsystem from openmmtools.testsystems import HarmonicOscillatorArray @@ -240,14 +268,14 @@ def test_sample_from_harmonic_osciallator_array_with_MCMC_classes_and_Metropolis simulation_reporter = MCReporter(1) - mc_displacement_move = MetropolisDisplacementMove( - nr_of_moves=10, + mc_displacement_move = MonteCarloDisplacementMove( + number_of_moves=10, displacement_sigma=0.1 * unit.angstrom, atom_subset=None, reporter=simulation_reporter, ) - move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) + move_set = MoveSchedule([("MonteCarloDisplacementMove", mc_displacement_move)]) # Initalize the sampler sampler = MCMCSampler(move_set) @@ -296,6 +324,134 @@ def test_thermodynamic_state_inputs(): ThermodynamicState(potential=harmonic_potential, pressure=100 * unit.atmosphere) +def test_mc_barostat_parameter_setting(): + import jax.numpy as jnp + from chiron.mcmc import MonteCarloBarostatMove + + barostat_move = MonteCarloBarostatMove( + volume_max_scale=0.1, + number_of_moves=1, + ) + + assert barostat_move.volume_max_scale == 0.1 + assert barostat_move.number_of_moves == 1 + + +def test_mc_barostat(prep_temp_dir): + import jax.numpy as jnp + + from chiron.reporters import MCReporter, BaseReporter + + wd = prep_temp_dir.join(f"_test_{uuid.uuid4()}") + BaseReporter.set_directory(wd) + simulation_reporter = MCReporter(1) + + from chiron.mcmc import MonteCarloBarostatMove + + barostat_move = MonteCarloBarostatMove( + volume_max_scale=0.1, + number_of_moves=10, + reporter=simulation_reporter, + report_interval=1, + ) + + from chiron.potential import IdealGasPotential + from openmm import unit + + positions = ( + jnp.array( + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 1, 0], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + ] + ) + * unit.nanometer + ) + box_vectors = ( + jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + * unit.nanometer + ) + volume = box_vectors[0][0] * box_vectors[1][1] * box_vectors[2][2] + + from openmm.app import Topology, Element + + topology = Topology() + element = Element.getBySymbol("Ar") + chain = topology.addChain() + residue = topology.addResidue("system", chain) + for i in range(positions.shape[0]): + topology.addAtom("Ar", element, residue) + + ideal_gas_potential = IdealGasPotential(topology) + + from chiron.states import SamplerState, ThermodynamicState + from chiron.utils import PRNG + + PRNG.set_seed(1234) + + # define the sampler state + sampler_state = SamplerState( + positions=positions, + box_vectors=box_vectors, + current_PRNG_key=PRNG.get_random_key(), + ) + + # define the thermodynamic state + thermodynamic_state = ThermodynamicState( + potential=ideal_gas_potential, + temperature=300 * unit.kelvin, + pressure=1.0 * unit.atmosphere, + ) + + from chiron.neighbors import PairListNsqrd, OrthogonalPeriodicSpace + + # since particles are non-interacting and we will not displacece them, the pair list basically + # does nothing in this case. + nbr_list = PairListNsqrd(OrthogonalPeriodicSpace(), cutoff=0 * unit.nanometer) + + sampler_state, thermodynamic_state, nbr_list = barostat_move.update( + sampler_state, thermodynamic_state, nbr_list + ) + potential_energies = simulation_reporter.get_property("potential_energy") + volumes = simulation_reporter.get_property("volume") + + # ideal gas treatment, so stored energy will only be a + # consequence of pressure, volume, and temperature + from loguru import logger as log + + log.debug(f"PE {potential_energies * unit.kilojoules_per_mole}") + log.debug(thermodynamic_state.pressure) + log.debug(thermodynamic_state.beta) + log.debug(volumes) + log.debug(volumes * unit.nanometer**3) + + # assert that the PE is always zero + assert potential_energies[0] == 0 + assert potential_energies[-1] == 0 + + # the reduced potential will only be a consequence of the pressure, volume, and temperature + + assert jnp.isclose( + thermodynamic_state.get_reduced_potential(sampler_state), + ( + thermodynamic_state.pressure + * thermodynamic_state.beta + * (volumes[-1] * unit.nanometer**3) + ), + 1e-3, + ) + + print(barostat_move.statistics["n_accepted"]) + assert barostat_move.statistics["n_proposed"] == 10 + assert barostat_move.statistics["n_accepted"] == 8 + + def test_sample_from_joint_distribution_of_two_HO_with_local_moves_and_MC_updates(): # define two harmonic oscillators with different spring constants and equilibrium positions # sample from the joint distribution of the two HO using local langevin moves diff --git a/chiron/tests/test_minization.py b/chiron/tests/test_minization.py index cf0cf4e..4faa41d 100644 --- a/chiron/tests/test_minization.py +++ b/chiron/tests/test_minization.py @@ -3,7 +3,7 @@ def test_minimization(): import jax.numpy as jnp from chiron.states import SamplerState - from chiron.neighbors import PairList, OrthogonalPeriodicSpace + from chiron.neighbors import PairListNsqrd, OrthogonalPeriodicSpace from openmm import unit # initialize testystem @@ -25,12 +25,14 @@ def test_minimization(): box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors(), ) # use parilist - nbr_list = PairList(OrthogonalPeriodicSpace(), cutoff=cutoff) + nbr_list = PairListNsqrd(OrthogonalPeriodicSpace(), cutoff=cutoff) nbr_list.build_from_state(sampler_state) # compute intial energy with and without pairlist - initial_e_with_nbr_list = lj_potential.compute_energy(sampler_state.x0, nbr_list) - initial_e_without_nbr_list = lj_potential.compute_energy(sampler_state.x0) + initial_e_with_nbr_list = lj_potential.compute_energy( + sampler_state.positions, nbr_list + ) + initial_e_without_nbr_list = lj_potential.compute_energy(sampler_state.positions) print(f"initial_e_with_nbr_list: {initial_e_with_nbr_list}") print(f"initial_e_without_nbr_list: {initial_e_without_nbr_list}") assert not jnp.isclose( @@ -38,7 +40,7 @@ def test_minimization(): ), "initial_e_with_nbr_list and initial_e_without_nbr_list should not be close" # minimize energy for 0 steps results = minimize_energy( - sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=0 + sampler_state.positions, lj_potential.compute_energy, nbr_list, maxiter=0 ) # check that the minimization did not change the energy @@ -48,7 +50,7 @@ def test_minimization(): min_x, nbr_list ) after_0_steps_minimization_e_without_nbr_list = lj_potential.compute_energy( - sampler_state.x0 + sampler_state.positions ) print( f"after_0_steps_minimization_e_with_nbr_list: {after_0_steps_minimization_e_with_nbr_list}" @@ -67,7 +69,7 @@ def test_minimization(): # after 100 steps of minimization steps = 100 results = minimize_energy( - sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=steps + sampler_state.positions, lj_potential.compute_energy, nbr_list, maxiter=steps ) min_x = results.params e_min = lj_potential.compute_energy(min_x, nbr_list) @@ -86,7 +88,7 @@ def test_minimize_two_particles(): import jax.numpy as jnp from chiron.states import SamplerState - from chiron.neighbors import PairList, OrthogonalPeriodicSpace + from chiron.neighbors import PairListNsqrd, OrthogonalPeriodicSpace from openmm import unit from chiron.potential import LJPotential @@ -103,13 +105,13 @@ def test_minimize_two_particles(): # define the sampler state sampler_state = SamplerState( - x0=coordinates * unit.nanometer, + positions=coordinates * unit.nanometer, current_PRNG_key=PRNG.get_random_key(), box_vectors=jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) * unit.nanometer, ) - pair_list = PairList(OrthogonalPeriodicSpace(), cutoff=cutoff) + pair_list = PairListNsqrd(OrthogonalPeriodicSpace(), cutoff=cutoff) pair_list.build_from_state(sampler_state) e_start = lj_potential.compute_energy(coordinates, pair_list) diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index 0f153be..533d879 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -25,7 +25,9 @@ def setup_sampler() -> Tuple[NeighborListNsqrd, MultiStateSampler]: OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 ) - lang_move = LangevinDynamicsMove(stepsize=1.0 * unit.femtoseconds, nr_of_steps=100) + lang_move = LangevinDynamicsMove( + timestep=1.0 * unit.femtoseconds, number_of_steps=100 + ) BaseReporter.set_directory("multistate_test") reporter = MultistateReporter() reporter.reset_reporter_file() @@ -183,21 +185,24 @@ def test_multistate_minimize(ho_multistate_sampler_multiple_minima: MultiStateSa ho_multistate_sampler_multiple_minima.minimize() assert np.allclose( - ho_multistate_sampler_multiple_minima.sampler_states[0].x0, + ho_multistate_sampler_multiple_minima.sampler_states[0].positions, np.array([[0.0, 0.0, 0.0]]), ) assert np.allclose( - ho_multistate_sampler_multiple_minima.sampler_states[1].x0, + ho_multistate_sampler_multiple_minima.sampler_states[1].positions, np.array([[0.05, 0.0, 0.0]]), atol=1e-2, ) assert np.allclose( - ho_multistate_sampler_multiple_minima.sampler_states[2].x0, + ho_multistate_sampler_multiple_minima.sampler_states[2].positions, np.array([[0.1, 0.0, 0.0]]), atol=1e-2, ) +@pytest.mark.skip( + reason="Multistate code still needs to be modified in the multistage branch" +) def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): """ Test function for running the multistate sampler. @@ -221,12 +226,13 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): ho_sampler.run(n_iteratinos) # check that we have the correct number of iterations, replicas and states - assert ho_sampler.iteration == n_iteratinos - assert ho_sampler._iteration == n_iteratinos + assert ho_sampler.iteration == n_iterations + assert ho_sampler._iteration == n_iterations assert ho_sampler.n_replicas == 4 assert ho_sampler.n_states == 4 u_kn = ho_sampler._reporter.get_property("u_kn") + assert u_kn.shape == (n_iteratinos, 4, 4) # check that the free energies are correct print(ho_sampler.analytical_f_i) diff --git a/chiron/tests/test_pairs.py b/chiron/tests/test_pairs.py index fb2bf2c..60df5ad 100644 --- a/chiron/tests/test_pairs.py +++ b/chiron/tests/test_pairs.py @@ -2,7 +2,7 @@ import pytest from chiron.neighbors import ( NeighborListNsqrd, - PairList, + PairListNsqrd, OrthogonalPeriodicSpace, OrthogonalNonperiodicSpace, ) @@ -13,78 +13,47 @@ def test_orthogonal_periodic_displacement(): # test that the incorrect box shapes throw an exception - with pytest.raises(ValueError): - space = OrthogonalPeriodicSpace(jnp.array([10.0, 10.0, 10.0])) - # test that incorrect units throw an exception - with pytest.raises(ValueError): - space = OrthogonalPeriodicSpace( - unit.Quantity( - jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]), - unit.radians, - ) - ) - - space = OrthogonalPeriodicSpace( - jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) - ) - # test that the box vectors are set correctly - assert jnp.all( - space.box_vectors - == jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) - ) - - # test that the box lengths for an orthogonal box are set correctly - assert jnp.all(space._box_lengths == jnp.array([10.0, 10.0, 10.0])) + space = OrthogonalPeriodicSpace() + box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) # test calculation of the displacement_vector and distance between two points p1 = jnp.array([[0, 0, 0], [0, 0, 0]]) p2 = jnp.array([[1, 0, 0], [6, 0, 0]]) - r_ij, distance = space.displacement(p1, p2) + r_ij, distance = space.displacement(p1, p2, box_vectors) assert jnp.all(r_ij == jnp.array([[-1.0, 0.0, 0.0], [4.0, 0.0, 0.0]])) assert jnp.all(distance == jnp.array([1, 4])) # test that the periodic wrapping works as expected - wrapped_x = space.wrap(jnp.array([11, 0, 0])) + wrapped_x = space.wrap(jnp.array([11, 0, 0]), box_vectors) assert jnp.all(wrapped_x == jnp.array([1, 0, 0])) - wrapped_x = space.wrap(jnp.array([-1, 0, 0])) + wrapped_x = space.wrap(jnp.array([-1, 0, 0]), box_vectors) assert jnp.all(wrapped_x == jnp.array([9, 0, 0])) - wrapped_x = space.wrap(jnp.array([5, 0, 0])) + wrapped_x = space.wrap(jnp.array([5, 0, 0]), box_vectors) assert jnp.all(wrapped_x == jnp.array([5, 0, 0])) - wrapped_x = space.wrap(jnp.array([5, 12, -1])) + wrapped_x = space.wrap(jnp.array([5, 12, -1]), box_vectors) assert jnp.all(wrapped_x == jnp.array([5, 2, 9])) - # test the setter for the box vectors - space.box_vectors = jnp.array( - [[10.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 30.0]] - ) - assert jnp.all( - space._box_vectors - == jnp.array([[10.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 30.0]]) - ) - assert jnp.all(space._box_lengths == jnp.array([10.0, 20.0, 30.0])) - def test_orthogonal_nonperiodic_displacement(): - space = OrthogonalNonperiodicSpace( - jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) - ) + space = OrthogonalNonperiodicSpace() + box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) p1 = jnp.array([[0, 0, 0], [0, 0, 0]]) p2 = jnp.array([[1, 0, 0], [6, 0, 0]]) - r_ij, distance = space.displacement(p1, p2) + r_ij, distance = space.displacement(p1, p2, box_vectors) assert jnp.all(r_ij == jnp.array([[-1.0, 0.0, 0.0], [-6.0, 0.0, 0.0]])) assert jnp.all(distance == jnp.array([1, 6])) - wrapped_x = space.wrap(jnp.array([11, -1, 2])) + wrapped_x = space.wrap(jnp.array([11, -1, 2]), box_vectors) assert jnp.all(wrapped_x == jnp.array([11, -1, 2])) @@ -100,34 +69,33 @@ def test_neighborlist_pair(): PRNG.set_seed(1234) state = SamplerState( - x0=unit.Quantity(coordinates, unit.nanometer), + positions=unit.Quantity(coordinates, unit.nanometer), current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) space = OrthogonalPeriodicSpace() - cutoff = 1.1 - skin = 0.1 + cutoff = 1.1 * unit.nanometer + skin = 0.1 * unit.nanometer nbr_list = NeighborListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), - skin=unit.Quantity(skin, unit.nanometer), + cutoff=cutoff, + skin=skin, n_max_neighbors=5, ) assert nbr_list.cutoff == cutoff assert nbr_list.skin == skin - assert nbr_list.cutoff_and_skin == cutoff + skin assert nbr_list.n_max_neighbors == 5 nbr_list.build_from_state(state) - assert jnp.all(nbr_list.ref_coordinates == coordinates) + assert jnp.all(nbr_list.ref_positions == coordinates) assert jnp.all(nbr_list.box_vectors == box_vectors) assert nbr_list.is_built == True - nbr_list.build(state.x0, state.box_vectors) + nbr_list.build(state.positions, state.box_vectors) - assert jnp.all(nbr_list.ref_coordinates == coordinates) + assert jnp.all(nbr_list.ref_positions == coordinates) assert jnp.all(nbr_list.box_vectors == box_vectors) assert nbr_list.is_built == True @@ -195,12 +163,12 @@ def test_neighborlist_pair(): def test_inputs(): space = OrthogonalPeriodicSpace() # every particle should interact with every other particle - cutoff = 2.1 - skin = 0.1 + cutoff = 2.1 * unit.nanometer + skin = 0.1 * unit.nanometer nbr_list = NeighborListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), - skin=unit.Quantity(skin, unit.nanometer), + cutoff=cutoff, + skin=skin, n_max_neighbors=5, ) # check that the state is of the correct type @@ -213,7 +181,7 @@ def test_inputs(): PRNG.set_seed(1234) state = SamplerState( - x0=unit.Quantity(coordinates, unit.nanometer), + positions=unit.Quantity(coordinates, unit.nanometer), current_PRNG_key=PRNG.get_random_key(), box_vectors=None, ) @@ -247,24 +215,24 @@ def test_inputs(): with pytest.raises(TypeError): NeighborListNsqrd( 123, - cutoff=unit.Quantity(cutoff, unit.nanometer), - skin=unit.Quantity(skin, unit.nanometer), + cutoff=cutoff, + skin=skin, n_max_neighbors=5, ) # check units of cutoff with pytest.raises(ValueError): NeighborListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.radian), - skin=unit.Quantity(skin, unit.nanometer), + cutoff=unit.Quantity(123, unit.radian), + skin=unit.Quantity(123, unit.nanometer), n_max_neighbors=5, ) # check units of skin with pytest.raises(ValueError): NeighborListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), - skin=unit.Quantity(skin, unit.radian), + cutoff=unit.Quantity(123, unit.nanometer), + skin=unit.Quantity(123, unit.radian), n_max_neighbors=5, ) @@ -287,19 +255,19 @@ def test_neighborlist_pair_multiple_particles(): PRNG.set_seed(1234) state = SamplerState( - x0=unit.Quantity(coordinates, unit.nanometer), + positions=unit.Quantity(coordinates, unit.nanometer), current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) space = OrthogonalPeriodicSpace() # every particle should interact with every other particle - cutoff = 2.1 - skin = 0.1 + cutoff = 2.1 * unit.nanometer + skin = 0.1 * unit.nanometer nbr_list = NeighborListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), - skin=unit.Quantity(skin, unit.nanometer), + cutoff=cutoff, + skin=skin, n_max_neighbors=5, ) nbr_list.build_from_state(state) @@ -310,12 +278,12 @@ def test_neighborlist_pair_multiple_particles(): assert jnp.all(n_interacting == jnp.array([7, 6, 5, 4, 3, 2, 1, 0])) # every particle should be in the nieghbor list, but only a subset in the interacting range - cutoff = 1.1 - skin = 1.1 + cutoff = 1.1 * unit.nanometer + skin = 1.1 * unit.nanometer nbr_list = NeighborListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), - skin=unit.Quantity(skin, unit.nanometer), + cutoff=cutoff, + skin=skin, n_max_neighbors=5, ) nbr_list.build_from_state(state) @@ -342,7 +310,7 @@ def test_neighborlist_pair_multiple_particles(): ) ) # test passing coordinates and box vectors directly - nbr_list.build(state.x0, state.box_vectors) + nbr_list.build(state.positions, state.box_vectors) assert jnp.all(nbr_list.n_neighbors == jnp.array([7, 6, 5, 4, 3, 2, 1, 0])) @@ -362,17 +330,16 @@ def test_pairlist_pair(): PRNG.set_seed(1234) state = SamplerState( - x0=unit.Quantity(coordinates, unit.nanometer), + positions=unit.Quantity(coordinates, unit.nanometer), current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) space = OrthogonalPeriodicSpace() - cutoff = 1.1 - skin = 0.1 - pair_list = PairList( + cutoff = 1.1 * unit.nanometer + pair_list = PairListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), + cutoff=cutoff, ) assert pair_list.cutoff == cutoff @@ -382,7 +349,7 @@ def test_pairlist_pair(): assert jnp.all(pair_list.reduction_mask == jnp.array([[True], [False]])) assert pair_list.is_built == True - n_pairs, all_pairs, mask, dist, displacement = pair_list.calculate(coordinates) + n_pairs, all_pairs, mask, dist, displacement = pair_list.calculate(state.positions) assert jnp.all(n_pairs == jnp.array([1, 0])) assert jnp.all(all_pairs.shape == (2, 1)) @@ -394,10 +361,49 @@ def test_pairlist_pair(): assert pair_list.check(coordinates) == False - coordinates = coordinates = jnp.array([[0, 0, 0], [1, 0, 0], [1, 1, 0]]) + coordinates = jnp.array([[0, 0, 0], [1, 0, 0], [1, 1, 0]]) # we changed number of particles, and thus should rebuild assert pair_list.check(coordinates) == True + # test without using a cutoff + # this will be exactly the same as with a cutoff, given it is just two particles + cutoff = None + pair_list = PairListNsqrd( + space, + cutoff=None, + ) + pair_list.build_from_state(state) + + assert pair_list.cutoff == cutoff + n_pairs, all_pairs, mask, dist, displacement = pair_list.calculate(state.positions) + assert jnp.all(n_pairs == jnp.array([1, 0])) + assert jnp.all(all_pairs.shape == (2, 1)) + assert jnp.all(all_pairs == jnp.array([[1], [0]])) + assert jnp.all(mask == jnp.array([[1], [0]])) + assert jnp.all(dist == jnp.array([[1.0], [1.0]])) + assert displacement.shape == (2, 1, 3) + assert jnp.all(displacement == jnp.array([[[-1.0, 0.0, 0.0]], [[1.0, 0.0, 0.0]]])) + + # test the difference between a short cutoff with no interactions and the same + # system with no cutoff. + + # this test ultimately have no particles in the neighbor list + # because the cutoff is really short + cutoff = 0.5 * unit.nanometer + pair_list = PairListNsqrd(space, cutoff=cutoff) + + assert pair_list.cutoff == cutoff + pair_list.build_from_state(state) + n_pairs, all_pairs, mask, dist, displacement = pair_list.calculate(state.positions) + # the mask will all be false because the cutoff is too short + assert jnp.all(mask == jnp.array([[0], [0]])) + + # set the cutoff to None, and calculate all pairs in the box + pair_list.cutoff = None + n_pairs, all_pairs, mask, dist, displacement = pair_list.calculate(state.positions) + # the mask will have the single pair in the box be true + assert jnp.all(mask == jnp.array([[1], [0]])) + def test_pair_list_multiple_particles(): # test the pair list for multiple particles @@ -416,18 +422,18 @@ def test_pair_list_multiple_particles(): PRNG.set_seed(1234) state = SamplerState( - x0=unit.Quantity(coordinates, unit.nanometer), + positions=unit.Quantity(coordinates, unit.nanometer), current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) space = OrthogonalPeriodicSpace() # every particle should interact with every other particle - cutoff = 2.1 - skin = 0.1 - pair_list = PairList( + cutoff = 2.1 * unit.nanometer + skin = 0.1 * unit.nanometer + pair_list = PairListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), + cutoff=cutoff, ) pair_list.build_from_state(state) @@ -454,8 +460,8 @@ def test_pair_list_multiple_particles(): # compare to nbr_list nbr_list = NeighborListNsqrd( space, - cutoff=unit.Quantity(cutoff, unit.nanometer), - skin=unit.Quantity(skin, unit.nanometer), + cutoff=cutoff, + skin=skin, n_max_neighbors=20, ) nbr_list.build_from_state(state) diff --git a/chiron/tests/test_potential.py b/chiron/tests/test_potential.py index 230fa1c..3013df1 100644 --- a/chiron/tests/test_potential.py +++ b/chiron/tests/test_potential.py @@ -178,7 +178,7 @@ def test_lennard_jones(): positions = jnp.array([[0, 0, 0], [i * 0.25 * 2 ** (1 / 6), 0, 0]]) state = SamplerState( - x0=unit.Quantity(positions, unit.nanometer), + positions=unit.Quantity(positions, unit.nanometer), current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) diff --git a/chiron/tests/test_states.py b/chiron/tests/test_states.py index 4499f18..e437fad 100644 --- a/chiron/tests/test_states.py +++ b/chiron/tests/test_states.py @@ -31,14 +31,14 @@ def test_initialize_state(): sampler_state = SamplerState(ho.positions, current_PRNG_key=PRNG.get_random_key()) assert jnp.allclose( - sampler_state.x0, + sampler_state.positions, jnp.array([[0.0, 0.0, 0.0]]), ) def test_sampler_state_conversion(): """Test converting a sampler state to jnp arrays. - Note, testing the conversion of x0, where internal unit length is nanometers + Note, testing the conversion of positions, where internal unit length is nanometers and thus output jnp.arrays (with units dropped) should reflect this. """ from chiron.states import SamplerState @@ -54,7 +54,7 @@ def test_sampler_state_conversion(): ) assert jnp.allclose( - sampler_state.x0, + sampler_state.positions, jnp.array([[10.0, 10.0, 10.0]]), ) @@ -64,7 +64,7 @@ def test_sampler_state_conversion(): ) assert jnp.allclose( - sampler_state.x0, + sampler_state.positions, jnp.array([[1.0, 1.0, 1.0]]), ) @@ -81,11 +81,11 @@ def test_sampler_state_inputs(): # test input of positions # should have units with pytest.raises(TypeError): - SamplerState(x0=jnp.array([1, 2, 3])) + SamplerState(positions=jnp.array([1, 2, 3])) # throw and error because of incompatible units with pytest.raises(ValueError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.radians), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.radians), current_PRNG_key=PRNG.get_random_key(), ) @@ -93,14 +93,14 @@ def test_sampler_state_inputs(): # velocities should have units with pytest.raises(TypeError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), current_PRNG_key=PRNG.get_random_key(), velocities=jnp.array([1, 2, 3]), ) # velocities should have units of distance/time with pytest.raises(ValueError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), current_PRNG_key=PRNG.get_random_key(), velocities=unit.Quantity(jnp.array([1, 2, 3]), unit.nanometers), ) @@ -109,14 +109,14 @@ def test_sampler_state_inputs(): # box_vectors should have units with pytest.raises(TypeError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), current_PRNG_key=PRNG.get_random_key(), box_vectors=jnp.array([1, 2, 3]), ) # box_vectors should have units of distance with pytest.raises(ValueError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity( jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), unit.radians @@ -125,7 +125,7 @@ def test_sampler_state_inputs(): # check to see that the size of the box vectors are correct with pytest.raises(ValueError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity( jnp.array([[1, 0, 0], [0, 1, 0]]), unit.nanometers @@ -140,7 +140,7 @@ def test_sampler_state_inputs(): # check openmm_box conversion: state = SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), current_PRNG_key=PRNG.get_random_key(), box_vectors=openmm_box, ) @@ -155,7 +155,7 @@ def test_sampler_state_inputs(): # openmm box vectors end up as a list with contents; check to make sure we capture an error if we pass a bad list with pytest.raises(TypeError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + positions=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), current_PRNG_key=PRNG.get_random_key(), box_vectors=[123], ) diff --git a/chiron/tests/test_testsystems.py b/chiron/tests/test_testsystems.py index 3506b77..5b16ef8 100644 --- a/chiron/tests/test_testsystems.py +++ b/chiron/tests/test_testsystems.py @@ -1,3 +1,13 @@ +import pytest + + +@pytest.fixture(scope="session") +def prep_temp_dir(tmpdir_factory): + """Create a temporary directory for the test.""" + tmpdir = tmpdir_factory.mktemp("test_testsystems") + return tmpdir + + def compute_openmm_reference_energy(testsystem, positions): from openmm import unit from openmm.app import Simulation @@ -186,7 +196,7 @@ def test_LJ_fluid(): PRNG.set_seed(1234) state = SamplerState( - x0=lj_openmm.positions, + positions=lj_openmm.positions, current_PRNG_key=PRNG.get_random_key(), box_vectors=lj_openmm.system.getDefaultPeriodicBoxVectors(), ) @@ -200,7 +210,7 @@ def test_LJ_fluid(): lj_openmm.topology, sigma=sigma, epsilon=epsilon, cutoff=cutoff ) - e_chiron_energy = lj_chiron.compute_energy(state.x0, nbr_list) + e_chiron_energy = lj_chiron.compute_energy(state.positions, nbr_list) e_openmm_energy = compute_openmm_reference_energy( lj_openmm, lj_openmm.positions ) diff --git a/chiron/tests/test_utils.py b/chiron/tests/test_utils.py index 59f6d10..7ce4474 100644 --- a/chiron/tests/test_utils.py +++ b/chiron/tests/test_utils.py @@ -64,11 +64,14 @@ def test_reporter(prep_temp_dir, ho_multistate_sampler_multiple_ks): reporter = LangevinDynamicsReporter("langevin_test") reporter.reset_reporter_file() - integrator = LangevinIntegrator(reporter=reporter, report_frequency=1) + integrator = LangevinIntegrator( + reporter=reporter, + report_interval=1, + ) integrator.run( sampler_state, thermodynamic_state, - n_steps=20, + number_of_steps=20, ) import numpy as np diff --git a/chiron/utils.py b/chiron/utils.py index a9a7da0..41daf33 100644 --- a/chiron/utils.py +++ b/chiron/utils.py @@ -11,20 +11,21 @@ def __init__(self) -> None: """ A PRNG class that can be used to generate random numbers in JAX. The intended use case is to initialize new PRN streams in the `SamplerState` class. - + Example: -------- from chiron.utils import PRNG from chiron.states import SamplerState from openmmtools.testsystems import HarmonicOscillator - + ho = HarmonicOscillator() PRNG.set_seed(1234) sampler_state = [SamplerState(ho.positions, PRNG.get_random_key()) for _ in x0s] - + """ - + pass + @classmethod def set_seed(cls, seed: int) -> None: cls._seed = seed @@ -37,6 +38,25 @@ def get_random_key(cls) -> int: return subkey +def get_full_path(relative_path: str) -> str: + """Get the fill path of a file that is defined relative to the chiron module root directory. + + Parameters + ---------- + relative_path : str + The relative path of the file. + + Returns + ------- + str + The full path of the file. + """ + from importlib.resources import files + + _MODULE_ROOT = files("chiron") + return f"{_MODULE_ROOT}/../{relative_path}" + + def get_data_file_path(relative_path: str) -> str: """Get the full path to one of the reference files in testsystems. In the source distribution, these files are in ``chiron/data/``, @@ -85,9 +105,40 @@ def get_nr_of_particles(topology: Topology) -> int: def get_list_of_mass(topology: Topology) -> unit.Quantity: """Get the mass of the system from the topology.""" - from simtk import unit + from openmm import unit mass = [] for atom in topology.atoms(): mass.append(atom.element.mass.value_in_unit(unit.amu)) return mass * unit.amu + + +def initialize_velocities( + temperature: unit.Quantity, topology: Topology, key +) -> unit.Quantity: + """Initialize the velocities from the Maxwell-Boltzmann distribution at the given temperature. + + Parameters + ---------- + temperature : unit.Quantity + The temperature of the system. + topology : Topology + The topology of the system. + key : int + The PRNG key. + + """ + from openmm import unit + import jax.numpy as jnp + + mass = get_list_of_mass(topology) + + kB = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA + + kbT_unitless = (kB * temperature).value_in_unit_system(unit.md_unit_system) + mass_unitless = jnp.array(mass.value_in_unit_system(unit.md_unit_system))[:, None] + sigma_v = jnp.sqrt(kbT_unitless / mass_unitless) + + v0 = sigma_v * random.normal(key, [len(mass), 3]) + + return v0 * unit.nanometer / unit.picosecond