Skip to content

Commit

Permalink
Adding Simulation and MultistateSimulation reporter (#18)
Browse files Browse the repository at this point in the history
* Add MBAREstimator class and MultistateReporter class

* Add MBAR class and update free energy estimators

* Update LangevinIntegrator class in integrators.py

* Update MCMCMove and LangevinDynamicsMove constructors

* Fix MBAREstimator initialization in MultiStateSampler

* Refactor reporters and tests

* Add new reporters and update tests

* Wrap and rebuild neighborlist in LangevinIntegrator

* Refactor code to transpose u_kn array in MultiStateSampler and _SimulationReporter

* Refactor _SimulationReporter class to improve code readability and maintainability

* Refactor code and add random seed functionality

* Fix reporter visibility and add test for multistate reporter
  • Loading branch information
wiederm authored Jan 18, 2024
1 parent c82cd8f commit ab0d114
Show file tree
Hide file tree
Showing 20 changed files with 1,151 additions and 503 deletions.
6 changes: 3 additions & 3 deletions Examples/LJ_langevin.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,20 @@
# build the neighbor list from the sampler state
nbr_list.build_from_state(sampler_state)

from chiron.reporters import SimulationReporter
from chiron.reporters import _SimulationReporter

# initialize a reporter to save the simulation data
filename = "test_lj.h5"
import os

if os.path.isfile(filename):
os.remove(filename)
reporter = SimulationReporter("test_lj.h5", lj_fluid.topology, 1)
reporter = _SimulationReporter("test_lj.h5", lj_fluid.topology, 1)

from chiron.integrators import LangevinIntegrator

# initialize the Langevin integrator
integrator = LangevinIntegrator(reporter=reporter, save_frequency=100)
integrator = LangevinIntegrator(reporter=reporter, report_frequency=100)
print("init_energy: ", lj_potential.compute_energy(sampler_state.x0, nbr_list))

integrator.run(
Expand Down
6 changes: 3 additions & 3 deletions Examples/LJ_mcmove.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,23 @@
# build the neighbor list from the sampler state
nbr_list.build_from_state(sampler_state)

from chiron.reporters import SimulationReporter
from chiron.reporters import _SimulationReporter

# initialize a reporter to save the simulation data
filename = "test_lj.h5"
import os

if os.path.isfile(filename):
os.remove(filename)
reporter = SimulationReporter("test_mc_lj.h5", lj_fluid.topology, 1)
reporter = _SimulationReporter("test_mc_lj.h5", lj_fluid.topology, 1)

from chiron.mcmc import MetropolisDisplacementMove

mc_move = MetropolisDisplacementMove(
seed=1234,
displacement_sigma=0.01 * unit.nanometer,
nr_of_moves=1000,
simulation_reporter=reporter,
reporter=reporter,
)

mc_move.run(sampler_state, thermodynamic_state, nbr_list, True)
File renamed without changes.
131 changes: 86 additions & 45 deletions chiron/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from jax import random
from openmm import unit
from .states import SamplerState, ThermodynamicState
from .reporters import SimulationReporter
from .reporters import LangevinDynamicsReporter
from typing import Optional
from .potential import NeuralNetworkPotential
from .neighbors import PairsBase


class LangevinIntegrator:
Expand All @@ -24,8 +26,8 @@ def __init__(
self,
stepsize=1.0 * unit.femtoseconds,
collision_rate=1.0 / unit.picoseconds,
save_frequency: int = 100,
reporter: Optional[SimulationReporter] = None,
report_frequency: int = 100,
reporter: Optional[LangevinDynamicsReporter] = None,
save_traj_in_memory: bool = False,
) -> None:
"""
Expand All @@ -37,29 +39,33 @@ def __init__(
Time step of integration with units of time. Default is 1.0 * unit.femtoseconds.
collision_rate : unit.Quantity, optional
Collision rate for the Langevin dynamics, with units 1/time. Default is 1.0 / unit.picoseconds.
save_frequency : int, optional
report_frequency : int, optional
Frequency of saving the simulation data. Default is 100.
reporter : SimulationReporter, optional
Reporter object for saving the simulation data. Default is None.
save_traj_in_memory : bool
Whether to save the trajectory in memory. For debugging purposes only.
save_traj_in_memory: bool
Flag indicating whether to save the trajectory in memory.
Default is False. NOTE: Only for debugging purposes.
"""
from loguru import logger as log

self.kB = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA
log.info(f"stepsize = {stepsize}")
log.info(f"collision_rate = {collision_rate}")
log.info(f"save_frequency = {save_frequency}")
log.info(f"report_frequency = {report_frequency}")

self.stepsize = stepsize
self.collision_rate = collision_rate
if reporter is not None:
log.info(f"Using reporter {reporter} saving to {reporter.filename}")
if reporter:
log.info(
f"Using reporter {reporter} saving trajectory to {reporter.xtc_file_path}"
)
log.info(f"and logging to {reporter.log_file_path}")
self.reporter = reporter
self.save_frequency = save_frequency
self.report_frequency = report_frequency
self.velocities = None
self.save_traj_in_memory = save_traj_in_memory
self.traj = []
self.velocities = None

def set_velocities(self, vel: unit.Quantity) -> None:
"""
Expand All @@ -77,8 +83,7 @@ def run(
sampler_state: SamplerState,
thermodynamic_state: ThermodynamicState,
n_steps: int = 5_000,
key=random.PRNGKey(0),
nbr_list=None,
nbr_list: Optional[PairsBase] = None,
progress_bar=False,
):
"""
Expand All @@ -92,9 +97,7 @@ def run(
The thermodynamic state of the system, including temperature and potential.
n_steps : int, optional
Number of simulation steps to perform.
key : jax.random.PRNGKey, optional
Random key for generating random numbers.
nbr_list : NeighborListNsqrd, optional
nbr_list : PairBase, optional
Neighbor list for the system.
progress_bar : bool, optional
Flag indicating whether to display a progress bar during integration.
Expand All @@ -116,8 +119,11 @@ def run(
log.debug("Running Langevin dynamics")
log.debug(f"n_steps = {n_steps}")
log.debug(f"temperature = {temperature}")
log.debug(f"Using seed: {key}")

# Initialize the random number generator
key = sampler_state.new_PRNG_key

# Convert to dimensionless quantities
kbT_unitless = (self.kB * temperature).value_in_unit_system(unit.md_unit_system)
mass_unitless = jnp.array(mass.value_in_unit_system(unit.md_unit_system))[
:, None
Expand All @@ -127,22 +133,24 @@ def run(
collision_rate_unitless = self.collision_rate.value_in_unit_system(
unit.md_unit_system
)
a = jnp.exp((-collision_rate_unitless * stepsize_unitless))
b = jnp.sqrt(1 - jnp.exp(-2 * collision_rate_unitless * stepsize_unitless))

# Initialize velocities
if self.velocities is None:
v0 = sigma_v * random.normal(key, x0.shape)
else:
v0 = self.velocities.value_in_unit_system(unit.md_unit_system)
# Convert to dimensionless quantities
a = jnp.exp((-collision_rate_unitless * stepsize_unitless))
b = jnp.sqrt(1 - jnp.exp(-2 * collision_rate_unitless * stepsize_unitless))

x = x0
v = v0

if nbr_list is not None:
nbr_list.build_from_state(sampler_state)

F = potential.compute_force(x, nbr_list)

# propagation loop
for step in tqdm(range(n_steps)) if self.progress_bar else range(n_steps):
key, subkey = random.split(key)
# v
Expand All @@ -151,46 +159,79 @@ def run(
x += (stepsize_unitless * 0.5) * v

if nbr_list is not None:
x = nbr_list.space.wrap(x)
# check if we need to rebuild the neighborlist after moving the particles
if nbr_list.check(x):
nbr_list.build(x, self.box_vectors)
x = self._wrap_and_rebuild_neighborlist(x, nbr_list)
# o
random_noise_v = random.normal(subkey, x.shape)
v = (a * v) + (b * sigma_v * random_noise_v)

x += (stepsize_unitless * 0.5) * v
if nbr_list is not None:
x = nbr_list.space.wrap(x)
# check if we need to rebuild the neighborlist after moving the particles
if nbr_list.check(x):
nbr_list.build(x, self.box_vectors)
x = self._wrap_and_rebuild_neighborlist(x, nbr_list)

F = potential.compute_force(x, nbr_list)
# v
v += (stepsize_unitless * 0.5) * F / mass_unitless

if step % self.save_frequency == 0:
# log.debug(f"Saving at step {step}")
# check if reporter is attribute of the class
# log.debug(f"step {step} energy {potential.compute_energy(x, nbr_list)}")
# log.debug(f"step {step} force {F}")

if step % self.report_frequency == 0:
if hasattr(self, "reporter") and self.reporter is not None:
d = {
"traj": x,
"energy": potential.compute_energy(x, nbr_list),
"step": step,
}
if nbr_list is not None:
d["box_vectors"] = nbr_list.space.box_vectors

# log.debug(d)
self.reporter.report(d)
self._report(x, potential, nbr_list, step)

if self.save_traj_in_memory:
self.traj.append(x)

log.debug("Finished running Langevin dynamics")
# save the final state of the simulation in the sampler_state object
sampler_state.x0 = x
# self.reporter.close()
sampler_state.v0 = v

def _wrap_and_rebuild_neighborlist(self, x: jnp.array, nbr_list: PairsBase):
"""
Wrap the coordinates and rebuild the neighborlist if necessary.
Parameters
----------
x: jnp.array
The coordinates of the particles.
nbr_list: PairsBsse
The neighborlist object.
"""

x = nbr_list.space.wrap(x)
# check if we need to rebuild the neighborlist after moving the particles
if nbr_list.check(x):
nbr_list.build(x, self.box_vectors)
return x

def _report(
self,
x: jnp.array,
potential: NeuralNetworkPotential,
nbr_list: PairsBase,
step: int,
):
"""
Reports the trajectory, energy, step, and box vectors (if available) to the reporter.
Parameters
----------
x : jnp.array
current coordinate set
potential: NeuralNetworkPotential
potential used to compute the energy and force
nbr_list: PairsBase
The neighbor list
step: int
The current time step.
Returns:
None
"""
d = {
"positions": x,
"potential_energy": potential.compute_energy(x, nbr_list),
"step": step,
}
if nbr_list is not None:
d["box_vectors"] = nbr_list.space.box_vectors

self.reporter.report(d)
Loading

0 comments on commit ab0d114

Please sign in to comment.