diff --git a/Examples/LJ_langevin.py b/Examples/LJ_langevin.py index e2d66ca..d769b1a 100644 --- a/Examples/LJ_langevin.py +++ b/Examples/LJ_langevin.py @@ -43,7 +43,7 @@ # build the neighbor list from the sampler state nbr_list.build_from_state(sampler_state) -from chiron.reporters import SimulationReporter +from chiron.reporters import _SimulationReporter # initialize a reporter to save the simulation data filename = "test_lj.h5" @@ -51,12 +51,12 @@ if os.path.isfile(filename): os.remove(filename) -reporter = SimulationReporter("test_lj.h5", lj_fluid.topology, 1) +reporter = _SimulationReporter("test_lj.h5", lj_fluid.topology, 1) from chiron.integrators import LangevinIntegrator # initialize the Langevin integrator -integrator = LangevinIntegrator(reporter=reporter, save_frequency=100) +integrator = LangevinIntegrator(reporter=reporter, report_frequency=100) print("init_energy: ", lj_potential.compute_energy(sampler_state.x0, nbr_list)) integrator.run( diff --git a/Examples/LJ_mcmove.py b/Examples/LJ_mcmove.py index 0ed407e..bc673f6 100644 --- a/Examples/LJ_mcmove.py +++ b/Examples/LJ_mcmove.py @@ -45,7 +45,7 @@ # build the neighbor list from the sampler state nbr_list.build_from_state(sampler_state) -from chiron.reporters import SimulationReporter +from chiron.reporters import _SimulationReporter # initialize a reporter to save the simulation data filename = "test_lj.h5" @@ -53,7 +53,7 @@ if os.path.isfile(filename): os.remove(filename) -reporter = SimulationReporter("test_mc_lj.h5", lj_fluid.topology, 1) +reporter = _SimulationReporter("test_mc_lj.h5", lj_fluid.topology, 1) from chiron.mcmc import MetropolisDisplacementMove @@ -61,7 +61,7 @@ seed=1234, displacement_sigma=0.01 * unit.nanometer, nr_of_moves=1000, - simulation_reporter=reporter, + reporter=reporter, ) mc_move.run(sampler_state, thermodynamic_state, nbr_list, True) diff --git a/chiron/analysis.py b/chiron/analysis.py new file mode 100644 index 0000000..47fdb86 --- /dev/null +++ b/chiron/analysis.py @@ -0,0 +1,60 @@ +import numpy as np + + +class MBAREstimator: + def __init__(self) -> None: + """ + Initialize the MBAR analysis class. + + Returns: + - None + """ + self.mbar_f_k = None + self.mbar = None + + def initialize(self, u_kn: np.ndarray, N_k: np.ndarray): + """ + Initialize the analysis object. + + Parameters + ---------- + u_kn: np.ndarray + Array of dimensionless reduced potentials for each state. + N_k: np.ndarray + Array of number of samples for each state. + + """ + from pymbar import MBAR + from loguru import logger as log + + log.debug(f"{N_k=}") + self.mbar = MBAR(u_kn=u_kn, N_k=N_k) + + @property + def f_k(self): + """ + Free energy for each state. + + Returns + ------- + mbar.f_k. + """ + + from loguru import logger as log + + log.debug(self.mbar.f_k) + return self.mbar.f_k + + def get_free_energy_difference(self): + """ + Calculate the free energy difference between the endstates. + + Returns + ------- + float + """ + from loguru import logger as log + + log.debug(self.mbar.f_k[-1]) + self.f_k = self.mbar.f_k + return self.mbar_f_k[-1] diff --git a/chiron/data/test_md.h5 b/chiron/data/langevin_reporter.h5 similarity index 100% rename from chiron/data/test_md.h5 rename to chiron/data/langevin_reporter.h5 diff --git a/chiron/integrators.py b/chiron/integrators.py index b9120fd..0d77452 100644 --- a/chiron/integrators.py +++ b/chiron/integrators.py @@ -4,8 +4,10 @@ from jax import random from openmm import unit from .states import SamplerState, ThermodynamicState -from .reporters import SimulationReporter +from .reporters import LangevinDynamicsReporter from typing import Optional +from .potential import NeuralNetworkPotential +from .neighbors import PairsBase class LangevinIntegrator: @@ -24,8 +26,8 @@ def __init__( self, stepsize=1.0 * unit.femtoseconds, collision_rate=1.0 / unit.picoseconds, - save_frequency: int = 100, - reporter: Optional[SimulationReporter] = None, + report_frequency: int = 100, + reporter: Optional[LangevinDynamicsReporter] = None, save_traj_in_memory: bool = False, ) -> None: """ @@ -37,29 +39,33 @@ def __init__( Time step of integration with units of time. Default is 1.0 * unit.femtoseconds. collision_rate : unit.Quantity, optional Collision rate for the Langevin dynamics, with units 1/time. Default is 1.0 / unit.picoseconds. - save_frequency : int, optional + report_frequency : int, optional Frequency of saving the simulation data. Default is 100. reporter : SimulationReporter, optional Reporter object for saving the simulation data. Default is None. - save_traj_in_memory : bool - Whether to save the trajectory in memory. For debugging purposes only. + save_traj_in_memory: bool + Flag indicating whether to save the trajectory in memory. + Default is False. NOTE: Only for debugging purposes. """ from loguru import logger as log self.kB = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA log.info(f"stepsize = {stepsize}") log.info(f"collision_rate = {collision_rate}") - log.info(f"save_frequency = {save_frequency}") + log.info(f"report_frequency = {report_frequency}") self.stepsize = stepsize self.collision_rate = collision_rate - if reporter is not None: - log.info(f"Using reporter {reporter} saving to {reporter.filename}") + if reporter: + log.info( + f"Using reporter {reporter} saving trajectory to {reporter.xtc_file_path}" + ) + log.info(f"and logging to {reporter.log_file_path}") self.reporter = reporter - self.save_frequency = save_frequency + self.report_frequency = report_frequency + self.velocities = None self.save_traj_in_memory = save_traj_in_memory self.traj = [] - self.velocities = None def set_velocities(self, vel: unit.Quantity) -> None: """ @@ -77,8 +83,7 @@ def run( sampler_state: SamplerState, thermodynamic_state: ThermodynamicState, n_steps: int = 5_000, - key=random.PRNGKey(0), - nbr_list=None, + nbr_list: Optional[PairsBase] = None, progress_bar=False, ): """ @@ -92,9 +97,7 @@ def run( The thermodynamic state of the system, including temperature and potential. n_steps : int, optional Number of simulation steps to perform. - key : jax.random.PRNGKey, optional - Random key for generating random numbers. - nbr_list : NeighborListNsqrd, optional + nbr_list : PairBase, optional Neighbor list for the system. progress_bar : bool, optional Flag indicating whether to display a progress bar during integration. @@ -116,8 +119,11 @@ def run( log.debug("Running Langevin dynamics") log.debug(f"n_steps = {n_steps}") log.debug(f"temperature = {temperature}") - log.debug(f"Using seed: {key}") + # Initialize the random number generator + key = sampler_state.new_PRNG_key + + # Convert to dimensionless quantities kbT_unitless = (self.kB * temperature).value_in_unit_system(unit.md_unit_system) mass_unitless = jnp.array(mass.value_in_unit_system(unit.md_unit_system))[ :, None @@ -127,22 +133,24 @@ def run( collision_rate_unitless = self.collision_rate.value_in_unit_system( unit.md_unit_system ) + a = jnp.exp((-collision_rate_unitless * stepsize_unitless)) + b = jnp.sqrt(1 - jnp.exp(-2 * collision_rate_unitless * stepsize_unitless)) # Initialize velocities if self.velocities is None: v0 = sigma_v * random.normal(key, x0.shape) else: v0 = self.velocities.value_in_unit_system(unit.md_unit_system) - # Convert to dimensionless quantities - a = jnp.exp((-collision_rate_unitless * stepsize_unitless)) - b = jnp.sqrt(1 - jnp.exp(-2 * collision_rate_unitless * stepsize_unitless)) x = x0 v = v0 + if nbr_list is not None: nbr_list.build_from_state(sampler_state) F = potential.compute_force(x, nbr_list) + + # propagation loop for step in tqdm(range(n_steps)) if self.progress_bar else range(n_steps): key, subkey = random.split(key) # v @@ -151,46 +159,79 @@ def run( x += (stepsize_unitless * 0.5) * v if nbr_list is not None: - x = nbr_list.space.wrap(x) - # check if we need to rebuild the neighborlist after moving the particles - if nbr_list.check(x): - nbr_list.build(x, self.box_vectors) + x = self._wrap_and_rebuild_neighborlist(x, nbr_list) # o random_noise_v = random.normal(subkey, x.shape) v = (a * v) + (b * sigma_v * random_noise_v) x += (stepsize_unitless * 0.5) * v if nbr_list is not None: - x = nbr_list.space.wrap(x) - # check if we need to rebuild the neighborlist after moving the particles - if nbr_list.check(x): - nbr_list.build(x, self.box_vectors) + x = self._wrap_and_rebuild_neighborlist(x, nbr_list) F = potential.compute_force(x, nbr_list) # v v += (stepsize_unitless * 0.5) * F / mass_unitless - if step % self.save_frequency == 0: - # log.debug(f"Saving at step {step}") - # check if reporter is attribute of the class - # log.debug(f"step {step} energy {potential.compute_energy(x, nbr_list)}") - # log.debug(f"step {step} force {F}") - + if step % self.report_frequency == 0: if hasattr(self, "reporter") and self.reporter is not None: - d = { - "traj": x, - "energy": potential.compute_energy(x, nbr_list), - "step": step, - } - if nbr_list is not None: - d["box_vectors"] = nbr_list.space.box_vectors - - # log.debug(d) - self.reporter.report(d) + self._report(x, potential, nbr_list, step) + if self.save_traj_in_memory: self.traj.append(x) log.debug("Finished running Langevin dynamics") # save the final state of the simulation in the sampler_state object sampler_state.x0 = x - # self.reporter.close() + sampler_state.v0 = v + + def _wrap_and_rebuild_neighborlist(self, x: jnp.array, nbr_list: PairsBase): + """ + Wrap the coordinates and rebuild the neighborlist if necessary. + + Parameters + ---------- + x: jnp.array + The coordinates of the particles. + nbr_list: PairsBsse + The neighborlist object. + """ + + x = nbr_list.space.wrap(x) + # check if we need to rebuild the neighborlist after moving the particles + if nbr_list.check(x): + nbr_list.build(x, self.box_vectors) + return x + + def _report( + self, + x: jnp.array, + potential: NeuralNetworkPotential, + nbr_list: PairsBase, + step: int, + ): + """ + Reports the trajectory, energy, step, and box vectors (if available) to the reporter. + + Parameters + ---------- + x : jnp.array + current coordinate set + potential: NeuralNetworkPotential + potential used to compute the energy and force + nbr_list: PairsBase + The neighbor list + step: int + The current time step. + + Returns: + None + """ + d = { + "positions": x, + "potential_energy": potential.compute_energy(x, nbr_list), + "step": step, + } + if nbr_list is not None: + d["box_vectors"] = nbr_list.space.box_vectors + + self.reporter.report(d) diff --git a/chiron/mcmc.py b/chiron/mcmc.py index 2cb7d75..8063012 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -2,10 +2,16 @@ from openmm import unit from typing import Tuple, List, Optional import jax.numpy as jnp -from chiron.reporters import SimulationReporter +from chiron.reporters import LangevinDynamicsReporter, _SimulationReporter + class MCMCMove: - def __init__(self, nr_of_moves: int, seed: int): + def __init__( + self, + nr_of_moves: int, + reporter: Optional[_SimulationReporter] = None, + report_frequency: Optional[int] = 100, + ): """ Initialize a move within the molecular system. @@ -13,13 +19,22 @@ def __init__(self, nr_of_moves: int, seed: int): ---------- nr_of_moves : int Number of moves to be applied. - seed : int - Seed for random number generation. + reporter : _SimulationReporter, optional + Reporter object for saving the simulation data. + Default is None. + report_frequency : int, optional """ - import jax.random as jrandom self.nr_of_moves = nr_of_moves - self.key = jrandom.PRNGKey(seed) # 'seed' is an integer seed value + self.reporter = reporter + self.report_frequency = report_frequency + from loguru import logger as log + + if self.reporter is not None: + log.info( + f"Using reporter {self.reporter} saving to {self.reporter.workdir}" + ) + assert self.report_frequency is not None class LangevinDynamicsMove(MCMCMove): @@ -27,9 +42,9 @@ def __init__( self, stepsize=1.0 * unit.femtoseconds, collision_rate=1.0 / unit.picoseconds, - simulation_reporter: Optional[SimulationReporter] = None, + reporter: Optional[LangevinDynamicsReporter] = None, + report_frequency: int = 100, nr_of_steps=1_000, - seed: int = 1234, save_traj_in_memory: bool = False, ): """ @@ -41,13 +56,27 @@ def __init__( Time step size for the integration. collision_rate : unit.Quantity Collision rate for the Langevin dynamics. - nr_of_steps : int + reporter : LangevinDynamicsReporter, optional + Reporter object for saving the simulation data. + Default is None. + report_frequency : int + Frequency of saving the simulation data. + Default is 100. + nr_of_steps : int, optional Number of steps to run the integrator for. + Default is 1_000. + save_traj_in_memory: bool + Flag indicating whether to save the trajectory in memory. + Default is False. NOTE: Only for debugging purposes. """ - super().__init__(nr_of_steps, seed) + super().__init__( + nr_of_moves=nr_of_steps, + reporter=reporter, + report_frequency=report_frequency, + ) + self.stepsize = stepsize self.collision_rate = collision_rate - self.simulation_reporter = simulation_reporter self.save_traj_in_memory = save_traj_in_memory self.traj = [] from chiron.integrators import LangevinIntegrator @@ -55,7 +84,8 @@ def __init__( self.integrator = LangevinIntegrator( stepsize=self.stepsize, collision_rate=self.collision_rate, - reporter=self.simulation_reporter, + report_frequency=report_frequency, + reporter=reporter, save_traj_in_memory=save_traj_in_memory, ) @@ -67,8 +97,12 @@ def run( """ Run the integrator to perform molecular dynamics simulation. - Args: - state_variables (StateVariablesCollection): State variables of the system. + Parameters + ---------- + sampler_state : SamplerState + The sampler state to run the integrator on. + thermodynamic_state : ThermodynamicState + The thermodynamic state to run the integrator on. """ assert isinstance( @@ -82,38 +116,18 @@ def run( thermodynamic_state=thermodynamic_state, sampler_state=sampler_state, n_steps=self.nr_of_moves, - key=self.key, ) + if self.save_traj_in_memory: self.traj.append(self.integrator.traj) self.integrator.traj = [] class MCMove(MCMCMove): - def __init__(self, nr_of_moves: int, seed: int) -> None: - super().__init__(nr_of_moves, seed) - - def _check_state_compatiblity( - self, - old_state: SamplerState, - new_state: SamplerState, - ): - """ - Check if the states are compatible. - - Parameters - ---------- - old_state : StateVariablesCollection - The state of the system before the move. - new_state : StateVariablesCollection - The state of the system after the move. - - Raises - ------ - ValueError - If the states are not compatible. - """ - pass + def __init__( + self, nr_of_moves: int, reporter: Optional[_SimulationReporter] + ) -> None: + super().__init__(nr_of_moves, reporter=reporter) def apply_move(self): """ @@ -239,9 +253,9 @@ def _validate_sequence(self): raise ValueError(f"Move {move_name} in the sequence is not available.") -class MCMCSampler(object): +class MCMCSampler: """ - Basic Markov chain Monte Carlo Gibbs sampler. + Basic Markov chain Monte Carlo sampler. Parameters ---------- @@ -291,9 +305,13 @@ def run(self, n_iterations: int = 1): log.info("Finished running MCMC sampler") log.debug("Closing reporter") for _, move in self.move.move_schedule: - if move.simulation_reporter is not None: - move.simulation_reporter.close() - log.debug(f"Closed reporter {move.simulation_reporter.filename}") + if move.reporter is not None: + move.reporter.flush_buffer() + # TODO: flush reporter + log.debug(f"Closed reporter {move.reporter.log_file_path}") + + +from .neighbors import PairsBase class MetropolizedMove(MCMove): @@ -323,16 +341,18 @@ class MetropolizedMove(MCMove): def __init__( self, - seed: int = 1234, atom_subset: Optional[List[int]] = None, nr_of_moves: int = 100, + reporter: Optional[_SimulationReporter] = None, + report_frequency: int = 1, ): self.n_accepted = 0 self.n_proposed = 0 self.atom_subset = atom_subset - super().__init__(nr_of_moves=nr_of_moves, seed=seed) + super().__init__(nr_of_moves=nr_of_moves, reporter=reporter) from loguru import logger as log + self.report_frequency = report_frequency log.debug(f"Atom subset is {atom_subset}.") @property @@ -349,8 +369,7 @@ def apply( self, thermodynamic_state: ThermodynamicState, sampler_state: SamplerState, - reporter: SimulationReporter, - nbr_list=None, + nbr_list=Optional[PairsBase], ): """Apply a metropolized move to the sampler state. @@ -362,8 +381,6 @@ def apply( The thermodynamic state to use to apply the move. sampler_state : SamplerState The initial sampler state to apply the move to. This is modified. - reporter: SimulationReporter - The reporter to write the data to. nbr_list: Neighbor List or Pair List routine, The routine to use to calculate the interacting atoms. Default is None and will use an unoptimized pairlist without PBC @@ -376,9 +393,8 @@ def apply( sampler_state, nbr_list ) # NOTE: in kT log.debug(f"Initial energy is {initial_energy} kT.") - # Store initial positions of the atoms that are moved. - # We'll use this also to recover in case the move is rejected. + # Store initial positions of the atoms that are moved. x0 = sampler_state.x0 atom_subset = self.atom_subset if atom_subset is None: @@ -420,15 +436,14 @@ def apply( log.debug( f"Move accepted. Energy change: {delta_energy:.3f} kT. Number of accepted moves: {self.n_accepted}." ) - reporter.report( - { - "energy": thermodynamic_state.kT_to_kJ_per_mol( - proposed_energy - ).value_in_unit_system(unit.md_unit_system), - "step": self.n_proposed, - "traj": sampler_state.x0, - } - ) + if self.n_proposed % self.report_frequency == 0: + self.reporter.report( + { + "energy": proposed_energy, # in kT + "step": self.n_proposed, + "traj": sampler_state.x0, + } + ) else: # Restore original positions. if atom_subset is None: @@ -493,11 +508,10 @@ class MetropolisDisplacementMove(MetropolizedMove): def __init__( self, - seed: int = 1234, displacement_sigma=1.0 * unit.nanometer, nr_of_moves: int = 100, atom_subset: Optional[List[int]] = None, - simulation_reporter: Optional[SimulationReporter] = None, + reporter: Optional[LangevinDynamicsReporter] = None, ): """ Initialize the MCMC class. @@ -512,22 +526,16 @@ def __init__( The number of moves to perform. Default is 100. atom_subset : list of int, optional A subset of atom indices to consider for the moves. Default is None. - simulation_reporter : SimulationReporter, optional + reporter : SimulationReporter, optional The reporter to write the data to. Default is None. Returns ------- None """ - from loguru import logger as log - - super().__init__(nr_of_moves=nr_of_moves, seed=seed) + super().__init__(nr_of_moves=nr_of_moves, reporter=reporter) self.displacement_sigma = displacement_sigma self.atom_subset = atom_subset - self.simulation_reporter = simulation_reporter - if self.simulation_reporter is not None: - log.info( - f"Using reporter {self.simulation_reporter} saving to {self.simulation_reporter.filename}" - ) + self.key = None def displace_positions( self, positions: jnp.array, displacement_sigma=1.0 * unit.nanometer @@ -576,17 +584,18 @@ def run( ): from tqdm import tqdm from loguru import logger as log + from jax import random + + self.key = sampler_state.new_PRNG_key for trials in ( tqdm(range(self.nr_of_moves)) if progress_bar else range(self.nr_of_moves) ): - self.apply( - thermodynamic_state, sampler_state, self.simulation_reporter, nbr_list - ) + self.apply(thermodynamic_state, sampler_state, nbr_list) if trials % 100 == 0: log.debug(f"Acceptance rate: {self.n_accepted / self.n_proposed}") - if self.simulation_reporter is not None: - self.simulation_reporter.report( + if self.reporter is not None: + self.reporter.report( { "Acceptance rate": self.n_accepted / self.n_proposed, "step": self.n_proposed, diff --git a/chiron/multistate.py b/chiron/multistate.py index 13b7798..b9100ec 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -4,43 +4,65 @@ from openmm import unit import numpy as np from chiron.mcmc import MCMCMove +from chiron.reporters import MultistateReporter class MultiStateSampler: """ - Base class for samplers that sample multiple thermodynamic states using - one or more replicas. + A sampler for simulating multiple thermodynamic states using replicas. - This base class provides a general simulation facility for multistate from multiple - thermodynamic states, allowing any set of thermodynamic states to be specified. - If instantiated on its own, the thermodynamic state indices associated with each - state are specified and replica mixing does not change any thermodynamic states, + This class provides a general simulation facility for sampling from multiple + thermodynamic states. It allows specifying any set of thermodynamic states. + If instantiated on its own, the thermodynamic state indices associated with + each state are specified, and replica mixing does not change any thermodynamic states, meaning that each replica remains in its original thermodynamic state. + + Attributes + ---------- + n_states : int + Number of thermodynamic states (read-only). + n_replicas : int + Number of replicas (read-only). + iteration : int + Current iteration of the simulation (read-only). + mcmc_moves : List[MCMCMove] + MCMC moves used to propagate the simulation. + sampler_states : List[SamplerState] + Sampler states list at the current iteration. + is_periodic : bool + True if system is periodic, False if not, None if not initialized. + is_completed : bool + Check if the sampler has reached any stop target criteria (read-only). + + Methods + ------- + create(thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], nbr_list: NeighborListNsqrd) + Creates a new multistate sampler simulation. + minimize(tolerance: unit.Quantity = 1.0 * unit.kilojoules_per_mole / unit.nanometers, max_iterations: int = 1000) + Minimizes all replicas in the sampler. + run(n_iterations: int = 10) + Executes the replica-exchange simulation for a specified number of iterations. + """ def __init__( self, - mcmc_moves=Union[MCMCMove, List[MCMCMove]], - online_analysis_interval=5, + mcmc_moves: Union[MCMCMove, List[MCMCMove]], + reporter: MultistateReporter, ): """ - Parameters - ---------- - mcmc_moves : MCMCMove or list of MCMCMove - The MCMCMove used to propagate the thermodynamic states. If a list of MCMCMoves, - they will be assigned to the correspondent thermodynamic state on - creation. + Initialize the MultiStateSampler. - Attributes + Parameters ---------- - n_replicas - n_states - mcmc_moves - sampler_states - is_completed + mcmc_moves : Union[MCMCMove, List[MCMCMove]] + The MCMCMove or list of MCMCMoves used to propagate the thermodynamic states. + reporter : MultistateReporter + The reporter used to store the simulation data. """ + import copy - from openmm import unit + from chiron.analysis import MBAREstimator # These will be set on initialization. See function # create() for explanation of single variables. @@ -50,34 +72,40 @@ def __init__( self._replica_thermodynamic_states = None self._iteration = None self._energy_thermodynamic_states = None - self._energy_thermodynamic_states_for_each_iteration = None self._neighborhoods = None - self._energy_unsampled_states = None self._n_accepted_matrix = None self._n_proposed_matrix = None - self._reporter = None + self._reporter = reporter # NOTE: reporter needs to be putlic, API change ahead self._metadata = None - self._online_analysis_interval = online_analysis_interval - self._timing_data = dict() - self.free_energy_estimator = None - self._traj = None - self._mcmc_moves = copy.deepcopy(mcmc_moves) - - self._last_mbar_f_k = None - self._last_err_free_energy = None + self._online_estimator = None + self._offline_estimator = MBAREstimator() @property - def n_states(self): - """The integer number of thermodynamic states (read-only).""" + def n_states(self) -> int: + """ + Get the number of thermodynamic states in the sampler. + + Returns + ------- + int + The number of thermodynamic states. + """ if self._thermodynamic_states is None: return 0 else: return len(self._thermodynamic_states) @property - def n_replicas(self): - """The integer number of replicas (read-only).""" + def n_replicas(self) -> int: + """ + Get the number of replicas in the sampler. + + Returns + ------- + int + The number of replicas. + """ if self._sampler_states is None: return 0 else: @@ -104,29 +132,37 @@ def mcmc_moves(self): return copy.deepcopy(self._mcmc_moves) @property - def sampler_states(self): - """A copy of the sampler states list at the current iteration. + def sampler_states(self) -> Optional[List[SamplerState]]: + """ + Get a copy of the sampler states list at the current iteration. - This can be set only before running. + This property can only be set before running the simulation. + + Returns + ------- + Optional[List[SamplerState]] + The list of sampler states at the current iteration, or None if not set. """ + if self._sampler_states is None: + return None import copy return copy.deepcopy(self._sampler_states) @property def is_periodic(self): - """Return True if system is periodic, False if not, and None if not initialized""" + """ + Determine if the system is periodic. + + Returns + ------- + Optional[bool] + True if the system is periodic, False if not, and None if not initialized. + """ if self._sampler_states is None: return None return self._thermodynamic_states[0].is_periodic - @property - def metadata(self): - """A copy of the metadata dictionary passed on creation (read-only).""" - import copy - - return copy.deepcopy(self._metadata) - @property def is_completed(self): """Check if we have reached any of the stop target criteria (read-only)""" @@ -134,59 +170,55 @@ def is_completed(self): def _compute_replica_energies(self, replica_id: int) -> np.ndarray: """ - Compute the energy for the replica in every ThermodynamicState. + Compute the energy of a replica across all thermodynamic states. Parameters ---------- replica_id : int - The ID of the replica to compute energies for. + The index of the replica for which to compute energies. Returns ------- np.ndarray - Array of energies for the specified replica across all thermodynamic states. + An array of energies for the replica across all thermodynamic states. """ - import jax.numpy as jnp from chiron.states import calculate_reduced_potential_at_states - # Only compute energies of the sampled states over neighborhoods. - thermodynamic_states = [ - self._thermodynamic_states[n] for n in range(self.n_states) - ] # Retrieve sampler state associated to this replica. sampler_state = self._sampler_states[replica_id] # Compute energy for all thermodynamic states. - return calculate_reduced_potential_at_states( - sampler_state, thermodynamic_states, self.nbr_list + energies = calculate_reduced_potential_at_states( + sampler_state, self._thermodynamic_states, self.nbr_list ) + return energies def create( self, thermodynamic_states: List[ThermodynamicState], sampler_states: List[SamplerState], nbr_list: NeighborListNsqrd, - metadata: Optional[dict] = None, ): - """Create new multistate sampler simulation. + """ + Create a new multistate sampler simulation. + Parameters + ---------- thermodynamic_states : List[ThermodynamicState] - List of ThermodynamicStates to simulate, with one replica allocated per state. + List of ThermodynamicStates to simulate, with one replica per state. sampler_states : List[SamplerState] - List of initial SamplerStates. The number of replicas is taken to be the number - of sampler states provided. + List of initial SamplerStates. The number of states is the number of replicas. nbr_list : NeighborListNsqrd - Neighbor list object to be used in the simulation. - metadata : dict, optional - Optional simulation metadata to be stored in the file. + Neighbor list object for the simulation. Raises ------ RuntimeError - If the lengths of thermodynamic_states and sampler_states are not equal. + If the lengths of `thermodynamic_states` and `sampler_states` are not equal. """ - # TODO: initialize reporter here - # TODO: consider unsampled thermodynamic states for reweighting schemes - self.free_energy_estimator = "mbar" + + self._online_estimator = None + + from chiron.reporters import MultistateReporter # Ensure the number of thermodynamic states matches the number of sampler states if len(thermodynamic_states) != len(sampler_states): @@ -194,9 +226,9 @@ def create( "Number of thermodynamic states and sampler states must be equal." ) - self._allocate_variables(thermodynamic_states, sampler_states) self.nbr_list = nbr_list - self._reporter = None + self._allocate_variables(thermodynamic_states, sampler_states) + self._reporter = MultistateReporter() def _allocate_variables( self, @@ -212,10 +244,6 @@ def _allocate_variables( A list of ThermodynamicState objects to be used in the sampler. sampler_states : List[SamplerState] A list of SamplerState objects for initializing the sampler. - unsampled_thermodynamic_states : Optional[List[ThermodynamicState]], optional - A list of additional ThermodynamicState objects that are not directly sampled but - for which energies will be computed for reweighting schemes. Defaults to None, - meaning no unsampled states are considered. Raises ------ @@ -225,40 +253,21 @@ def _allocate_variables( import copy import numpy as np - # Save thermodynamic states. This sets n_replicas. - self._thermodynamic_states = [ - copy.deepcopy(thermodynamic_state) - for thermodynamic_state in thermodynamic_states - ] - - # Deep copy sampler states. - self._sampler_states = [ - copy.deepcopy(sampler_state) for sampler_state in sampler_states - ] - + self._thermodynamic_states = copy.deepcopy(thermodynamic_states) + self._sampler_states = sampler_states assert len(self._thermodynamic_states) == len(self._sampler_states) - # Set initial thermodynamic state indices - initial_thermodynamic_states = np.arange( - len(self._thermodynamic_states), dtype=int - ) - self._replica_thermodynamic_states = np.array( - initial_thermodynamic_states, np.int64 + self._replica_thermodynamic_states = np.arange( + len(thermodynamic_states), dtype=int ) - # Reset statistics. - - # _n_accepted_matrix[i][j] is the number of swaps proposed between thermodynamic states i and j. - # _n_proposed_matrix[i][j] is the number of swaps proposed between thermodynamic states i and j. - # Allocate memory for energy matrix. energy_thermodynamic_states[k][l] - # is the reduced potential computed at the positions of SamplerState sampler_states[k] - # and ThermodynamicState thermodynamic_states[l]. - + # Initialize matrices for tracking acceptance and proposal statistics. self._n_accepted_matrix = np.zeros([self.n_states, self.n_states], np.int64) self._n_proposed_matrix = np.zeros([self.n_states, self.n_states], np.int64) self._energy_thermodynamic_states = np.zeros( [self.n_replicas, self.n_states], np.float64 ) self._traj = [[] for _ in range(self.n_replicas)] + # Ensure there is an MCMCMove for each thermodynamic state. from chiron.mcmc import MCMCMove @@ -288,10 +297,9 @@ def _minimize_replica( replica_id : int The index of the replica to minimize. tolerance : unit.Quantity, optional - The energy tolerance to which the system should be minimized. - Defaults to 1.0 kilojoules/mole/nanometers. + The energy tolerance for minimization (default: 1.0 kJ/mol/nm). max_iterations : int, optional - The maximum number of minimization iterations. Defaults to 1000. + Maximum number of minimization iterations (default: 1000). Notes ----- @@ -301,7 +309,6 @@ def _minimize_replica( from chiron.minimze import minimize_energy from loguru import logger as log - # Retrieve thermodynamic and sampler states. thermodynamic_state = self._thermodynamic_states[ self._replica_thermodynamic_states[replica_id] ] @@ -371,94 +378,90 @@ def minimize( def _propagate_replica(self, replica_id: int): """ - Propagate the state of a single replica. - - This method applies the MCMC move to the replica to change its state - according to the specified thermodynamic state. + Propagate the state of a single replica using its assigned MCMC move. Parameters ---------- replica_id : int The index of the replica to propagate. + Raises ------ RuntimeError If an error occurs during the propagation of the replica. """ - # Retrieve thermodynamic, sampler states, and MCMC move of this replica. + thermodynamic_state_id = self._replica_thermodynamic_states[replica_id] sampler_state = self._sampler_states[replica_id] - thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id] mcmc_move = self._mcmc_moves[thermodynamic_state_id] - # Apply MCMC move. + # Apply the MCMC move to the replica. mcmc_move.run(sampler_state, thermodynamic_state) + # Append the new state to the trajectory for analysis. self._traj[replica_id].append(sampler_state.x0) def _perform_swap_proposals(self): """ Perform swap proposals between replicas. - Placeholder method for replica swapping logic. Subclasses should - override this method with specific swapping algorithms. + This method should be overridden by subclasses to implement specific swapping algorithms. Returns ------- np.ndarray An array of updated thermodynamic state indices for each replica. """ - # Placeholder implementation, should be overridden by subclasses # For this example, we'll just return the current state indices return self._replica_thermodynamic_states def _mix_replicas(self) -> np.ndarray: """ - Propose and execute swaps between replicas. + Propose and execute swaps between replicas to enhance sampling efficiency. - This method is responsible for enhancing sampling efficiency by proposing - swaps between different thermodynamic states of the replicas. The actual - swapping algorithm depends on the specific subclass implementation. + This method handles the logic for proposing swaps between different thermodynamic states + of the replicas. The specifics of the swapping algorithm depend on subclass implementations. Returns ------- np.ndarray - An array of updated thermodynamic state indices for each replica. + An array of updated thermodynamic state indices for each replica after swapping. """ from loguru import logger as log log.debug("Mixing replicas (does nothing for MultiStateSampler)...") - # Reset storage to keep track of swap attempts this iteration. + # Reset swap attempt counters for this iteration. self._n_accepted_matrix[:, :] = 0 self._n_proposed_matrix[:, :] = 0 - # Perform replica mixing (swap proposals and acceptances) - # The actual swapping logic would depend on subclass implementations - # Here, we assume a placeholder implementation + # Perform the swap proposals and acceptances. new_replica_states = self._perform_swap_proposals() # Calculate swap acceptance statistics n_swaps_proposed = self._n_proposed_matrix.sum() n_swaps_accepted = self._n_accepted_matrix.sum() swap_fraction_accepted = 0.0 - if n_swaps_proposed > 0: - swap_fraction_accepted = n_swaps_accepted / n_swaps_proposed + swap_fraction_accepted = ( + n_swaps_accepted / n_swaps_proposed if n_swaps_proposed > 0 else 0.0 + ) log.debug( f"Accepted {n_swaps_accepted}/{n_swaps_proposed} attempted swaps ({swap_fraction_accepted * 100.0:.1f}%)" ) + return new_replica_states def _propagate_replicas(self) -> None: """ Propagate all replicas through their respective MCMC moves. - This method iterates over all replicas and applies the corresponding MCMC move - to each one, based on its current thermodynamic state. + This method applies the corresponding MCMC move to each replica based on its + current thermodynamic state, thus advancing the state of each replica. """ from loguru import logger as log log.debug("Propagating all replicas...") + # Iterate over all replicas and propagate each one. for replica_id in range(self.n_replicas): self._propagate_replica(replica_id) @@ -466,9 +469,8 @@ def _compute_energies(self) -> None: """ Compute the energies of all replicas at all thermodynamic states. - This method calculates the energy for each replica in every thermodynamic state, - considering the defined neighborhoods to optimize the computation. The energies - are stored in the internal energy matrix of the sampler. + This method calculates the energy for each replica in every thermodynamic state. + The energies are stored in the internal energy matrix of the sampler. """ from loguru import logger as log @@ -476,9 +478,8 @@ def _compute_energies(self) -> None: # Initialize the energy matrix and neighborhoods self._energy_thermodynamic_states = np.zeros((self.n_replicas, self.n_states)) - # Calculate energies for each replica + # Calculate and store energies for each replica. for replica_id in range(self.n_replicas): - # Compute and store energies for the neighborhood states self._energy_thermodynamic_states[ replica_id, : ] = self._compute_replica_energies(replica_id) @@ -487,20 +488,18 @@ def _is_completed(self, iteration_limit: Optional[int] = None) -> bool: """ Determine if the sampling process has met its completion criteria. - This method checks if the simulation has reached a specified iteration limit - or any other predefined stopping condition. + Checks if the simulation has reached a specified iteration limit or any other + predefined stopping condition. Parameters ---------- iteration_limit : Optional[int], default=None - An optional iteration limit. If specified, the method checks if the - current iteration number has reached this limit. + An optional iteration limit to check against the current iteration number. Returns ------- bool - True if the simulation has completed based on the stopping criteria, - False otherwise. + True if the simulation has completed based on the stopping criteria, False otherwise. """ from loguru import logger as log @@ -517,115 +516,167 @@ def _is_completed(self, iteration_limit: Optional[int] = None) -> bool: def run(self, n_iterations: int = 10) -> None: """ - Execute the replica-exchange simulation. + Execute the replica-exchange simulation for a specified number of iterations. - Run the simulation for a specified number of iterations. If no number is - specified, it runs for the number of iterations set during the initialization - of the sampler. + Runs the simulation, performing replica propagation, mixing, and energy computation + for the specified number of iterations. Parameters ---------- n_iterations : int, default=10 - The number of iterations to run. - - Raises - ------ - RuntimeError - If an error occurs during the computation of energies. + The number of iterations to run the simulation. """ from loguru import logger as log - # If this is the first iteration, compute and store the - # starting energies of the minimized/equilibrated structures. - self.number_of_iterations = n_iterations - log.info("Running simulation...") - self._energy_thermodynamic_states_for_each_iteration_in_run = np.zeros( - [self.n_replicas, self.n_states, n_iterations + 1], np.float64 - ) - # Initialize energies if this is the first iteration + self.number_of_iterations = n_iterations + if self._iteration == 0: + # Initialize energies if this is the first iteration self._compute_energies() - # store energies for mbar analysis - self._energy_thermodynamic_states_for_each_iteration_in_run[ - :, :, self._iteration - ] = self._energy_thermodynamic_states - # TODO report energies - - iteration_limit = n_iterations + self._report_iteration() # start the sampling loop - log.debug(f"{iteration_limit=}") - while not self._is_completed(iteration_limit): - # Increment iteration counter. + while not self._is_completed(n_iterations): self._iteration += 1 - log.info("-" * 80) - log.info(f"Iteration {self._iteration}/{iteration_limit}") + log.info(f"Iteration {self._iteration}/{n_iterations}") log.info("-" * 80) - # Update thermodynamic states self._mix_replicas() - - # Propagate replicas. self._propagate_replicas() - - # Compute energies of all replicas at all states self._compute_energies() + self._report_iteration() + self._update_analysis() - # Add energies to the energy matrix - self._energy_thermodynamic_states_for_each_iteration_in_run[ - :, :, self._iteration - ] = self._energy_thermodynamic_states - # Write iteration to storage file - # TODO - # self._report_iteration() + self._reporter.flush_buffer() - # Update analysis - self._update_analysis() + def _report_energy_matrix(self): + """ + Report the energy matrix for each thermodynamic state. - def _report_iteration(self): - """Store positions, states, and energies of current iteration.""" + This method logs the energy per thermodynamic state, which is useful for analysis + and debugging purposes. + """ + from loguru import logger as log - # TODO: write energies + log.debug("Reporting energy per thermodynamic state...") + # NOTE: self._energy_thermodynamic_states is transposed from + # shape (n_replicas, n_states) to (n_states, n_replicas) + return {"u_kn": self._energy_thermodynamic_states.T} - # TODO: write trajectory + def _report_positions(self): + """ + Store and report the positions of all replicas at the current iteration. - # TODO: write mixing statistics + This method compiles and reports the position data for each replica, which + is critical for trajectory analysis. + """ + from loguru import logger as log - def _update_analysis(self): - """Update analysis of free energies""" + log.debug("Reporting positions...") + # numpy array with shape (n_replicas, n_atoms, 3) + xyz = np.zeros((self.n_replicas, self._sampler_states[0].x0.shape[0], 3)) + for replica_id in range(self.n_replicas): + xyz[replica_id] = self._sampler_states[replica_id].x0 + return {"positions": xyz} + + def _report(self, property: str) -> None: + """ + Report a specific property of the simulation. + + Depending on the specified property, this method delegates to the appropriate + internal reporting method. + + Parameters + ---------- + property : str + The property to report. Can be 'positions', 'states', 'energies', + 'trajectory', 'mixing_statistics', or 'all'. + """ from loguru import logger as log - if self._online_analysis_interval is None: - log.debug("No online analysis requested") - # Perform no analysis and exit function - return + log.debug(f"Reporting {property}...") + match property: + case "positions": + return self._report_positions() + case "states": + pass + case "u_kn": + return self._report_energy_matrix() + case "trajectory": + return + case "mixing_statistics": + return - # Perform offline free energy estimate if requested - if self.free_energy_estimator == "mbar": - self._last_err_free_energy = self._mbar_analysis() + def _report_iteration(self): + """ + Store and report various properties of the current iteration. - return + This method is called at each iteration to report essential simulation data, + such as positions, states, energies, and other properties defined in the reporter. + """ + from loguru import logger as log + + log.debug("Reporting data for current iteration...") + log.debug(self._reporter.properties_to_report) + prop = {} + for property in self._reporter.properties_to_report: + p = self._report(property) + if p: + prop.update(p) + self._reporter.report(prop) - def _mbar_analysis(self): + def _update_analysis(self): """ - Perform mbar analysis + Update the analysis of free energies based on the current simulation data. + + This method is responsible for updating the free energy estimates, either using + online or offline estimation methods, as configured in the sampler. """ - from pymbar import MBAR from loguru import logger as log - self._last_mbar_f_k_offline = np.zeros(len(self._thermodynamic_states)) + log.debug("Updating free energy analysis...") - log.debug( - f"{self._energy_thermodynamic_states_for_each_iteration_in_run.shape=}" - ) - log.debug(f"{self.n_states=}") - u_kn = self._energy_thermodynamic_states_for_each_iteration_in_run - log.debug(f"{self._iteration=}") - N_k = [self._iteration] * self.n_states - log.debug(f"{N_k=}") - mbar = MBAR(u_kn=u_kn, N_k=N_k) - log.debug(mbar.f_k) - self._last_mbar_f_k_offline = mbar.f_k + # Perform offline free energy estimate if requested + if self._offline_estimator: + log.debug("Performing offline free energy estimate...") + N_k = [self._iteration] * self.n_states + u_kn = self._reporter.get_property("u_kn") + self._offline_estimator.initialize( + u_kn=u_kn, + N_k=N_k, + ) + log.debug(self._offline_estimator.f_k) + elif self._online_estimator: + log.debug("Performing online free energy estimate...") + self._online_estimator.update() + else: + raise RuntimeError("No free energy estimator provided.") + + @property + def f_k(self) -> np.ndarray: + """ + Get the current free energy estimates. + + Returns the free energy estimates calculated by the sampler's free energy estimator. + The specific estimator used (online or offline) depends on the sampler configuration. + + Returns + ------- + np.ndarray + Array of free energy estimates for each thermodynamic state. + + Raises + ------ + RuntimeError + If no free energy estimator is found. + """ + + if self._offline_estimator: + return self._offline_estimator.f_k + elif self._online_estimator: + return self._online_estimator.f_k + else: + raise RuntimeError("No free energy estimator found.") diff --git a/chiron/reporters.py b/chiron/reporters.py index b2fae35..27457a6 100644 --- a/chiron/reporters.py +++ b/chiron/reporters.py @@ -5,31 +5,85 @@ import numpy as np from openmm.app import Topology +from typing import List -class SimulationReporter: - def __init__(self, filename: str, topology: Topology, buffer_size: int = 1): +class BaseReporter: + _directory = None + + @classmethod + def set_directory(cls, directory: str): """ - Initialize the SimulationReporter. + Set the base directory for saving reporter files. Parameters ---------- - filename : str - Name of the HDF5 file to write the simulation data. - topology: openmm.Topology - buffer_size : int, optional - Number of data points to buffer before writing to disk (default is 1). + directory : str + The path to the directory where files will be saved. + """ + cls._directory = directory + + @classmethod + def get_directory(cls): + """ + Get the current directory set for saving reporter files. + Returns + ------- + Path + The path to the directory where files will be saved. Defaults to the + current working directory if no directory has been set. + """ + from pathlib import Path + + if cls._directory is None: + log.debug( + f"No directory set, using current working directory: {Path.cwd()}" + ) + return Path.cwd() + return Path(cls._directory) + + +class _SimulationReporter: + def __init__(self, file_name: str, buffer_size: int = 10): + """ + Initialize the _SimulationReporter class. + + Parameters + ---------- + file_name : str + Name of the HDF5 file for writing simulation data. + buffer_size : int, optional + The size of the buffer before flushing data to disk (default is 10). """ - self.filename = filename + workdir = BaseReporter.get_directory() + self.file_path_base = workdir / f"{file_name}" + self.log_file_path = self.file_path_base.with_suffix(".h5") + self.workdir = workdir + self.report_iteration = 0 + import os + + os.makedirs(workdir, exist_ok=True) + + log.info(f"Writing simulation log data to {self.log_file_path}") + self.buffer_size = buffer_size - self.topology = topology self.buffer = {} - self.h5file = h5py.File(filename, "a") - log.info(f"Writing simulation data to {filename}") + + @property + def properties_to_report(self): + return self._default_properties + + @properties_to_report.setter + def properties_to_report(self, properties: List[str]): + self._default_properties = properties def get_available_keys(self): - return self.h5file.keys() + keys = [] + with h5py.File(self.log_file_path, "r") as h5file: + for key in h5file: + keys.append(key) + return keys def report(self, data_dict): """ @@ -40,17 +94,26 @@ def report(self, data_dict): data_dict : dict Dictionary containing data to report. Keys are data labels (e.g., 'energy'), and values are the data points (usually numpy arrays). - """ for key, value in data_dict.items(): if key not in self.buffer: + # new key shouldn't trigger a flush self.buffer[key] = [] self.buffer[key].append(value) - if len(self.buffer[key]) >= self.buffer_size: - self._write_to_disk(key) + self._flush_buffer_if_necessary() - def _write_to_disk(self, key:str): + def _flush_buffer_if_necessary(self): + """ + Flush the buffer to disk if it reaches the specified buffer size. + """ + # NOTE: we assume that every property is updated with the same frequency! + if all(len(self.buffer[key]) > self.buffer_size for key in self.buffer): + # flush and reset the buffer + log.debug(self.buffer) + self.flush_buffer() + + def _write_to_disk(self, key: str): """ Write buffered data of a given key to the HDF5 file. @@ -60,51 +123,222 @@ def _write_to_disk(self, key:str): The key of the data to write to disk. """ - data = np.array(self.buffer[key]) - if key in self.h5file: - dset = self.h5file[key] - dset.resize((dset.shape[0] + data.shape[0],) + data.shape[1:]) - dset[-data.shape[0] :] = data - else: - log.debug(f"Creating {key} in {self.filename}") - self.h5file.create_dataset( - key, data=data, maxshape=(None,) + data.shape[1:], chunks=True + log.debug(f"Writing {key} to file") + if key == "positions" and hasattr(self, "_write_to_trajectory"): + xyz = np.stack(self.buffer[key]) + self._write_to_trajectory( + positions=xyz, ) - self.buffer[key] = [] + with h5py.File(self.log_file_path, "a") as h5file: + if key in h5file: + data = np.array(self.buffer[key]) + dset = h5file[key] + dset.resize((dset.shape[0] + data.shape[0],) + data.shape[1:]) + dset[-data.shape[0] :] = data + else: + data = np.array(self.buffer[key]) + log.debug(f"Creating {key} in {self.log_file_path}") + h5file.create_dataset( + key, data=data, maxshape=(None,) + data.shape[1:], chunks=True + ) + + def reset_reporter_file(self): + # delete the reporter files + import os + + # if file exists, delete it + if os.path.exists(self.log_file_path): + log.debug(f"Deleting {self.log_file_path}") + os.remove(self.log_file_path) - def close(self): + def flush_buffer(self) -> None: """ - Write any remaining data in the buffer to disk and close the HDF5 file. + Write any remaining data in the buffer to disk. """ for key in self.buffer: if self.buffer[key]: self._write_to_disk(key) - self.h5file.close() + self._reset_buffer() - def get_property(self, name: str): + def _reset_buffer(self) -> None: """ - Get the property from the HDF5 file. + Reset the data buffer after writing to disk. + """ + self.buffer = {key: [] for key in self.buffer} + + def get_property(self, name: str) -> np.ndarray: + """ + Retrieve a specific property from the HDF5 file. Parameters ---------- name : str - Name of the property to get. + The name of the property to retrieve. Returns ------- np.ndarray - The property. + The retrieved property data, if available. + """ + if name == "positions" and hasattr(self, "read_from_trajectory"): + return self.read_from_trajectory() + + with h5py.File(self.log_file_path, "r") as h5file: + if name in h5file: + data = np.array(h5file[name]) + elif name in self.buffer and name not in h5file: + data = np.array(self.buffer[name]) + elif name not in h5file: + log.warning(f"{name} not in HDF5 file") + return None + + if name == "u_kn": + return np.transpose( + data, (2, 1, 0) + ) # shape: n_states, n_replicas, n_iterations + + else: + return data + + +from typing import Optional +import mdtraj as md + + +class MultistateReporter(_SimulationReporter): + _name = "multistate_reporter" + _default_properties = [ + "positions", + "box_vectors", + "u_kn", + "state_index", + "step", + ] + + def __init__( + self, + file_name: Optional[str] = None, + buffer_size: int = 1, + ) -> None: + """ + Initialize the MultistateReporter class. + + Parameters + ---------- + file_name : Optional[str], optional + Name of the file for storing multistate simulation data. If None, a + default name based on the reporter name is used. + buffer_size : int, optional + The size of the buffer before flushing data to disk (default is 1). + """ + + if file_name is None: + file_name = MultistateReporter.get_name() + + super().__init__(file_name=file_name, buffer_size=buffer_size) + self._replica_reporter = {} + + @classmethod + def get_name(cls): + return cls._name + + def _write_to_trajectory(self, positions: np.ndarray) -> None: + nr_of_frames, n_replicas, n_of_atoms, _ = positions.shape + + for replica_id in range(n_replicas): + # if file does not exist, create it + key = f"replica_{replica_id}" + if self._replica_reporter.get(key) is None: + self._replica_reporter[key] = LangevinDynamicsReporter(key) + + reporter = self._replica_reporter.get(key) + + for frame_id in range(nr_of_frames): + data = {"positions": positions[frame_id, replica_id]} + if self.buffer.get("box_vectors") is not None: + data["box_vectors"] = self.buffer.get("box_vectors")[frame_id] + reporter.report(data) + + def flush_buffer(self): + for reporter in self._replica_reporter.values(): + reporter.flush_buffer() + reporter._write_xtc_file_handle.flush() + + return super().flush_buffer() + + +from typing import Optional + +class MCReporter(_SimulationReporter): + _name = "mc_reporter" + + def __init__(self, file_name: Optional[str] = None, buffer_size: int = 1) -> None: + """ + Initialize the MCReporter class for Monte Carlo simulations. + + Parameters + ---------- + file_name : Optional[str], optional + The file name for storing simulation data. + buffer_size : int, optional + The size of the buffer before flushing data to disk. + """ + if file_name is None: + file_name = MCReporter.get_name() + + super().__init__(file_name=file_name, buffer_size=buffer_size) + + @classmethod + def get_name(cls): + return cls._name + + +class LangevinDynamicsReporter(_SimulationReporter): + _name = "langevin_reporter" + _default_properties = ["positions", "box_vectors", "potential_energy", "step"] + + def __init__( + self, + file_name: Optional[str] = None, + buffer_size: int = 1, + topology: Optional[Topology] = None, + ): + """ + Initialize the LangevinDynamicsReporter for Langevin dynamics simulations. + + Parameters + ---------- + file_name : Optional[str], optional + The file name for storing simulation data. + buffer_size : int, optional + The size of the buffer before flushing data to disk. + topology : Optional[Topology], optional + The system topology for generating trajectories. + """ + if file_name is None: + file_name = LangevinDynamicsReporter.get_name() + + super().__init__(file_name=file_name, buffer_size=buffer_size) + self.topology = topology + self._write_xtc_file_handle = None + self.xtc_file_path = f"{self.file_path_base}.xtc" + + @classmethod + def get_name(cls): + return cls._name + + def get_mdtraj_trajectory(self) -> md.Trajectory: """ - if name not in self.h5file: - log.debug(f"{name} not in HDF5 file") - return None - else: - return np.array(self.h5file[name]) + Generate an MDTraj trajectory object from the stored positions. - def get_mdtraj_trajectory(self): + Returns + ------- + md.Trajectory + The MDTraj trajectory object created from the stored position data. + """ import mdtraj as md return md.Trajectory( @@ -113,3 +347,86 @@ def get_mdtraj_trajectory(self): unitcell_lengths=self.get_property("box_vectors"), unitcell_angles=self.get_property("box_angles"), ) + + def _write_to_trajectory(self, positions: np.ndarray) -> None: + """ + Write position data to a trajectory file for molecular dynamics. + + Parameters + ---------- + positions : np.ndarray + The positions of particles to be written to the trajectory. + """ + if self._write_xtc_file_handle is None: + log.debug(f"Creating trajectory in {self.xtc_file_path}") + self._write_xtc_file_handle = md.formats.XTCTrajectoryFile( + self.xtc_file_path, mode="w" + ) + + LangevinDynamicsReporter._write_to_xtc( + file_handler=self._write_xtc_file_handle, + positions=positions, + iteration=self.buffer.get("step"), + box_vecotrs=self.buffer.get("box_vectors"), + ) + + def read_from_trajectory(self) -> np.ndarray: + """ + Read position data from a trajectory file. + + Returns + ------- + np.ndarray + The positions read from the trajectory file. + """ + # flush the write buffer + self._write_xtc_file_handle.flush() + with md.formats.XTCTrajectoryFile( + self.xtc_file_path, mode="r" + ) as _read_xtc_file_handle: + return LangevinDynamicsReporter._read_from_xtc(_read_xtc_file_handle) + + @classmethod + def _read_from_xtc(cls, file_handler) -> np.ndarray: + """ + Read data from an XTC file. + + Parameters + ---------- + file_handler : md.formats.XTCTrajectoryFile + The file handler for reading XTC files. + + Returns + ------- + np.ndarray + The data read from the XTC file. + """ + return file_handler.read() + + @classmethod + def _write_to_xtc( + cls, + file_handler: md.formats.XTCTrajectoryFile, + positions: np.ndarray, + iteration: np.ndarray, + box_vecotrs: Optional[np.ndarray] = None, + ): + """ + Write position data to an XTC file. + + Parameters + ---------- + file_handler : md.formats.XTCTrajectoryFile + The file handler for writing to XTC files. + positions : np.ndarray + The positions to be written. + iteration : np.ndarray + The iteration numbers corresponding to the positions. + box_vectors : Optional[np.ndarray], optional + Box vectors for each position frame. + """ + file_handler.write( + positions, + time=iteration, + box=box_vecotrs, + ) diff --git a/chiron/states.py b/chiron/states.py index 99ee8f9..99459ae 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -2,6 +2,7 @@ from typing import List, Optional, Union from jax import numpy as jnp from .potential import NeuralNetworkPotential +from jax import random class SamplerState: @@ -22,6 +23,7 @@ class SamplerState: def __init__( self, x0: unit.Quantity, + current_PRNG_key: random.PRNGKey, velocities: Optional[unit.Quantity] = None, box_vectors: Optional[unit.Quantity] = None, ) -> None: @@ -61,9 +63,16 @@ def __init__( raise ValueError( f"box_vectors must be a 3x3 array, got {box_vectors.shape} instead." ) + if velocities is not None and x0.shape != velocities.shape: + raise ValueError( + f"x0 and velocities must have the same shape, got {x0.shape} and {velocities.shape} instead." + ) + if current_PRNG_key is None: + raise ValueError(f"random_seed must be set.") self._x0 = x0 self._velocities = velocities + self._current_PRNG_key = current_PRNG_key self._box_vectors = box_vectors self._distance_unit = unit.nanometer @@ -98,6 +107,12 @@ def x0(self, x0: Union[jnp.array, unit.Quantity]) -> None: def distance_unit(self) -> unit.Unit: return self._distance_unit + @property + def new_PRNG_key(self) -> random.PRNGKey: + key, subkey = random.split(self._current_PRNG_key) + self._current_PRNG_key = key + return subkey + def _convert_to_jnp(self, array: unit.Quantity) -> jnp.array: """ Convert the sampler state to jnp arrays. diff --git a/chiron/tests/data/langevin_reporter.h5 b/chiron/tests/data/langevin_reporter.h5 new file mode 100644 index 0000000..e6e23c5 Binary files /dev/null and b/chiron/tests/data/langevin_reporter.h5 differ diff --git a/chiron/tests/test_convergence_tests.py b/chiron/tests/test_convergence_tests.py index 86c881c..16cfad3 100644 --- a/chiron/tests/test_convergence_tests.py +++ b/chiron/tests/test_convergence_tests.py @@ -48,11 +48,11 @@ def test_convergence_of_MC_estimator(prep_temp_dir): ) sampler_state = SamplerState(ho.positions) - from chiron.reporters import SimulationReporter + from chiron.reporters import _SimulationReporter id = uuid.uuid4() - simulation_reporter = SimulationReporter(f"{prep_temp_dir}/test_{id}.h5") + simulation_reporter = _SimulationReporter(f"{prep_temp_dir}/test_{id}.h5") # Initalize the move set (here only LangevinDynamicsMove) from chiron.mcmc import MetropolisDisplacementMove, MoveSchedule, MCMCSampler @@ -61,7 +61,7 @@ def test_convergence_of_MC_estimator(prep_temp_dir): nr_of_moves=100_000, displacement_sigma=0.5 * unit.angstrom, atom_subset=[0], - simulation_reporter=simulation_reporter, + reporter=simulation_reporter, ) move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) @@ -153,12 +153,12 @@ def test_langevin_dynamics_with_LJ_fluid(prep_temp_dir): potential=lj_potential, temperature=300 * unit.kelvin ) - from chiron.reporters import SimulationReporter + from chiron.reporters import _SimulationReporter id = uuid.uuid4() - reporter = SimulationReporter(f"{prep_temp_dir}/test_{id}.h5") + reporter = _SimulationReporter(f"{prep_temp_dir}/test_{id}.h5") - integrator = LangevinIntegrator(reporter=reporter, save_frequency=100) + integrator = LangevinIntegrator(reporter=reporter, report_frequency=100) integrator.run( sampler_state, thermodynamic_state, diff --git a/chiron/tests/test_integrators.py b/chiron/tests/test_integrators.py index 203b409..341c987 100644 --- a/chiron/tests/test_integrators.py +++ b/chiron/tests/test_integrators.py @@ -24,20 +24,28 @@ def test_langevin_dynamics(prep_temp_dir, provide_testsystems_and_potentials): # initialize states and integrator from chiron.integrators import LangevinIntegrator from chiron.states import SamplerState, ThermodynamicState + from chiron.utils import PRNG + + PRNG.set_seed(1234) thermodynamic_state = ThermodynamicState( potential=potential, temperature=300 * unit.kelvin ) - sampler_state = SamplerState(testsystem.positions) - from chiron.reporters import SimulationReporter + sampler_state = SamplerState(testsystem.positions, PRNG.get_random_key()) + + from chiron.reporters import LangevinDynamicsReporter + from chiron.reporters import BaseReporter + + # set up reporter directory + BaseReporter.set_directory(prep_temp_dir.join(f"test_{i}")) - reporter = SimulationReporter(f"{prep_temp_dir}/test{i}.h5", None, 1) + reporter = LangevinDynamicsReporter() - integrator = LangevinIntegrator(reporter=reporter) + integrator = LangevinIntegrator(reporter=reporter, report_frequency=1) integrator.run( sampler_state, thermodynamic_state, - n_steps=5, + n_steps=20, ) i = i + 1 diff --git a/chiron/tests/test_mcmc.py b/chiron/tests/test_mcmc.py index 63c3f92..21bf8b6 100644 --- a/chiron/tests/test_mcmc.py +++ b/chiron/tests/test_mcmc.py @@ -36,40 +36,47 @@ def test_sample_from_harmonic_osciallator(prep_temp_dir): thermodynamic_state = ThermodynamicState( potential=harmonic_potential, temperature=300 * unit.kelvin ) - sampler_state = SamplerState(x0=ho.positions) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + + sampler_state = SamplerState( + x0=ho.positions, current_PRNG_key=PRNG.get_random_key() + ) from chiron.integrators import LangevinIntegrator - from chiron.reporters import SimulationReporter + from chiron.reporters import LangevinDynamicsReporter, BaseReporter id = uuid.uuid4() - h5_file = f"test_{id}.h5" - reporter = SimulationReporter(f"{prep_temp_dir}/{h5_file}", 1) + wd = prep_temp_dir.join(f"_test_{id}") + BaseReporter.set_directory(wd) + reporter = LangevinDynamicsReporter() integrator = LangevinIntegrator( - stepsize=0.2 * unit.femtosecond, reporter=reporter, save_frequency=1 + stepsize=2 * unit.femtosecond, reporter=reporter, report_frequency=1 ) - r = integrator.run( + integrator.run( sampler_state, thermodynamic_state, n_steps=5, ) - + integrator.reporter.flush_buffer() import jax.numpy as jnp import h5py - h5 = h5py.File(f"{prep_temp_dir}/{h5_file}", "r") + h5 = h5py.File(f"{wd}/{LangevinDynamicsReporter.get_name()}.h5", "r") keys = h5.keys() - assert "energy" in keys, "Energy not in keys" + assert "potential_energy" in keys, "Energy not in keys" assert "step" in keys, "Step not in keys" - assert "traj" in keys, "Traj not in keys" + assert "traj" not in keys, "Traj is not in keys" - energy = h5["energy"][:] + energy = h5["potential_energy"][:] print(energy) reference_energy = jnp.array( - [0.00019308, 0.00077772, 0.00174247, 0.00307798, 0.00479007] + [0.03551735, 0.1395877, 0.30911613, 0.5495938, 0.85149795] ) assert jnp.allclose(energy, reference_energy) @@ -99,23 +106,23 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_LangevinDynamics # Initalize the sampler and thermodynamic state from chiron.states import ThermodynamicState, SamplerState + from chiron.utils import PRNG + PRNG.set_seed(1234) thermodynamic_state = ThermodynamicState( harmonic_potential, temperature=300 * unit.kelvin, volume=30 * (unit.angstrom**3), ) - sampler_state = SamplerState(ho.positions) + sampler_state = SamplerState(ho.positions, current_PRNG_key=PRNG.get_random_key()) # Initalize the move set (here only LangevinDynamicsMove) and reporter - from chiron.reporters import SimulationReporter + from chiron.reporters import LangevinDynamicsReporter, BaseReporter - simulation_reporter = SimulationReporter( - f"{prep_temp_dir}/test_{uuid.uuid4()}.h5", None, 1 - ) - langevin_move = LangevinDynamicsMove( - nr_of_steps=10, seed=1234, simulation_reporter=simulation_reporter - ) + BaseReporter.set_directory(prep_temp_dir) + + simulation_reporter = LangevinDynamicsReporter(1) + langevin_move = LangevinDynamicsMove(nr_of_steps=10, reporter=simulation_reporter) move_set = MoveSchedule([("LangevinMove", langevin_move)]) @@ -157,20 +164,23 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_MetropolisDispla temperature=300 * unit.kelvin, volume=30 * (unit.angstrom**3), ) - sampler_state = SamplerState(ho.positions) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + sampler_state = SamplerState(ho.positions, current_PRNG_key=PRNG.get_random_key()) # Initalize the move set and reporter - from chiron.reporters import SimulationReporter + from chiron.reporters import MCReporter, BaseReporter - simulation_reporter = SimulationReporter( - f"{prep_temp_dir}/test_{uuid.uuid4()}.h5", 1 - ) + wd = prep_temp_dir.join(f"_test_{uuid.uuid4()}") + BaseReporter.set_directory(wd) + simulation_reporter = MCReporter(1) mc_displacement_move = MetropolisDisplacementMove( nr_of_moves=10, displacement_sigma=0.1 * unit.angstrom, atom_subset=[0], - simulation_reporter=simulation_reporter, + reporter=simulation_reporter, ) move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) @@ -212,20 +222,25 @@ def test_sample_from_harmonic_osciallator_array_with_MCMC_classes_and_Metropolis temperature=300 * unit.kelvin, volume=30 * (unit.angstrom**3), ) - sampler_state = SamplerState(ho.positions) + + from chiron.utils import PRNG + + PRNG.set_seed(1234) + sampler_state = SamplerState(ho.positions, current_PRNG_key=PRNG.get_random_key()) # Initalize the move set and reporter - from chiron.reporters import SimulationReporter + from chiron.reporters import MCReporter, BaseReporter - simulation_reporter = SimulationReporter( - f"{prep_temp_dir}/test_{uuid.uuid4()}.h5", 1 - ) + wd = prep_temp_dir.join(f"_test_{uuid.uuid4()}") + BaseReporter.set_directory(wd) + + simulation_reporter = MCReporter(1) mc_displacement_move = MetropolisDisplacementMove( nr_of_moves=10, displacement_sigma=0.1 * unit.angstrom, atom_subset=None, - simulation_reporter=simulation_reporter, + reporter=simulation_reporter, ) move_set = MoveSchedule([("MetropolisDisplacementMove", mc_displacement_move)]) diff --git a/chiron/tests/test_minization.py b/chiron/tests/test_minization.py index 3d89854..cf0cf4e 100644 --- a/chiron/tests/test_minization.py +++ b/chiron/tests/test_minization.py @@ -16,8 +16,13 @@ def test_minimization(): cutoff = unit.Quantity(1.0, unit.nanometer) lj_potential = LJPotential(lj_fluid.topology, cutoff=cutoff) + from chiron.utils import PRNG + + PRNG.set_seed(1234) sampler_state = SamplerState( - lj_fluid.positions, box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors() + lj_fluid.positions, + current_PRNG_key=PRNG.get_random_key(), + box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors(), ) # use parilist nbr_list = PairList(OrthogonalPeriodicSpace(), cutoff=cutoff) @@ -92,10 +97,14 @@ def test_minimize_two_particles(): lj_potential = LJPotential(None, sigma=sigma, epsilon=epsilon, cutoff=cutoff) coordinates = jnp.array([[0.0, 0.0, 0.0], [0.9, 0.0, 0.0]]) + from chiron.utils import PRNG + + PRNG.set_seed(1234) # define the sampler state sampler_state = SamplerState( x0=coordinates * unit.nanometer, + current_PRNG_key=PRNG.get_random_key(), box_vectors=jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) * unit.nanometer, ) diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index ce74e1d..ef31a84 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -14,6 +14,7 @@ def setup_sampler() -> Tuple[NeighborListNsqrd, MultiStateSampler]: from openmm import unit from chiron.mcmc import LangevinDynamicsMove from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace + from chiron.reporters import MultistateReporter, BaseReporter sigma = 0.34 * unit.nanometer cutoff = 3.0 * sigma @@ -23,9 +24,12 @@ def setup_sampler() -> Tuple[NeighborListNsqrd, MultiStateSampler]: OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 ) - move = LangevinDynamicsMove(stepsize=2.0 * unit.femtoseconds, nr_of_steps=500) + move = LangevinDynamicsMove(stepsize=1.0 * unit.femtoseconds, nr_of_steps=100) + BaseReporter.set_directory("multistate_test") + reporter = MultistateReporter() + reporter.reset_reporter_file() - multistate_sampler = MultiStateSampler(mcmc_moves=move) + multistate_sampler = MultiStateSampler(mcmc_moves=move, reporter=reporter) return nbr_list, multistate_sampler @@ -59,7 +63,11 @@ def ho_multistate_sampler_multiple_minima() -> MultiStateSampler: ) for x0 in x0s ] - sampler_state = [SamplerState(ho.positions) for _ in x0s] + from chiron.utils import PRNG + + PRNG.set_seed(1234) + + sampler_state = [SamplerState(ho.positions, PRNG.get_random_key()) for _ in x0s] nbr_list, multistate_sampler = setup_sampler() multistate_sampler.create( thermodynamic_states=thermodynamic_states, @@ -102,8 +110,14 @@ def ho_multistate_sampler_multiple_ks() -> MultiStateSampler: from loguru import logger as log log.info(f"Initialize harmonic oscillator with {n_states} states and ks {Ks}") + from chiron.utils import PRNG + + PRNG.set_seed(1234) - sampler_state = [SamplerState(ho.positions) for _ in sigmas] + sampler_state = [ + SamplerState(ho.positions, current_PRNG_key=PRNG.get_random_key()) + for _ in sigmas + ] import numpy as np f_i = np.array( @@ -198,7 +212,7 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): print(f"Analytical free energy difference: {ho_sampler.delta_f_ij_analytical[0]}") - n_iteratinos = 25 + n_iteratinos = 250 ho_sampler.run(n_iteratinos) # check that we have the correct number of iterations, replicas and states @@ -207,11 +221,12 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): assert ho_sampler.n_replicas == 4 assert ho_sampler.n_states == 4 + u_kn = ho_sampler.reporter.get_property("u_kn") + assert u_kn.shape == (n_iteratinos, 4, 4) # check that the free energies are correct print(ho_sampler.analytical_f_i) + # [ 0. , -0.28593054, -0.54696467, -0.78709279] print(ho_sampler.delta_f_ij_analytical) - print(ho_sampler._last_mbar_f_k_offline) + print(ho_sampler.f_k) - assert np.allclose( - ho_sampler.delta_f_ij_analytical[0], ho_sampler._last_mbar_f_k_offline, atol=0.1 - ) + assert np.allclose(ho_sampler.delta_f_ij_analytical[0], ho_sampler.f_k, atol=0.1) diff --git a/chiron/tests/test_pairs.py b/chiron/tests/test_pairs.py index 4802bb6..fb2bf2c 100644 --- a/chiron/tests/test_pairs.py +++ b/chiron/tests/test_pairs.py @@ -95,8 +95,13 @@ def test_neighborlist_pair(): coordinates = jnp.array([[0, 0, 0], [1, 0, 0]]) box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + state = SamplerState( x0=unit.Quantity(coordinates, unit.nanometer), + current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) @@ -203,8 +208,14 @@ def test_inputs(): nbr_list.build_from_state(123) coordinates = jnp.array([[1, 2, 3], [0, 0, 0]]) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + state = SamplerState( - x0=unit.Quantity(coordinates, unit.nanometer), box_vectors=None + x0=unit.Quantity(coordinates, unit.nanometer), + current_PRNG_key=PRNG.get_random_key(), + box_vectors=None, ) # check that boxvectors are defined in the state @@ -271,8 +282,13 @@ def test_neighborlist_pair_multiple_particles(): coordinates = jnp.stack(coord_mesh.reshape(3, -1), axis=1, dtype=jnp.float32) box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + state = SamplerState( x0=unit.Quantity(coordinates, unit.nanometer), + current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) @@ -341,8 +357,13 @@ def test_pairlist_pair(): coordinates = jnp.array([[0, 0, 0], [1, 0, 0]]) box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + state = SamplerState( x0=unit.Quantity(coordinates, unit.nanometer), + current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) @@ -390,8 +411,13 @@ def test_pair_list_multiple_particles(): coordinates = jnp.stack(coord_mesh.reshape(3, -1), axis=1, dtype=jnp.float32) box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + from chiron.utils import PRNG + + PRNG.set_seed(1234) + state = SamplerState( x0=unit.Quantity(coordinates, unit.nanometer), + current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity(box_vectors, unit.nanometer), ) diff --git a/chiron/tests/test_potential.py b/chiron/tests/test_potential.py index 8eb97fd..230fa1c 100644 --- a/chiron/tests/test_potential.py +++ b/chiron/tests/test_potential.py @@ -33,7 +33,9 @@ def test_neural_network_pairlist(): cutoffs = [0.2, 0.1] expected_pairs = [(1, 1), (0, 0)] for cutoff, expected in zip(cutoffs, expected_pairs): - distances, displacement_vectors, pairlist = nn_potential.compute_pairlist(positions, cutoff) + distances, displacement_vectors, pairlist = nn_potential.compute_pairlist( + positions, cutoff + ) assert pairlist[0].size == expected[0] and pairlist[1].size == expected[1] # Test with ethanol molecule @@ -44,7 +46,9 @@ def test_neural_network_pairlist(): # Test compute_pairlist method cutoff = 0.2 - distances, displacement_vectors, pairlist = nn_potential.compute_pairlist(positions, cutoff) + distances, displacement_vectors, pairlist = nn_potential.compute_pairlist( + positions, cutoff + ) print(pairlist) assert ( pairlist[0].size == 12 and pairlist[1].size == 12 @@ -92,8 +96,9 @@ def test_harmonic_oscillator_potential(): forces = harmonic_potential.compute_force(positions_without_unit) assert forces.shape == positions_without_unit.shape, "Forces shape mismatch." + def test_harmonic_oscillator_input_checking(): - #topology check + # topology check with pytest.raises(TypeError): HarmonicOscillatorPotential(1) with pytest.raises(TypeError): @@ -104,14 +109,15 @@ def test_harmonic_oscillator_input_checking(): HarmonicOscillatorPotential(None, U0=1.0) with pytest.raises(ValueError): - HarmonicOscillatorPotential(None, k=1.0*unit.nanometer) + HarmonicOscillatorPotential(None, k=1.0 * unit.nanometer) with pytest.raises(ValueError): - HarmonicOscillatorPotential(None, x0=1.0*unit.kilocalories_per_mole) + HarmonicOscillatorPotential(None, x0=1.0 * unit.kilocalories_per_mole) with pytest.raises(ValueError): - HarmonicOscillatorPotential(None, U0=1.0*unit.nanometer) + HarmonicOscillatorPotential(None, U0=1.0 * unit.nanometer) + def test_lj_input_checking(): - #topology check + # topology check with pytest.raises(TypeError): LJPotential(1) with pytest.raises(TypeError): @@ -122,28 +128,30 @@ def test_lj_input_checking(): LJPotential(None, cutoff=1.0) with pytest.raises(ValueError): - LJPotential(None, sigma=1.0*unit.kilocalories_per_mole) + LJPotential(None, sigma=1.0 * unit.kilocalories_per_mole) with pytest.raises(ValueError): - LJPotential(None, epsilon=1.0*unit.nanometer) + LJPotential(None, epsilon=1.0 * unit.nanometer) with pytest.raises(ValueError): - LJPotential(None, cutoff=1.0*unit.kilocalories_per_mole) + LJPotential(None, cutoff=1.0 * unit.kilocalories_per_mole) from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace + positions = jnp.array([[0, 0, 0], [1, 0, 0]]) box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) - lj = LJPotential(None, sigma=1.0*unit.nanometer) - nbr_list = NeighborListNsqrd(OrthogonalPeriodicSpace(), cutoff=2.0*unit.nanometer) + lj = LJPotential(None, sigma=1.0 * unit.nanometer) + nbr_list = NeighborListNsqrd(OrthogonalPeriodicSpace(), cutoff=2.0 * unit.nanometer) - #capture the error associated with not building the neighborlist + # capture the error associated with not building the neighborlist with pytest.raises(ValueError): lj.compute_energy(positions, nbr_list) nbr_list.build(positions, box_vectors) - #capture the error associated cutoffs not matching + # capture the error associated cutoffs not matching with pytest.raises(ValueError): lj.compute_energy(positions, nbr_list) + def test_lennard_jones(): # This will evaluate two LJ particles to ensure the energy and force are correct from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace @@ -156,27 +164,46 @@ def test_lennard_jones(): box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) space = OrthogonalPeriodicSpace() - lj_pot = LJPotential(None, unit.Quantity(sigma, unit.nanometer), unit.Quantity(epsilon, unit.kilojoules_per_mole), - unit.Quantity(cutoff, unit.nanometer)) + lj_pot = LJPotential( + None, + unit.Quantity(sigma, unit.nanometer), + unit.Quantity(epsilon, unit.kilojoules_per_mole), + unit.Quantity(cutoff, unit.nanometer), + ) + from chiron.utils import PRNG + + PRNG.set_seed(1234) for i in range(1, 11): positions = jnp.array([[0, 0, 0], [i * 0.25 * 2 ** (1 / 6), 0, 0]]) - state = SamplerState(x0=unit.Quantity(positions, unit.nanometer), box_vectors = unit.Quantity(box_vectors, - unit.nanometer)) - nbr_list = NeighborListNsqrd(space, cutoff = unit.Quantity(cutoff, unit.nanometer), skin=unit.Quantity(skin, unit.nanometer), n_max_neighbors=5) + state = SamplerState( + x0=unit.Quantity(positions, unit.nanometer), + current_PRNG_key=PRNG.get_random_key(), + box_vectors=unit.Quantity(box_vectors, unit.nanometer), + ) + nbr_list = NeighborListNsqrd( + space, + cutoff=unit.Quantity(cutoff, unit.nanometer), + skin=unit.Quantity(skin, unit.nanometer), + n_max_neighbors=5, + ) nbr_list.build_from_state(state) # first use the pairlist energy_chiron = lj_pot.compute_energy(positions) energy_chiron_nbr = lj_pot.compute_energy(positions, nbr_list) - displacement_vector = positions[0]-positions[1] + displacement_vector = positions[0] - positions[1] dist = jnp.linalg.norm(displacement_vector) - energy_analytical = 4.0*epsilon*((sigma/dist)**12-(sigma/dist)**6) + energy_analytical = 4.0 * epsilon * ((sigma / dist) ** 12 - (sigma / dist) ** 6) - assert jnp.isclose(energy_chiron, energy_analytical), "Energy from chiron using a pair list does not match the analytical energy calculation" - assert jnp.isclose(energy_chiron_nbr, energy_analytical), "Energy from chiron using a neighbor list does not match the analytical energy calculation" + assert jnp.isclose( + energy_chiron, energy_analytical + ), "Energy from chiron using a pair list does not match the analytical energy calculation" + assert jnp.isclose( + energy_chiron_nbr, energy_analytical + ), "Energy from chiron using a neighbor list does not match the analytical energy calculation" force_chiron = lj_pot.compute_force(positions) force_chiron_nbr = lj_pot.compute_force(positions, nbr_list) @@ -185,14 +212,19 @@ def test_lennard_jones(): force_chiron_analytical = lj_pot.compute_force_analytical(positions) force = ( - 24 - * (epsilon / (dist * dist)) - * (2 * (sigma / dist) ** 12 - (sigma / dist) ** 6) - ) * displacement_vector + 24 + * (epsilon / (dist * dist)) + * (2 * (sigma / dist) ** 12 - (sigma / dist) ** 6) + ) * displacement_vector forces_analytical = jnp.array([force, -force]) - assert jnp.allclose(force_chiron, forces_analytical, atol=1e-5), "Force from chiron using pair list does not match analytical force" - assert jnp.allclose(force_chiron_nbr, forces_analytical, atol=1e-5), "Force from chiron using neighbor list does not match analytical force" - assert jnp.allclose(force_chiron_analytical, forces_analytical, atol=1e-5), "Force from chiron analytical using pair list does not match analytical force" - + assert jnp.allclose( + force_chiron, forces_analytical, atol=1e-5 + ), "Force from chiron using pair list does not match analytical force" + assert jnp.allclose( + force_chiron_nbr, forces_analytical, atol=1e-5 + ), "Force from chiron using neighbor list does not match analytical force" + assert jnp.allclose( + force_chiron_analytical, forces_analytical, atol=1e-5 + ), "Force from chiron analytical using pair list does not match analytical force" diff --git a/chiron/tests/test_states.py b/chiron/tests/test_states.py index d94869a..4499f18 100644 --- a/chiron/tests/test_states.py +++ b/chiron/tests/test_states.py @@ -24,8 +24,11 @@ def test_initialize_state(): assert state.pressure is None assert state.volume == 30 * (unit.angstrom**3) assert state.nr_of_particles == 1 + from chiron.utils import PRNG - sampler_state = SamplerState(ho.positions) + PRNG.set_seed(1234) + + sampler_state = SamplerState(ho.positions, current_PRNG_key=PRNG.get_random_key()) assert jnp.allclose( sampler_state.x0, @@ -41,9 +44,13 @@ def test_sampler_state_conversion(): from chiron.states import SamplerState from openmm import unit import jax.numpy as jnp + from chiron.utils import PRNG + + PRNG.set_seed(1234) sampler_state = SamplerState( - unit.Quantity(jnp.array([[10.0, 10.0, 10.0]]), unit.nanometer) + unit.Quantity(jnp.array([[10.0, 10.0, 10.0]]), unit.nanometer), + current_PRNG_key=PRNG.get_random_key(), ) assert jnp.allclose( @@ -52,7 +59,8 @@ def test_sampler_state_conversion(): ) sampler_state = SamplerState( - unit.Quantity(jnp.array([[10.0, 10.0, 10.0]]), unit.angstrom) + unit.Quantity(jnp.array([[10.0, 10.0, 10.0]]), unit.angstrom), + current_PRNG_key=PRNG.get_random_key(), ) assert jnp.allclose( @@ -66,6 +74,9 @@ def test_sampler_state_inputs(): from openmm import unit import jax.numpy as jnp import pytest + from chiron.utils import PRNG + + PRNG.set_seed(1234) # test input of positions # should have units @@ -73,19 +84,24 @@ def test_sampler_state_inputs(): SamplerState(x0=jnp.array([1, 2, 3])) # throw and error because of incompatible units with pytest.raises(ValueError): - SamplerState(x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.radians)) + SamplerState( + x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.radians), + current_PRNG_key=PRNG.get_random_key(), + ) # test input of velocities # velocities should have units with pytest.raises(TypeError): SamplerState( x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + current_PRNG_key=PRNG.get_random_key(), velocities=jnp.array([1, 2, 3]), ) # velocities should have units of distance/time with pytest.raises(ValueError): SamplerState( x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + current_PRNG_key=PRNG.get_random_key(), velocities=unit.Quantity(jnp.array([1, 2, 3]), unit.nanometers), ) @@ -94,12 +110,14 @@ def test_sampler_state_inputs(): with pytest.raises(TypeError): SamplerState( x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + current_PRNG_key=PRNG.get_random_key(), box_vectors=jnp.array([1, 2, 3]), ) # box_vectors should have units of distance with pytest.raises(ValueError): SamplerState( x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity( jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), unit.radians ), @@ -108,6 +126,7 @@ def test_sampler_state_inputs(): with pytest.raises(ValueError): SamplerState( x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + current_PRNG_key=PRNG.get_random_key(), box_vectors=unit.Quantity( jnp.array([[1, 0, 0], [0, 1, 0]]), unit.nanometers ), @@ -122,6 +141,7 @@ def test_sampler_state_inputs(): # check openmm_box conversion: state = SamplerState( x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + current_PRNG_key=PRNG.get_random_key(), box_vectors=openmm_box, ) assert jnp.allclose( @@ -135,7 +155,9 @@ def test_sampler_state_inputs(): # openmm box vectors end up as a list with contents; check to make sure we capture an error if we pass a bad list with pytest.raises(TypeError): SamplerState( - x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), box_vectors=[123] + x0=unit.Quantity(jnp.array([[1, 2, 3]]), unit.nanometers), + current_PRNG_key=PRNG.get_random_key(), + box_vectors=[123], ) @@ -146,6 +168,7 @@ def test_reduced_potential(): from chiron.potential import HarmonicOscillatorPotential import jax.numpy as jnp from openmmtools.testsystems import HarmonicOscillator + from chiron.utils import PRNG ho = HarmonicOscillator() potential = HarmonicOscillatorPotential(topology=ho.topology, k=ho.K, U0=ho.U0) @@ -153,7 +176,7 @@ def test_reduced_potential(): state = ThermodynamicState( potential, temperature=300 * unit.kelvin, volume=30 * (unit.angstrom**3) ) - sampler_state = SamplerState(ho.positions) + sampler_state = SamplerState(ho.positions, current_PRNG_key=PRNG.get_random_key()) reduced_e = state.get_reduced_potential(sampler_state) assert reduced_e == 0.0 diff --git a/chiron/tests/test_testsystems.py b/chiron/tests/test_testsystems.py index 5f14269..3506b77 100644 --- a/chiron/tests/test_testsystems.py +++ b/chiron/tests/test_testsystems.py @@ -182,8 +182,12 @@ def test_LJ_fluid(): dispersion_correction=False, shift=False, ) + from chiron.utils import PRNG + + PRNG.set_seed(1234) state = SamplerState( x0=lj_openmm.positions, + current_PRNG_key=PRNG.get_random_key(), box_vectors=lj_openmm.system.getDefaultPeriodicBoxVectors(), ) diff --git a/chiron/tests/test_utils.py b/chiron/tests/test_utils.py index 2cef5b3..59f6d10 100644 --- a/chiron/tests/test_utils.py +++ b/chiron/tests/test_utils.py @@ -22,48 +22,113 @@ def test_get_list_of_mass(): assert np.isclose(c, expected[0]), "Incorrect masses returned" -def test_reporter(): - """Read in a reporter file and check its contend.""" - import h5py - import numpy as np - from chiron.utils import get_data_file_path +import pytest +from .test_multistate import ho_multistate_sampler_multiple_ks + + +@pytest.fixture(scope="session") +def prep_temp_dir(tmpdir_factory): + """Create a temporary directory for the test.""" + tmpdir = tmpdir_factory.mktemp("test_reporter") + return tmpdir - h5_file = "test_md.h5" - h5_test_file = get_data_file_path(h5_file) - print(h5_test_file) - # Read the h5 file manually and check values - h5 = h5py.File(h5_test_file, "r") - keys = h5.keys() +def test_reporter(prep_temp_dir, ho_multistate_sampler_multiple_ks): + from chiron.integrators import LangevinIntegrator + from chiron.potential import HarmonicOscillatorPotential + from openmm import unit + + from openmmtools.testsystems import HarmonicOscillator - assert "energy" in keys, "Energy not in keys" - assert "step" in keys, "Step not in keys" - assert "traj" in keys, "Traj not in keys" + ho = HarmonicOscillator() + potential = HarmonicOscillatorPotential(ho.topology) + from chiron.utils import PRNG - energy = h5["energy"][:5] - reference_energy = np.array( - [1.9328993e-06, 2.0289978e-02, 8.3407544e-02, 1.7832418e-01, 2.8428176e-01] + PRNG.set_seed(1234) + + from chiron.states import SamplerState, ThermodynamicState + + thermodynamic_state = ThermodynamicState( + potential=potential, temperature=300 * unit.kelvin ) - assert np.allclose( - energy, - reference_energy, - ), "Energy not correct" - h5.close() + sampler_state = SamplerState(ho.positions, PRNG.get_random_key()) + + from chiron.reporters import LangevinDynamicsReporter + from chiron.reporters import BaseReporter + + # set up reporter directory + BaseReporter.set_directory(prep_temp_dir) + + # test langevin reporter + reporter = LangevinDynamicsReporter("langevin_test") + reporter.reset_reporter_file() + + integrator = LangevinIntegrator(reporter=reporter, report_frequency=1) + integrator.run( + sampler_state, + thermodynamic_state, + n_steps=20, + ) + import numpy as np + + reporter.flush_buffer() + + # test for available keys + assert "potential_energy" in reporter.get_available_keys() + assert "step" in reporter.get_available_keys() + + # test for property + pot_energy = reporter.get_property("potential_energy") + np.allclose( + pot_energy, + np.array( + [ + 8.8336921e-05, + 3.5010747e-04, + 7.8302569e-04, + 1.4021739e-03, + 2.1981772e-03, + 3.1483083e-03, + 4.2442558e-03, + 5.4960307e-03, + 6.8922052e-03, + 8.4171966e-03, + 1.0099258e-02, + 1.1929392e-02, + 1.3859766e-02, + 1.5893064e-02, + 1.8023632e-02, + 2.0219875e-02, + 2.2491256e-02, + 2.4893485e-02, + 2.7451182e-02, + 3.0140089e-02, + ], + dtype=np.float32, + ), + ) - # Use the reporter class and check values - from chiron.reporters import SimulationReporter + # test that xtc and log file is written + import os - reporter = SimulationReporter(h5_test_file, None, 1) - assert np.allclose(reference_energy, reporter.get_property("energy")[:5]) - reporter.close() - # test the topology - from openmmtools.testsystems import HarmonicOscillatorArray + assert os.path.exists(reporter.xtc_file_path) + assert os.path.exists(reporter.log_file_path) - ho = HarmonicOscillatorArray() - topology = ho.topology - reporter = SimulationReporter(h5_test_file, topology, 1) - traj = reporter.get_mdtraj_trajectory() - import mdtraj as md + # test multistate reporter + ho_sampler = ho_multistate_sampler_multiple_ks + ho_sampler._reporter.reset_reporter_file() + ho_sampler.run(5) - assert isinstance(traj, md.Trajectory), "Trajectory not correct type" + assert len(ho_sampler._reporter._replica_reporter.keys()) == 4 + assert ho_sampler._reporter._replica_reporter.get("replica_0") + assert ho_sampler._reporter._default_properties == [ + "positions", + "box_vectors", + "u_kn", + "state_index", + "step", + ] + u_kn = ho_sampler._reporter.get_property("u_kn") + assert u_kn.shape == (4, 4, 6) + assert os.path.exists(ho_sampler._reporter.log_file_path) diff --git a/chiron/utils.py b/chiron/utils.py index abe2ec9..a9a7da0 100644 --- a/chiron/utils.py +++ b/chiron/utils.py @@ -1,5 +1,40 @@ from openmm.app import Topology from openmm import unit +from jax import random + + +class PRNG: + _key: random.PRNGKey + _seed: int + + def __init__(self) -> None: + """ + A PRNG class that can be used to generate random numbers in JAX. + The intended use case is to initialize new PRN streams in the `SamplerState` class. + + Example: + -------- + from chiron.utils import PRNG + from chiron.states import SamplerState + from openmmtools.testsystems import HarmonicOscillator + + ho = HarmonicOscillator() + PRNG.set_seed(1234) + sampler_state = [SamplerState(ho.positions, PRNG.get_random_key()) for _ in x0s] + + """ + + pass + @classmethod + def set_seed(cls, seed: int) -> None: + cls._seed = seed + cls._key = random.PRNGKey(seed) + + @classmethod + def get_random_key(cls) -> int: + key, subkey = random.split(cls._key) + cls._key = key + return subkey def get_data_file_path(relative_path: str) -> str: