Skip to content

Commit

Permalink
Add trajectory reporter to openmm workflow (#1053)
Browse files Browse the repository at this point in the history
* Add trajectory reporter to openmm workflow

* respond to janosh review

* fix test

* slightly stricter asserts in test_trajectory_reporter

---------

Co-authored-by: Janosh Riebesell <[email protected]>
  • Loading branch information
orionarcher and janosh authored Nov 13, 2024
1 parent 071d1c8 commit 0fb73a9
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 29 deletions.
43 changes: 26 additions & 17 deletions src/atomate2/openmm/jobs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from openmm.unit import angstrom, kelvin, picoseconds
from pymatgen.core import Structure

from atomate2.openmm.utils import increment_name, task_reports
from atomate2.openmm.utils import (
PymatgenTrajectoryReporter,
increment_name,
task_reports,
)

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -232,16 +236,18 @@ def make(

structure = self._create_structure(sim, prev_task)

task_doc = self._create_task_doc(
interchange, structure, elapsed_time, dir_name, prev_task
)

# leaving the MDAReporter makes the builders fail
for _ in range(len(sim.reporters)):
reporter = sim.reporters.pop()
if hasattr(reporter, "save"):
reporter.save()
del reporter
del sim

task_doc = self._create_task_doc(
interchange, structure, elapsed_time, dir_name, prev_task
)

# write out task_doc json to output dir
with open(dir_name / "taskdoc.json", "w") as file:
json.dump(task_doc.model_dump(), file, cls=MontyEncoder)
Expand Down Expand Up @@ -308,7 +314,7 @@ def _add_reporters(
if has_steps & (traj_interval > 0):
writer_kwargs = {}
# these are the only file types that support velocities
if traj_file_type in ["h5md", "nc", "ncdf"]:
if traj_file_type in ("h5md", "nc", "ncdf", "json"):
writer_kwargs["velocities"] = report_velocities
writer_kwargs["forces"] = False
elif report_velocities and traj_file_type != "trr":
Expand All @@ -330,17 +336,20 @@ def _add_reporters(
reportInterval=traj_interval,
enforcePeriodicBox=wrap_traj,
)
if report_velocities:
# assert package version

kwargs["writer_kwargs"] = writer_kwargs
warnings.warn(
"Reporting velocities is only supported with the"
"development version of MDAnalysis, >= 2.8.0, "
"proceed with caution.",
stacklevel=1,
)
traj_reporter = MDAReporter(**kwargs)
if traj_file_type == "json":
traj_reporter = PymatgenTrajectoryReporter(**kwargs)
else:
if report_velocities:
# assert package version

kwargs["writer_kwargs"] = writer_kwargs
warnings.warn(
"Reporting velocities is only supported with the"
"development version of MDAnalysis, >= 2.8.0, "
"proceed with caution.",
stacklevel=1,
)
traj_reporter = MDAReporter(**kwargs)

sim.reporters.append(traj_reporter)

Expand Down
149 changes: 146 additions & 3 deletions src/atomate2/openmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
import numpy as np
import openmm.unit as omm_unit
from emmet.core.openmm import OpenMMInterchange
from openmm import LangevinMiddleIntegrator, XmlSerializer
from openmm.app import PDBFile
from openmm import LangevinMiddleIntegrator, State, XmlSerializer
from openmm.app import PDBFile, Simulation
from pymatgen.core.trajectory import Trajectory

if TYPE_CHECKING:
from emmet.core.openmm import OpenMMTaskDocument
Expand Down Expand Up @@ -73,7 +74,7 @@ def download_opls_xml(
submit_button.click()

# Wait for the second page to load
# time.sleep(2) # Adjust this delay as needed based on the loading time
# time.sleep(2) # Adjust this delay as needed based on loading time

# Find and click the "XML" button under Downloads and OpenMM
xml_button = driver.find_element(
Expand Down Expand Up @@ -171,3 +172,145 @@ def openff_to_openmm_interchange(
state=XmlSerializer.serialize(state),
topology=pdb,
)


class PymatgenTrajectoryReporter:
"""Reporter that creates a pymatgen Trajectory from an OpenMM simulation.
Accumulates structures and velocities during the simulation and writes them to a
Trajectory object when the reporter is deleted.
"""

def __init__(
self,
file: str | Path,
reportInterval: int, # noqa: N803
enforcePeriodicBox: bool | None = None, # noqa: N803
) -> None:
"""Initialize the reporter.
Parameters
----------
file : str | Path
The file to write the trajectory to
reportInterval : int
The interval (in time steps) at which to save frames
enforcePeriodicBox : bool | None
Whether to wrap coordinates to the periodic box. If None, determined from
simulation settings.
"""
self._file = file
self._reportInterval = reportInterval
self._enforcePeriodicBox = enforcePeriodicBox
self._topology = None
self._nextModel = 0

# Storage for trajectory data
self._positions: list[np.ndarray] = []
self._velocities: list[np.ndarray] = []
self._lattices: list[np.ndarray] = []
self._frame_properties: list[dict] = []
self._species: list[str] | None = None
self._time_step: float | None = None

def describeNextReport( # noqa: N802
self, simulation: Simulation
) -> tuple[int, bool, bool, bool, bool, bool]:
"""Get information about the next report this object will generate.
Parameters
----------
simulation : Simulation
The Simulation to generate a report for
Returns
-------
tuple[int, bool, bool, bool, bool, bool]
A six element tuple. The first element is the number of steps until the
next report. The remaining elements specify whether that report will
require positions, velocities, forces, energies, and periodic box info.
"""
steps = self._reportInterval - simulation.currentStep % self._reportInterval
return steps, True, True, False, True, self._enforcePeriodicBox

def report(self, simulation: Simulation, state: State) -> None:
"""Generate a report.
Parameters
----------
simulation : Simulation
The Simulation to generate a report for
state : State
The current state of the simulation
"""
if self._nextModel == 0:
self._topology = simulation.topology
self._species = [
atom.element.symbol for atom in simulation.topology.atoms()
]
self._time_step = (
simulation.integrator.getStepSize() * self._reportInterval
).value_in_unit(omm_unit.femtoseconds)

# Get positions and velocities in Angstrom and Angstrom/fs
positions = state.getPositions(asNumpy=True).value_in_unit(omm_unit.angstrom)
velocities = state.getVelocities(asNumpy=True).value_in_unit(
omm_unit.angstrom / omm_unit.femtosecond
)
box_vectors = state.getPeriodicBoxVectors(asNumpy=True).value_in_unit(
omm_unit.angstrom
)

# Get energies in eV
kinetic_energy = (
state.getKineticEnergy() / omm_unit.AVOGADRO_CONSTANT_NA
).value_in_unit(omm_unit.ev)

potential_energy = (
state.getPotentialEnergy() / omm_unit.AVOGADRO_CONSTANT_NA
).value_in_unit(omm_unit.ev)

self._positions.append(positions)
self._velocities.append(velocities)
self._lattices.append(box_vectors)
self._frame_properties.append(
{
"kinetic_energy": kinetic_energy,
"potential_energy": potential_energy,
"total_energy": kinetic_energy + potential_energy,
}
)

self._nextModel += 1

def save(self) -> None:
"""Write accumulated trajectory data to a pymatgen Trajectory object."""
if not self._positions:
return

velocities = [
[tuple(site_vel) for site_vel in frame_vel]
for frame_vel in self._velocities
]

# Format site properties as list of dicts, one per frame
site_properties = []
n_frames = len(self._positions)
site_properties = [{"velocities": velocities[i]} for i in range(n_frames)]

# Create trajectory with positions and lattices
trajectory = Trajectory(
species=self._species,
coords=self._positions,
lattice=self._lattices,
frame_properties=self._frame_properties,
site_properties=site_properties, # Now properly formatted as list of dicts
time_step=self._time_step,
)

# Store trajectory as a class attribute so it can be accessed after deletion
self.trajectory = trajectory

# write out trajectory to a file
with open(self._file, mode="w") as file:
file.write(trajectory.to_json())
51 changes: 45 additions & 6 deletions tests/openmm_md/jobs/test_core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from collections.abc import Callable
from pathlib import Path

import numpy as np
from emmet.core.openmm import OpenMMInterchange
from emmet.core.openmm import OpenMMInterchange, OpenMMTaskDocument
from monty.serialization import loadfn
from openmm import XmlSerializer

from atomate2.openmm.jobs import (
Expand All @@ -12,7 +14,9 @@
)


def test_energy_minimization_maker(interchange, run_job):
def test_energy_minimization_maker(
interchange: OpenMMInterchange, run_job: Callable
) -> None:
state = XmlSerializer.deserialize(interchange.state)
start_positions = state.getPositions(asNumpy=True)

Expand All @@ -28,7 +32,7 @@ def test_energy_minimization_maker(interchange, run_job):
assert (Path(task_doc.calcs_reversed[0].output.dir_name) / "state.csv").exists()


def test_npt_maker(interchange, run_job):
def test_npt_maker(interchange: OpenMMInterchange, run_job: Callable) -> None:
state = XmlSerializer.deserialize(interchange.state)
start_positions = state.getPositions(asNumpy=True)
start_box = state.getPeriodicBoxVectors()
Expand All @@ -47,11 +51,11 @@ def test_npt_maker(interchange, run_job):
assert not np.all(new_box == start_box)


def test_nvt_maker(interchange, run_job):
def test_nvt_maker(interchange: OpenMMInterchange, run_job: Callable) -> None:
state = XmlSerializer.deserialize(interchange.state)
start_positions = state.getPositions(asNumpy=True)

maker = NVTMaker(n_steps=10, state_interval=1)
maker = NVTMaker(n_steps=10, state_interval=1, traj_interval=5)
base_job = maker.make(interchange)
task_doc = run_job(base_job)

Expand All @@ -70,7 +74,7 @@ def test_nvt_maker(interchange, run_job):
assert calc_output.steps_reported == list(range(1, 11))


def test_temp_change_maker(interchange, run_job):
def test_temp_change_maker(interchange: OpenMMInterchange, run_job: Callable):
state = XmlSerializer.deserialize(interchange.state)
start_positions = state.getPositions(asNumpy=True)

Expand All @@ -88,3 +92,38 @@ def test_temp_change_maker(interchange, run_job):
# test that temperature was updated correctly in the input
assert task_doc.calcs_reversed[0].input.temperature == 310
assert task_doc.calcs_reversed[0].input.starting_temperature == 298


def test_trajectory_reporter_json(
interchange: OpenMMInterchange, tmp_path: Path, run_job: Callable
):
"""Test that the trajectory reporter can be serialized to JSON."""
# Create simulation using NVTMaker
maker = NVTMaker(
temperature=300,
friction_coefficient=1.0,
step_size=0.002,
platform_name="CPU",
traj_interval=1,
n_steps=3,
traj_file_type="json",
)

job = maker.make(interchange)
task_doc = run_job(job)

# Test serialization/deserialization
json_str = task_doc.model_dump_json()
new_doc = OpenMMTaskDocument.model_validate_json(json_str)

# Verify trajectory data survived the round trip
calc_output = new_doc.calcs_reversed[0].output
traj_file = Path(calc_output.dir_name) / calc_output.traj_file
traj = loadfn(traj_file)

assert len(traj) == 3
assert traj.coords.max() < traj.lattice.max()
assert "kinetic_energy" in traj.frame_properties[0]

# Check that trajectory file was written
assert (tmp_path / "trajectory.json").exists()
Loading

0 comments on commit 0fb73a9

Please sign in to comment.