Skip to content

Commit

Permalink
Writing script unit tests were missing (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 authored Mar 6, 2024
1 parent 657bd91 commit 8645e70
Show file tree
Hide file tree
Showing 2 changed files with 374 additions and 76 deletions.
205 changes: 129 additions & 76 deletions mdagent/tools/base_tools/simulation_tools/setup_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,19 +635,22 @@ def __init__(
self.save = save
self.sim_id = sim_id
self.pdb_id = pdb_id
self.int_params = self.params.get("integrator_params", None)
if self.int_params is None:
self.int_params = {
self.int_params = (
self.params.integrator_params
if self.params.integrator_params is not None
else {
"integrator_type": "LangevinMiddle",
"Temperature": 300 * kelvin,
"Friction": 1.0 / picoseconds,
"Timestep": 0.002 * picoseconds,
"Pressure": 1.0 * bar,
}
)

self.sys_params = self.params.get("system_params", None)
if self.sys_params is None:
self.sys_params = {
self.sys_params = (
self.params.system_params
if self.params.system_params is not None
else {
"nonbondedMethod": NoCutoff,
"nonbondedCutoff": 1 * nanometers,
"ewaldErrorTolerance": None,
Expand All @@ -656,26 +659,28 @@ def __init__(
"constraintTolerance": 0.000001,
"solvate": False,
}
self.sim_params = self.params.get("simulation_params", None)
if self.sim_params is None:
self.sim_params = {
)

self.sim_params = (
self.params.simulation_params
if self.params.simulation_params is not None
else {
"Ensemble": "NVT",
"Number of Steps": 5000,
"record_interval_steps": 100,
"record_params": ["step", "potentialEnergy", "temperature"],
}
)

self.path_registry = path_registry
self.setup_system()
self.setup_integrator()
self.create_simulation()

def setup_system(self):
print("Building system...")
st.markdown("Building system", unsafe_allow_html=True)
self.pdb_id = self.params["pdb_id"]
self.pdb_id = self.params.pdb_id
self.pdb_path = self.path_registry.get_mapped_path(self.pdb_id)
self.pdb = PDBFile(self.pdb_path)
self.forcefield = ForceField(*self.params["forcefield_files"])
self.forcefield = ForceField(*self.params.forcefield_files)
self.system = self._create_system(self.pdb, self.forcefield, **self.sys_params)

if self.sys_params.get("nonbondedMethod", None) in [
Expand Down Expand Up @@ -882,47 +887,32 @@ def _create_system(

return system

def write_standalone_script(self, filename="reproduce_simulation.py"):
"""Extracting parameters from the class instance
Inspired by the code snippet provided from openmm-setup
https://github.com/openmm/openmm-setup
"""

def unit_to_string(unit):
"""Needed to convert units to strings for the script
Otherwise internal __str()__ method makes the script
not runnable"""
return f"{unit.value_in_unit(unit.unit)}*{unit.unit.get_name()}"

pdb_path = self.pdb_path
forcefield_files = ", ".join(
f"'{file}'" for file in self.params["forcefield_files"]
)
nonbondedMethod = self.sys_params.get("nonbondedMethod", NoCutoff)
nbCo = self.sys_params.get("nonbondedCutoff", 1 * nanometers)
nonbondedCutoff = unit_to_string(nbCo)
constraints = self.sys_params.get("constraints", "None")
rigidWater = self.sys_params.get("rigidWater", False)
ewaldErrorTolerance = self.sys_params.get("ewaldErrorTolerance", 0.0005)
constraintTolerance = self.sys_params.get("constraintTolerance", None)
hydrogenMass = self.sys_params.get("hydrogenMass", None)
solvate = self.sys_params.get("solvate", False)

integrator_type = self.int_params.get("integrator_type", "LangevinMiddle")
friction = self.int_params.get("Friction", 1.0 / picoseconds)
friction = f"{friction.value_in_unit(friction.unit)}{friction.unit.get_name()}"
_temp = self.int_params.get("Temperature", 300 * kelvin)
Temperature = unit_to_string(_temp)

t_step = self.int_params.get("Timestep", 0.004 * picoseconds)
Time_step = unit_to_string(t_step)
press = self.int_params.get("Pressure", 1.0 * bar)
pressure = unit_to_string(press)
ensemble = self.sim_params.get("Ensemble", "NVT")
self.sim_params.get("Number of Steps", 10000)
record_interval_steps = self.sim_params.get("record_interval_steps", 1000)
def unit_to_string(self, unit):
"""Needed to convert units to strings for the script
Otherwise internal __str()__ method makes the script
not runnable"""
return f"{unit.value_in_unit(unit.unit)}*{unit.unit.get_name()}"

# Construct the script content
def _construct_script_content(
self,
pdb_path,
forcefield_files,
nonbonded_method,
constraints,
rigid_water,
constraint_tolerance,
nonbonded_cutoff,
ewald_error_tolerance,
hydrogen_mass,
time_step,
temperature,
friction,
ensemble,
pressure,
record_interval_steps,
solvate,
integrator_type,
):
script_content = f"""
# This script was generated by MDagent-Setup.
Expand All @@ -935,27 +925,27 @@ def unit_to_string(unit):
forcefield = ForceField({forcefield_files})
# System Configuration
nonbondedMethod = {nonbondedMethod}
nonbondedMethod = {nonbonded_method}
constraints = {constraints}
rigidWater = {rigidWater}
rigidWater = {rigid_water}
"""
if rigidWater and constraintTolerance is not None:
script_content += f"constraintTolerance = {constraintTolerance}\n"
if rigid_water and constraint_tolerance is not None:
script_content += f"constraintTolerance = {constraint_tolerance}\n"

# Conditionally add nonbondedCutoff

if nonbondedMethod != NoCutoff:
script_content += f"nonbondedCutoff = {nonbondedCutoff}\n"
if nonbondedMethod == PME:
script_content += f"ewaldErrorTolerance = {ewaldErrorTolerance}\n"
if hydrogenMass:
script_content += f"hydrogenMass = {hydrogenMass}\n"
if nonbonded_method != NoCutoff:
script_content += f"nonbondedCutoff = {nonbonded_cutoff}\n"
if nonbonded_method == PME:
script_content += f"ewaldErrorTolerance = {ewald_error_tolerance}\n"
if hydrogen_mass:
script_content += f"hydrogenMass = {hydrogen_mass}\n"

# ... other configurations ...
script_content += f"""
# Integration Options
dt = {Time_step}
temperature = {Temperature}
dt = {time_step}
temperature = {temperature}
friction = {friction}
"""
if ensemble == "NPT":
Expand Down Expand Up @@ -992,8 +982,8 @@ def unit_to_string(unit):
"""modeller.addSolvent(forcefield, padding=1*nanometers)"""
)

if nonbondedMethod == NoCutoff:
if hydrogenMass:
if nonbonded_method == NoCutoff:
if hydrogen_mass:
script_content += """
system = forcefield.createSystem(modeller.topology,
nonbondedMethod=nonbondedMethod, constraints=constraints,
Expand All @@ -1005,8 +995,8 @@ def unit_to_string(unit):
nonbondedMethod=nonbondedMethod, constraints=constraints,
rigidWater=rigidWater)
"""
if nonbondedMethod == CutoffNonPeriodic or nonbondedMethod == CutoffPeriodic:
if hydrogenMass:
if nonbonded_method == CutoffNonPeriodic or nonbonded_method == CutoffPeriodic:
if hydrogen_mass:
script_content += """
system = forcefield.createSystem(modeller.topology,
nonbondedMethod=nonbondedMethod, nonbondedCutoff=nonbondedCutoff,
Expand All @@ -1019,8 +1009,8 @@ def unit_to_string(unit):
nonbondedMethod=nonbondedMethod, nonbondedCutoff=nonbondedCutoff,
constraints=constraints, rigidWater=rigidWater)
"""
if nonbondedMethod == PME:
if hydrogenMass:
if nonbonded_method == PME:
if hydrogen_mass:
script_content += """
system = forcefield.createSystem(modeller.topology,
nonbondedMethod=nonbondedMethod,
Expand Down Expand Up @@ -1072,6 +1062,61 @@ def unit_to_string(unit):
simulation.currentStep = 0
simulation.step(steps)
"""
return script_content

def write_standalone_script(self, filename="reproduce_simulation.py"):
"""Extracting parameters from the class instance
Inspired by the code snippet provided from openmm-setup
https://github.com/openmm/openmm-setup
"""

pdb_path = self.pdb_path
forcefield_files = ", ".join(
f"'{file}'" for file in self.params["forcefield_files"]
)
nonbonded_method = self.sys_params.get("nonbondedMethod", NoCutoff)
nbCo = self.sys_params.get("nonbondedCutoff", 1 * nanometers)
nonbonded_cutoff = self.unit_to_string(nbCo)
constraints = self.sys_params.get("constraints", "None")
rigid_water = self.sys_params.get("rigidWater", False)
ewald_error_tolerance = self.sys_params.get("ewaldErrorTolerance", 0.0005)
constraint_tolerance = self.sys_params.get("constraintTolerance", None)
hydrogen_mass = self.sys_params.get("hydrogenMass", None)
solvate = self.sys_params.get("solvate", False)

integrator_type = self.int_params.get("integrator_type", "LangevinMiddle")
friction = self.int_params.get("Friction", 1.0 / picoseconds)
friction = f"{friction.value_in_unit(friction.unit)}{friction.unit.get_name()}"
_temp = self.int_params.get("Temperature", 300 * kelvin)
temperature = self.unit_to_string(_temp)

t_step = self.int_params.get("Timestep", 0.004 * picoseconds)
time_step = self.unit_to_string(t_step)
press = self.int_params.get("Pressure", 1.0 * bar)
pressure = self.unit_to_string(press)
ensemble = self.sim_params.get("Ensemble", "NVT")
self.sim_params.get("Number of Steps", 10000)
record_interval_steps = self.sim_params.get("record_interval_steps", 1000)

script_content = self._construct_script_content(
pdb_path,
forcefield_files,
nonbonded_method,
constraints,
rigid_water,
constraint_tolerance,
nonbonded_cutoff,
ewald_error_tolerance,
hydrogen_mass,
time_step,
temperature,
friction,
ensemble,
pressure,
record_interval_steps,
solvate,
integrator_type,
)

# Remove leading spaces for proper formatting
def remove_leading_spaces(text):
Expand Down Expand Up @@ -1148,6 +1193,10 @@ class SetUpandRunFunction(BaseTool):

path_registry: Optional[PathRegistry]

def __init__(self, path_registry: Optional[PathRegistry]):
super().__init__()
self.path_registry = path_registry

def _run(self, **input_args):
if self.path_registry is None:
return "Path registry not initialized"
Expand Down Expand Up @@ -1185,9 +1234,13 @@ def _run(self, **input_args):
print(f"An exception was found: {str(e)}.")
return f"An exception was found trying to write the filenames: {str(e)}."
try:
Simulation = OpenMMSimulation(
openmmsim = OpenMMSimulation(
input, self.path_registry, save, sim_id, pdb_id
)
openmmsim.setup_system()
openmmsim.setup_integrator()
openmmsim.create_simulation()

print("simulation set!")
st.markdown("simulation set!", unsafe_allow_html=True)
except ValueError as e:
Expand All @@ -1204,7 +1257,7 @@ def _run(self, **input_args):
except OpenMMException as e:
return f"OpenMM Exception: {str(e)}. This were the inputs {input_args}"
try:
Simulation.run()
openmmsim.run()
except Exception as e:
return (
f"An exception was found: {str(e)}. Not a problem, thats one "
Expand All @@ -1216,14 +1269,14 @@ def _run(self, **input_args):
"b) clean file inputs depending on error "
)
try:
Simulation.write_standalone_script(filename=file_name)
openmmsim.write_standalone_script(filename=file_name)
self.path_registry.map_path(
sim_id,
f"files/simulations/{file_name}",
f"Basic Simulation of Protein {pdb_id}",
)
if save:
records = Simulation.registry_records
records = openmmsim.registry_records
# move record files to files/records/
print(os.listdir("."))
if not os.path.exists("files/records"):
Expand Down
Loading

0 comments on commit 8645e70

Please sign in to comment.