Skip to content

Commit

Permalink
69 dealing with simulfiles (#70)
Browse files Browse the repository at this point in the history
 1. add FileType.RECORD functionality in writefilenames and get ids at the path registry (write name and get id)
 2. Add handling temp files in setup and run: if final record files are saved in the path registry, if not they get deleted.
  • Loading branch information
Jgmedina95 authored Jan 23, 2024
1 parent 6fefec3 commit 5b08022
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 72 deletions.
5 changes: 2 additions & 3 deletions mdagent/tools/base_tools/preprocess_tools/clean_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ class CleaningToolFunctionInput(BaseModel):
"""Input model for CleaningToolFunction"""

pdb_id: str = Field(..., description="ID of the pdb/cif file in the path registry")
output_path: Optional[str] = Field(..., description="Path to the output file")
replace_nonstandard_residues: bool = Field(
True, description="Whether to replace nonstandard residues with standard ones. "
)
Expand Down Expand Up @@ -301,7 +300,7 @@ def _run(self, **input_args) -> str:
pdbfile_name = pdbfile.split("/")[-1]
name = pdbfile_name.split("_")[0]
end = pdbfile_name.split(".")[1]
print(f"pdbfile: {pdbfile}", f"name: {name}", f"end: {end}")

except Exception as e:
print(f"error retrieving from path_registry, trying to read file {e}")
return "File not found in path registry. "
Expand Down Expand Up @@ -384,7 +383,7 @@ def _run(self, **input_args) -> str:
self.path_registry.map_path(
file_id, f"{directory}/{file_name}", file_description
)
return f"{file_id} written to {directory}/{file_name}"
return f"File cleaned!\nFile ID:{file_id}\nPath:{directory}/{file_name}"
except FileNotFoundError:
return "Check your file path. File not found."
except Exception as e:
Expand Down
250 changes: 188 additions & 62 deletions mdagent/tools/base_tools/simulation_tools/setup_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,22 @@ async def _arun(self, query: str) -> str:
class SetUpandRunFunctionInput(BaseModel):
pdb_id: str
forcefield_files: List[str]
final: bool = Field(
False,
description=(
(
"Set to 'True' when the simulation is the desired final version. "
"Determines whether the simulation is the primary one "
"intended for final use. If set to 'False' (default), "
"the simulation is considered as being in a testing "
"or preliminary scripting stage, utilizing default parameters. "
"This setting is ideal for initial experimentation or "
"basic script development before customizing the "
"script for final use."
)
),
)

system_params: Dict[str, Any] = Field(
{
"nonbondedMethod": "NoCutoff",
Expand All @@ -559,29 +575,30 @@ class SetUpandRunFunctionInput(BaseModel):
"rigidWater": False,
"constraintTolerance": None,
},
description="""Parameters for the openmm system.
For nonbondedMethod, you can choose from the following:
NoCutoff, CutoffNonPeriodic, CutoffPeriodic, Ewald, PME.
If anything but NoCutoff is chosen,
you have to include a nonbondedCutoff
and a constrainTolerance.
If PME is chosen,
you have to include an ewaldErrorTolerance too.
For constraints, you can choose from the following:
None, HBonds, AllBonds or OnlyWater.
For rigidWater, you can choose from the following:
True, False.
Example1:
{"nonbondedMethod": 'NoCutoff',
"constraints": 'None',
"rigidWater": False}
Example2:
{"nonbondedMethod": 'CutoffPeriodic',
"nonbondedCutoff": 1.0,
"constraints": 'HBonds',
"rigidWater": True,
"constraintTolerance": 0.00001}
""",
description=(
"Parameters for the openmm system. "
"For nonbondedMethod, you can choose from the following:\n"
"NoCutoff, CutoffNonPeriodic, CutoffPeriodic, Ewald, PME. "
"If anything but NoCutoff is chosen,"
"you have to include a nonbondedCutoff"
"and a constrainTolerance.\n"
"If PME is chosen,"
"you have to include an ewaldErrorTolerance too."
"For constraints, you can choose from the following:\n"
"None, HBonds, AllBonds or OnlyWater."
"For rigidWater, you can choose from the following:\n"
"True, False.\n"
"Example1:\n"
"{'nonbondedMethod': 'NoCutoff',\n"
"'constraints': 'None',\n"
"'rigidWater': False}\n"
"Example2:\n"
"{'nonbondedMethod': 'CutoffPeriodic',\n"
"'nonbondedCutoff': 1.0,\n"
"'constraints': 'HBonds',\n"
"'rigidWater': True,\n"
"'constraintTolerance': 0.00001} "
),
)
integrator_params: Dict[str, Any] = Field(
{
Expand Down Expand Up @@ -618,9 +635,17 @@ class SetUpandRunFunctionInput(BaseModel):

class OpenMMSimulation:
def __init__(
self, input_params: SetUpandRunFunctionInput, path_registry: PathRegistry
self,
input_params: SetUpandRunFunctionInput,
path_registry: PathRegistry,
final: bool,
sim_id: str,
pdb_id: str,
):
self.params = input_params
self.final = final
self.sim_id = sim_id
self.pdb_id = pdb_id
self.int_params = self.params.get("integrator_params", None)
if self.int_params is None:
self.int_params = {
Expand Down Expand Up @@ -709,23 +734,74 @@ def create_simulation(self):
)
self.simulation.context.setPositions(self.pdb.positions)

# Add reporters for output
self.simulation.reporters.append(
DCDReporter(
"trajectory.dcd",
self.sim_params["record_interval_steps"],
# TEMPORARY FILE MANAGEMENT OR PATH REGISTRY MAPPING
if self.final:
trajectory_name = self.path_registry.write_file_name(
type=FileType.RECORD,
record_type="TRAJ",
protein_file_id=self.pdb_id,
Sim_id=self.sim_id,
term="dcd",
)
)
self.simulation.reporters.append(
StateDataReporter(
"log.txt",
self.sim_params["record_interval_steps"],
step=True,
potentialEnergy=True,
temperature=True,
separator="\t",

log_name = self.path_registry.write_file_name(
type=FileType.RECORD,
record_type="LOG",
protein_file_id=self.pdb_id,
Sim_id=self.sim_id,
term="txt",
)
traj_id = self.path_registry.get_fileid(trajectory_name, FileType.RECORD)
log_id = self.path_registry.get_fileid(log_name, FileType.RECORD)
traj_desc = (
f"Simulation trajectory for protein {self.pdb_id}"
f" and simulation {self.sim_id}"
)
log_desc = (
f"Simulation state log for protein {self.pdb_id} "
f"and simulation {self.sim_id}"
)

self.simulation.reporters.append(
DCDReporter(
f"{trajectory_name}",
self.sim_params["record_interval_steps"],
)
)
self.simulation.reporters.append(
StateDataReporter(
f"{log_name}",
self.sim_params["record_interval_steps"],
step=True,
potentialEnergy=True,
temperature=True,
separator="\t",
)
)
self.registry_records = [
(traj_id, f"files/records/{trajectory_name}", traj_desc),
(log_id, f"files/records/{log_name}", log_desc),
]

# TODO add checkpoint too?

else:
self.simulation.reporters.append(
DCDReporter(
"temp_trajectory.dcd",
self.sim_params["record_interval_steps"],
)
)
self.simulation.reporters.append(
StateDataReporter(
"temp_log.txt",
self.sim_params["record_interval_steps"],
step=True,
potentialEnergy=True,
temperature=True,
separator="\t",
)
)
)

def _create_system(
self,
Expand Down Expand Up @@ -858,13 +934,13 @@ def unit_to_string(unit):
steps = {self.sim_params.get("Number of Steps", record_interval_steps)}
equilibrationSteps = 1000
platform = Platform.getPlatformByName('CPU')
dcdReporter = DCDReporter('trajectory.dcd', 10000)
dcdReporter = DCDReporter('trajectory.dcd', 1000)
dataReporter = StateDataReporter('log.txt', {record_interval_steps},
totalSteps=steps,
step=True, speed=True, progress=True, elapsedTime=True, remainingTime=True,
potentialEnergy=True, temperature=True, volume=True, density=True,
separator='\t')
checkpointReporter = CheckpointReporter('checkpoint.chk', 10000)
checkpointReporter = CheckpointReporter('checkpoint.chk', 5000)
# Minimize and Equilibrate
# ... code for minimization and equilibration ...
Expand Down Expand Up @@ -989,16 +1065,26 @@ def run(self):
self.simulation.currentStep = 0
self.simulation.step(self.sim_params["Number of Steps"])
print("Done!")
if not self.final:
if os.path.exists("temp_trajectory.dcd"):
os.remove("temp_trajectory.dcd")
if os.path.exists("temp_log.txt"):
os.remove("temp_log.txt")
if os.path.exists("temp_checkpoint.chk"):
os.remove("temp_checkpoint.chk")

return "Simulation done!"


class SetUpandRunFunction(BaseTool):
name: str = "SetUpandRunFunction"
description: str = """This tool will set up and run a short simulation of a protein.
Then will write a standalone script that can be used
to reproduce the simulation or change accordingly for
a more elaborate simulation. It only runs short simulations because,
if there are errors you can try again changing the input"""
description: str = (
"This tool will set up and run a short simulation of a protein. "
"Then will write a standalone script that can be used "
"to reproduce the simulation or change accordingly for "
"a more elaborate simulation. It only runs short simulations because, "
"if there are errors, you can try again changing the input"
)

args_schema: Type[BaseModel] = SetUpandRunFunctionInput

Expand All @@ -1009,17 +1095,43 @@ def _run(self, **input_args):
print("Path registry not initialized")
return "Path registry not initialized"
input = self.check_system_params(input_args)

error = input.get("error", None)
if error:
print(f"error found: {error}")
return error

try:
pdb_id = input["pdb_id"]
# check if pdb_id is in the registry or as 1XYZ_112233 format
if pdb_id not in self.path_registry.list_path_names():
return "No pdb_id found in input, use the file id not the file name"
except KeyError:
print("whoops no pdb_id found in input,", input)
return "No pdb_id found in input"
try:
Simulation = OpenMMSimulation(input, self.path_registry)
final = input["final"] # either this simulation
# the final one or not for this system
except KeyError:
final = False
print(
"No 'final' key found in input, setting to False. "
"Record files will be deleted after script is written."
)
try:
file_name = self.path_registry.write_file_name(
type=FileType.SIMULATION,
type_of_sim=input["simmulation_params"]["Ensemble"],
protein_file_id=pdb_id,
)

sim_id = self.path_registry.get_fileid(file_name, FileType.SIMULATION)
except Exception as e:
print(f"An exception was found: {str(e)}.")
return f"An exception was found trying to write the filenames: {str(e)}."
try:
Simulation = OpenMMSimulation(
input, self.path_registry, final, sim_id, pdb_id
)
print("simulation set!")
except ValueError as e:
return str(e) + f"This were the inputs {input_args}"
Expand All @@ -1030,27 +1142,36 @@ def _run(self, **input_args):
try:
Simulation.run()
except Exception as e:
return f"""An exception was found: {str(e)}. Not a problem, thats one
purpose of this tool: to run a short simulation to check for correct
initialization. \n\n Try a) with different parameters like
nonbondedMethod, constraints, etc or b) clean file inputs depending on error
"""
try:
file_name = self.path_registry.write_file_name(
type=FileType.SIMULATION,
type_of_sim=input["simmulation_params"]["Ensemble"],
protein_file_id=pdb_id,
return (
f"An exception was found: {str(e)}. Not a problem, thats one "
"purpose of this tool: to run a short simulation to check for correct "
"initialization. "
""
"Try a) with different parameters like "
"nonbondedMethod, constraints, etc \n or\n"
"b) clean file inputs depending on error "
)
file_id = self.path_registry.get_fileid(file_name, FileType.SIMULATION)
try:
Simulation.write_standalone_script(filename=file_name)
self.path_registry.map_path(
file_id, file_name, f"Basic Simulation of Protein {pdb_id}"
sim_id,
f"files/simulations/{file_name}",
f"Basic Simulation of Protein {pdb_id}",
)
if final:
records = Simulation.registry_records
# move record files to files/records/
print(os.listdir("."))
if not os.path.exists("files/records"):
os.makedirs("files/records")
for record in records:
os.rename(record[1].split("/")[-1], f"{record[1]}")
for record in records:
self.path_registry.map_path(*record)
return "Simulation done!"
except Exception as e:
print(f"An exception was found: {str(e)}.")
return f"""An exception was found trying to write the filenames: {str(e)}.
"""
return f"An exception was found trying to write the filenames: {str(e)}."

def _parse_cutoff(self, cutoff):
# Check if cutoff is already an OpenMM Quantity (has a unit)
Expand Down Expand Up @@ -1481,6 +1602,10 @@ def check_system_params(cls, values):
if file not in FORCEFIELD_LIST:
error_msg += "The forcefield file is not present"

final = values.get("final", False)
if type(final) != bool:
error_msg += "final must be a boolean value"

if error_msg != "":
return {
"error": error_msg
Expand All @@ -1489,6 +1614,7 @@ def check_system_params(cls, values):
values = {
"pdb_id": pdb_id,
"forcefield_files": forcefield_files,
"final": final,
"system_params": system_params,
"integrator_params": integrator_params,
"simmulation_params": simmulation_params,
Expand Down
Loading

0 comments on commit 5b08022

Please sign in to comment.