diff --git a/mdagent/mainagent/agent.py b/mdagent/mainagent/agent.py index 61bff83d..cb9c81bc 100644 --- a/mdagent/mainagent/agent.py +++ b/mdagent/mainagent/agent.py @@ -45,8 +45,10 @@ def __init__( subagents_model="gpt-4-1106-preview", ckpt_dir="ckpt", resume=False, + learn=True, top_k_tools=20, # set "all" if you want to use all tools (& skills if resume) use_human_tool=False, + curriculum=True, uploaded_files=[], # user input files to add to path registry ): if path_registry is None: @@ -69,7 +71,11 @@ def __init__( callbacks=[StreamingStdOutCallbackHandler()], ) - # assign prompt + if learn: + self.skip_subagents = False + else: + self.skip_subagents = True + if agent_type == "Structured": self.prompt = structured_prompt elif agent_type == "OpenAIFunctionsAgent": @@ -83,6 +89,7 @@ def __init__( verbose=verbose, ckpt_dir=ckpt_dir, resume=resume, + curriculum=curriculum, ) def _initialize_tools_and_agent(self, user_input=None): @@ -97,6 +104,7 @@ def _initialize_tools_and_agent(self, user_input=None): llm=self.tools_llm, subagent_settings=self.subagents_settings, human=self.use_human_tool, + skip_subagents=self.skip_subagents, ) else: # retrieve all tools, including new tools if any @@ -104,6 +112,7 @@ def _initialize_tools_and_agent(self, user_input=None): self.tools_llm, subagent_settings=self.subagents_settings, human=self.use_human_tool, + skip_subagents=self.skip_subagents, ) return AgentExecutor.from_agent_and_tools( tools=self.tools, diff --git a/mdagent/subagents/subagent_setup.py b/mdagent/subagents/subagent_setup.py index 1da9ebeb..751d8fa5 100644 --- a/mdagent/subagents/subagent_setup.py +++ b/mdagent/subagents/subagent_setup.py @@ -15,6 +15,7 @@ def __init__( ckpt_dir="ckpt", resume=False, retrieval_top_k=5, + curriculum=True, ): self.path_registry = path_registry self.subagents_model = subagents_model @@ -24,6 +25,7 @@ def __init__( self.ckpt_dir = ckpt_dir self.resume = resume self.retrieval_top_k = retrieval_top_k + self.curriculum = curriculum class SubAgentInitializer: @@ -40,6 +42,7 @@ def __init__(self, settings: Optional[SubAgentSettings] = None): self.ckpt_dir = settings.ckpt_dir self.resume = settings.resume self.retrieval_top_k = settings.retrieval_top_k + self.curriculum = settings.curriculum def create_action(self, **overrides): params = { @@ -61,6 +64,8 @@ def create_critic(self, **overrides): return Critic(**params) def create_curriculum(self, **overrides): + if not self.curriculum: + return None params = { "model": self.subagents_model, "temp": self.temp, diff --git a/mdagent/tools/base_tools/__init__.py b/mdagent/tools/base_tools/__init__.py index b1e99c06..404fd3ca 100644 --- a/mdagent/tools/base_tools/__init__.py +++ b/mdagent/tools/base_tools/__init__.py @@ -13,7 +13,12 @@ RemoveWaterCleaningTool, SpecializedCleanTool, ) -from .preprocess_tools.pdb_tools import Name2PDBTool, PackMolTool, get_pdb +from .preprocess_tools.pdb_tools import ( + PackMolTool, + ProteinName2PDBTool, + SmallMolPDB, + get_pdb, +) from .simulation_tools.create_simulation import ModifyBaseSimulationScriptTool from .simulation_tools.setup_and_run import ( InstructionSummary, @@ -32,9 +37,10 @@ "InstructionSummary", "ListRegistryPaths", "MapPath2Name", - "Name2PDBTool", + "ProteinName2PDBTool", "PackMolTool", "PPIDistance", + "SmallMolPDB", "VisualizeProtein", "RMSDCalculator", "RemoveWaterCleaningTool", diff --git a/mdagent/tools/base_tools/preprocess_tools/__init__.py b/mdagent/tools/base_tools/preprocess_tools/__init__.py index 5c5650e2..b45ebce3 100644 --- a/mdagent/tools/base_tools/preprocess_tools/__init__.py +++ b/mdagent/tools/base_tools/preprocess_tools/__init__.py @@ -5,15 +5,16 @@ RemoveWaterCleaningTool, SpecializedCleanTool, ) -from .pdb_tools import Name2PDBTool, PackMolTool, get_pdb +from .pdb_tools import PackMolTool, ProteinName2PDBTool, SmallMolPDB, get_pdb __all__ = [ "AddHydrogensCleaningTool", "CleaningTools", - "Name2PDBTool", + "ProteinName2PDBTool", "PackMolTool", "RemoveWaterCleaningTool", "SpecializedCleanTool", "get_pdb", "CleaningToolFunction", + "SmallMolPDB", ] diff --git a/mdagent/tools/base_tools/preprocess_tools/clean_tools.py b/mdagent/tools/base_tools/preprocess_tools/clean_tools.py index 023d37d6..c5605012 100644 --- a/mdagent/tools/base_tools/preprocess_tools/clean_tools.py +++ b/mdagent/tools/base_tools/preprocess_tools/clean_tools.py @@ -296,17 +296,17 @@ def _run(self, **input_args) -> str: file_description = "Cleaned File: " CleaningTools() try: - pdbfile = self.path_registry.get_mapped_path(pdbfile_id) - if "/" in pdbfile: - pdbfile = pdbfile.split("/")[-1] - - name = pdbfile.split("_")[0] - end = pdbfile.split(".")[1] + pdbfile_path = self.path_registry.get_mapped_path(pdbfile_id) + if "/" in pdbfile_path: + pdbfile = pdbfile_path.split("/")[-1] + else: + pdbfile = pdbfile_path + name, end = pdbfile.split(".") except Exception as e: print(f"error retrieving from path_registry, trying to read file {e}") return "File not found in path registry. " - fixer = PDBFixer(filename=pdbfile) + fixer = PDBFixer(filename=pdbfile_path) try: fixer.findMissingResidues() except Exception: @@ -353,7 +353,7 @@ def _run(self, **input_args) -> str: file_mode = "w" if add_hydrogens else "a" file_name = self.path_registry.write_file_name( type=FileType.PROTEIN, - protein_name=name, + protein_name=name.split("_")[0], description="Clean", file_format=end, ) diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py index 24171be2..7c9a5b2c 100644 --- a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py +++ b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py @@ -10,6 +10,7 @@ from langchain.tools import BaseTool from pdbfixer import PDBFixer from pydantic import BaseModel, Field, ValidationError, root_validator +from rdkit import Chem from mdagent.utils import FileType, PathRegistry @@ -64,17 +65,18 @@ def get_pdb(query_string, path_registry=None): return None -class Name2PDBTool(BaseTool): +class ProteinName2PDBTool(BaseTool): name = "PDBFileDownloader" - description = """This tool downloads PDB (Protein Data Bank) or - CIF (Crystallographic Information File) files using - commercial chemical names. It’s ideal for situations where - you need to directly retrieve these file using a chemical’s - commercial name. When a specific file type, either PDB or CIF, - is requested, add file type to the query string with space. - Input: Commercial name of the chemical or file without - file extension - Output: Corresponding PDB or CIF file""" + description = ( + "This tool downloads PDB (Protein Data Bank) or" + "CIF (Crystallographic Information File) files using" + "a protein's common name (NOT a small molecule)." + "When a specific file type, either PDB or CIF," + "is requested, add file type to the query string with space." + "Input: Commercial name of the protein or file without" + "file extension" + "Output: Corresponding PDB or CIF file" + ) path_registry: Optional[PathRegistry] def __init__(self, path_registry: Optional[PathRegistry]): @@ -106,7 +108,8 @@ async def _arun(self, query) -> str: """validate_pdb_format: validates a pdb file against the pdb format specification packmol_wrapper: takes in a list of pdb files, a - list of number of molecules and a list of instructions and returns a packed pdb file + list of number of molecules, a list of instructions, and a list of small molecules + and returns a packed pdb file Molecule: class that represents a molecule (helpful for packmol PackmolBox: class that represents a box of molecules (helpful for packmol) summarize_errors: function that summarizes the errors found by validate_pdb_format @@ -131,6 +134,9 @@ def validate_pdb_format(fhandle): - 1 if error was found, 0 if no errors were found. - List of error messages encountered. """ + # check if filename is in directory + if not os.path.exists(fhandle): + return (1, ["File not found. Packmol failed to write the file."]) errors = [] _fmt_check = ( ("Atm. Num.", (slice(6, 11), re.compile(r"[\d\s]+"))), @@ -164,7 +170,7 @@ def _make_pointer(column): if not line: continue - if line[0:6] in ("ATOM ", "HETATM"): + if line[0:6] in ["ATOM ", "HETATM"]: # ... [rest of the code unchanged here] linelen = len(line) if linelen < 80: @@ -258,8 +264,9 @@ def summarize_errors(errors): class Molecule: - def __init__(self, filename, number_of_molecules=1, instructions=None): + def __init__(self, filename, file_id, number_of_molecules=1, instructions=None): self.filename = filename + self.id = file_id self.number_of_molecules = number_of_molecules self.instructions = instructions if instructions else [] self.load() @@ -280,6 +287,7 @@ def __init__( self.molecules = [] self.file_number = 1 self.file_description = file_description + self.final_name = None def add_molecule(self, molecule): self.molecules.append(molecule) @@ -288,16 +296,33 @@ def add_molecule(self, molecule): def generate_input_header(self): # Generate the header of the input file in .inp format + orig_pdbs_ids = [ + f"{molecule.number_of_molecules}_{molecule.id}" + for molecule in self.molecules + ] + + _final_name = f'{"_and_".join(orig_pdbs_ids)}' - while os.path.exists(f"packed_structures_v{self.file_number}.pdb"): + self.file_description = ( + "Packed Structures of the following molecules:\n" + + "\n".join( + [ + f"Molecule ID: {molecule.id}, " + f"Number of Molecules: {molecule.number_of_molecules}" + for molecule in self.molecules + ] + ) + ) + while os.path.exists(f"files/pdb/{_final_name}_v{self.file_number}.pdb"): self.file_number += 1 + self.final_name = f"{_final_name}_v{self.file_number}.pdb" with open("packmol.inp", "w") as out: out.write("##Automatically generated by LangChain\n") out.write("tolerance 2.0\n") out.write("filetype pdb\n") out.write( - f"output packed_structures_v{self.file_number}.pdb\n" + f"output {self.final_name}\n" ) # this is the name of the final file out.close() @@ -323,28 +348,41 @@ def run_packmol(self, PathRegistry): cmd = "packmol < packmol.inp" result = subprocess.run(cmd, shell=True, text=True, capture_output=True) if result.returncode != 0: + print("Packmol failed to run with 'packmol < packmol.inp' command") result = subprocess.run( "./" + cmd, shell=True, text=True, capture_output=True ) + if result.returncode != 0: + print("Packmol failed to run with './packmol < packmol.inp' command") + return ( + "Packmol failed to run. Please check the input file and try again." + ) - PathRegistry.map_path( - f"packed_structures_v{self.file_number}.pdb", - f"packed_structures_v{self.file_number}.pdb", - self.file_description, - ) - print(result.stdout) # validate final pdb - pdb_validation = validate_pdb_format("packed_structures.pdb") + pdb_validation = validate_pdb_format(f"{self.final_name}") if pdb_validation[0] == 0: # delete .inp files - os.remove("packmol.inp") - return "PDB file validated successfully" + # os.remove("packmol.inp") + for molecule in self.molecules: + os.remove(molecule.filename) + # name of packed pdb file + time_stamp = PathRegistry.get_timestamp()[-6:] + os.rename(self.final_name, f"files/pdb/{self.final_name}") + PathRegistry.map_path( + f"PACKED_{time_stamp}", + f"files/pdb/{self.final_name}", + self.file_description, + ) + # move file to files/pdb + print("successfull!") + return f"PDB file validated successfully. FileID: PACKED_{time_stamp}" elif pdb_validation[0] == 1: # format pdb_validation[1] list of errors errors = summarize_errors(pdb_validation[1]) # delete .inp files - os.remove("packmol.inp") + # os.remove("packmol.inp") + print("errors:", f"{errors}") return "PDB file not validated, errors found {}".format(("\n").join(errors)) @@ -355,6 +393,7 @@ def run_packmol(self, PathRegistry): def packmol_wrapper( PathRegistry, pdbfiles: List, + files_id: List, number_of_molecules: List, instructions: List[List], ): @@ -364,16 +403,19 @@ def packmol_wrapper( # create a box box = PackmolBox() # add molecules to the box - for pdbfile, number_of_molecules, instructions in zip( - pdbfiles, number_of_molecules, instructions - ): - molecule = Molecule(pdbfile, number_of_molecules, instructions) + for ( + pdbfile, + file_id, + number_of_molecules, + instructions, + ) in zip(pdbfiles, files_id, number_of_molecules, instructions): + molecule = Molecule(pdbfile, file_id, number_of_molecules, instructions) box.add_molecule(molecule) # generate input header box.generate_input_header() # generate input # run packmol - + print("Packing:", box.file_description, "\nThe file name is:", box.final_name) return box.run_packmol(PathRegistry) @@ -382,82 +424,59 @@ def packmol_wrapper( class PackmolInput(BaseModel): - pdbfiles: typing.Optional[typing.List[str]] = Field( - ..., description="List of PDB files to pack into a box" + pdbfiles_id: typing.Optional[typing.List[str]] = Field( + ..., description="List of PDB files id (path_registry) to pack into a box" + ) + small_molecules: typing.Optional[typing.List[str]] = Field( + [], + description=( + "List of small molecules to be packed in the system. " + "Examples: water, benzene, toluene, etc." + ), ) + number_of_molecules: typing.Optional[typing.List[int]] = Field( - ..., description="List of number of molecules to pack into a box" + ..., + description=( + "List of number of instances of each species to pack into the box. " + "One number per species (either protein or small molecule) " + ), ) instructions: typing.Optional[typing.List[List[str]]] = Field( ..., - description="""List of instructions for each molecule. - One List per Molecule. - Every instruction should be one string like: - 'inside box 0. 0. 0. 90. 90. 90.'""", + description=( + "List of instructions for each species. " + "One List per Molecule. " + "Every instruction should be one string like:\n" + "'inside box 0. 0. 0. 90. 90. 90.'" + ), ) - @root_validator - def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict: - # check if is only a string - print("values", values) - if isinstance(values, str): - print("values is a string", values) - raise ValidationError("Input must be a dictionary") - pdbfiles = values.get("pdbfiles", []) - number_of_molecules = values.get("number_of_molecules", []) - instructions = values.get("instructions", []) - - if not (len(pdbfiles) == len(number_of_molecules) == len(instructions)): - return { - "error": """The lengths of pdbfiles, number_of_molecules, - and instructions must be equal to use this tool.""" - } - - for instruction in instructions: - if len(instruction) != 1: - return { - "error": """Each instruction must be a single string. - If necessary, use newlines in a instruction string.""" - } - if instruction[0].split(" ")[0] not in [ - "inside", - "center", - "outside", - "fixed", - ]: - return { - "error": """The first word of each instruction must be one of - 'inside' or 'center' or 'outside' or 'fixed' \n - examples: center \n fixed 0. 0. 0. 0. 0. 0., - inside box -10. 0. 0. 10. 10. 10. \n """ - } - # Further validation, e.g., checking if files exist - for pdbfile in pdbfiles: - if not os.path.exists(pdbfile): - # look for files in the current directory - # that match some part of the pdbfile - possible_files = [] - for file in os.listdir(): - if pdbfile in file: - possible_files.append(file) - if len(possible_files) > 0: - return { - "error": f"""PDB file {pdbfile} does not exist in the current - directory, maybe you wanted one of:{','.join(possible_files)}.""" - } - if len(possible_files) == 0: - return { - "error": f"""PDB file {pdbfile} does not exist - in the current directory. - Make sure the pdbfiles are correct.""" - } - return values - class PackMolTool(BaseTool): name: str = "packmol_tool" - description: str = """Useful when you need to create a box - of different types of molecules molecules""" + description: str = ( + "Useful when you need to create a box " + "of different types of chemical species.\n" + "Three different examples:\n" + "pdbfiles_id: ['1a2b_123456']\n" + "small_molecules: ['water'] \n" + "number_of_molecules: [1, 1000]\n" + "instructions: [['fixed 0. 0. 0. 0. 0. 0. \n centerofmass'], " + "['inside box 0. 0. 0. 90. 90. 90.']]\n" + "will pack 1 molecule of 1a2b_123456 at the origin " + "and 1000 molecules of water. \n" + "pdbfiles_id: ['1a2b_123456']\n" + "number_of_molecules: [1]\n" + "instructions: [['fixed 0. 0. 0. 0. 0. 0.' \n center]]\n" + "This will fix the barocenter of protein 1a2b_123456 at " + "the center of the box with no rotation.\n" + "pdbfiles_id: ['1a2b_123456']\n" + "number_of_molecules: [1]\n" + "instructions: [['outside sphere 2.30 3.40 4.50 8.0]]\n" + "This will place the protein 1a2b_123456 outside a sphere " + "centered at 2.30 3.40 4.50 with radius 8.0\n" + ) args_schema: Type[BaseModel] = PackmolInput @@ -467,18 +486,60 @@ def __init__(self, path_registry: typing.Optional[PathRegistry]): super().__init__() self.path_registry = path_registry + def _get_sm_pdbs(self, small_molecules): + all_files = self.path_registry.list_path_names() + for molecule in small_molecules: + # check path registry for molecule.pdb + if molecule not in all_files: + # download molecule using small_molecule_pdb from MolPDB + molpdb = MolPDB() + molpdb.small_molecule_pdb(molecule, self.path_registry) + print("Small molecules PDBs created successfully") + def _run(self, **values) -> str: """use the tool.""" if self.path_registry is None: # this should not happen raise ValidationError("Path registry not initialized") - + try: + values = self.validate_input(values) + except ValidationError as e: + return str(e) error_msg = values.get("error", None) - pdbfiles = values.get("pdbfiles", []) + if error_msg: + print("Error in Packmol inputs:", error_msg) + return f"Error in inputs: {error_msg}" + print("Starting Packmol Tool!") + pdbfile_ids = values.get("pdbfiles_id", []) + pdbfiles = [ + self.path_registry.get_mapped_path(pdbfile) for pdbfile in pdbfile_ids + ] + pdbfile_names = [pdbfile.split("/")[-1] for pdbfile in pdbfiles] + # copy them to the current directory with temp_ names + + pdbfile_names = [f"temp_{pdbfile_name}" for pdbfile_name in pdbfile_names] number_of_molecules = values.get("number_of_molecules", []) instructions = values.get("instructions", []) - if error_msg: - return error_msg + small_molecules = values.get("small_molecules", []) + # make sure small molecules are all downloaded + self._get_sm_pdbs(small_molecules) + small_molecules_files = [ + self.path_registry.get_mapped_path(sm) for sm in small_molecules + ] + small_molecules_file_names = [ + small_molecule.split("/")[-1] for small_molecule in small_molecules_files + ] + small_molecules_file_names = [ + f"temp_{small_molecule_file_name}" + for small_molecule_file_name in small_molecules_file_names + ] + # append small molecules to pdbfiles + pdbfiles.extend(small_molecules_files) + pdbfile_names.extend(small_molecules_file_names) + pdbfile_ids.extend(small_molecules) + + for pdbfile, pdbfile_name in zip(pdbfiles, pdbfile_names): + os.system(f"cp {pdbfile} {pdbfile_name}") # check if packmol is installed cmd = "command -v packmol" result = subprocess.run(cmd, shell=True, text=True, capture_output=True) @@ -487,16 +548,133 @@ def _run(self, **values) -> str: "./" + cmd, shell=True, text=True, capture_output=True ) if result.returncode != 0: - return """Packmol is not installed. Please install packmol - at 'https://m3g.github.io/packmol/download.shtml' and try again.""" + return ( + "Packmol is not installed. Please install" + "packmol at " + "'https://m3g.github.io/packmol/download.shtml'" + "and try again." + ) return packmol_wrapper( self.path_registry, - pdbfiles=pdbfiles, + pdbfiles=pdbfile_names, + files_id=pdbfile_ids, number_of_molecules=number_of_molecules, instructions=instructions, ) + def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict: + # check if is only a string + if isinstance(values, str): + print("values is a string", values) + raise ValidationError("Input must be a dictionary") + pdbfiles = values.get("pdbfiles_id", []) + small_molecules = values.get("small_molecules", []) + number_of_molecules = values.get("number_of_molecules", []) + instructions = values.get("instructions", []) + number_of_species = len(pdbfiles) + len(small_molecules) + + if not number_of_species == len(number_of_molecules): + if not number_of_species == len(instructions): + return { + "error": ( + "The length of number_of_molecules AND instructions " + "must be equal to the number of species in the system. " + f"You have {number_of_species} " + f"from {len(pdbfiles)} pdbfiles and {len(small_molecules)} " + "small molecules" + ) + } + return { + "error": ( + "The length of number_of_molecules must be equal to the " + f"number of species in the system. You have {number_of_species} " + f"from {len(pdbfiles)} pdbfiles and {len(small_molecules)} " + "small molecules" + ) + } + elif not number_of_species == len(instructions): + return { + "error": ( + "The length of instructions must be equal to the " + f"number of species in the system. You have {number_of_species} " + f"from {len(pdbfiles)} pdbfiles and {len(small_molecules)} " + "small molecules" + ) + } + + molPDB = MolPDB() + for instruction in instructions: + if len(instruction) != 1: + return { + "error": ( + "Each instruction must be a single string. " + "If necessary, use newlines in a instruction string." + ) + } + # TODO enhance this validation with more packmol instructions + first_word = instruction[0].split(" ")[0] + if first_word == "center": + if len(instruction[0].split(" ")) == 1: + return { + "error": ( + "The instruction 'center' must be accompanied by more " + "instructions. Example 'fixed 0. 0. 0. 0. 0. 0.' " + "The complete instruction would be: 'center \n fixed 0. 0. " + "0. 0. 0. 0.' with a newline separating the two " + "instructions." + ) + } + elif first_word not in [ + "inside", + "outside", + "fixed", + ]: + return { + "error": ( + "The first word of each instruction must be one of " + "'inside' or 'outside' or 'fixed' \n" + "examples: center \n fixed 0. 0. 0. 0. 0. 0.,\n" + "inside box -10. 0. 0. 10. 10. 10. \n" + ) + } + + # Further validation, e.g., checking if files exist + registry = PathRegistry() + file_ids = registry.list_path_names() + + for pdbfile_id in pdbfiles: + if "_" not in pdbfile_id: + return { + "error": ( + f"{pdbfile_id} is not a valid pdbfile_id in the path_registry" + ) + } + if pdbfile_id not in file_ids: + # look for files in the current directory + # that match some part of the pdbfile + ids_w_description = registry.list_path_names_and_descriptions() + + return { + "error": ( + f"PDB file ID {pdbfile_id} does not exist " + "in the path registry.\n" + f"This are the files IDs: {ids_w_description} " + ) + } + for small_molecule in small_molecules: + if small_molecule not in file_ids: + result = molPDB.small_molecule_pdb(small_molecule, registry) + if "successfully" not in result: + return { + "error": ( + f"{small_molecule} could not be converted to a pdb " + "file. Try with a different name, or with the SMILES " + "of the small molecule" + ) + } + return values + async def _arun(self, values: str) -> str: """Use the tool asynchronously.""" raise NotImplementedError("custom_search does not support async") @@ -552,9 +730,11 @@ def _atoms_have_elements(self, pdbfile): print(elements) if len(elements) != len(atoms): print( - f"""No elements in the ATOM records there are - {len(elements)} elements and {len(atoms)} - atoms records""" + ( + "No elements in the ATOM records there are" + "{len(elements)} elements and {len(atoms)}" + "atoms records" + ) ) return False elements = list(set(elements)) @@ -665,14 +845,15 @@ def pdb_summarizer(pdb_file): pdb.num_of_residues = pdb._num_of_dif_residues(pdb_file) pdb.HETATM_tempFact = pdb._hetatm_have_tempFactor(pdb_file) - output = f"""PDB file: {pdb_file} has the following properties: - Number of residues: {pdb.num_of_residues} - Are elements identifiers present: {pdb.atoms} - Are HETATM elements identifiers present: {pdb.HETATM} - Are residue names present: {pdb.residues} - Are box dimensions present: {pdb.box} - Non-standard residues: {pdb.HETATM} - """ + output = ( + f"PDB file: {pdb_file} has the following properties:" + "Number of residues: {pdb.num_of_residues}" + "Are elements identifiers present: {pdb.atoms}" + "Are HETATM elements identifiers present: {pdb.HETATM}" + "Are residue names present: {pdb.residues}" + "Are box dimensions present: {pdb.box}" + "Non-standard residues: {pdb.HETATM}" + ) return output @@ -720,8 +901,10 @@ def fix_element_column(pdb_file, custom_element_dict=None): ), pdb._hetatm_have_elements(pdb_file) if atoms_have_elems and HETATM_have_elems: f.close() - return """Element's column already filled with - elements, no fix needed for elements""" + return ( + "Element's column already filled with" + "elements, no fix needed for elements" + ) print("I closed the initial file") f.close() @@ -774,8 +957,9 @@ class FixElementColumnArgs(BaseTool): pdb_file: str = Field(..., description="PDB file to be fixed") custom_element_dict: dict = Field( None, - description="""Custom element dictionary. If None, - the default dictionary is used""", + description=( + "Custom element dictionary. If None," "the default dictionary is used" + ), ) @@ -849,8 +1033,10 @@ def fix_temp_factor_column(pdb_file, bfactor=1.00, only_fill=True): if atoms_have_bfactor and HETATM_have_bfactor and only_fill: # print("Im closing the file temp factor") f.close() - return """TempFact column filled with bfactor already, - no fix needed for temp factor""" + return ( + "TempFact column filled with bfactor already," + "no fix needed for temp factor" + ) f.close() # fix element column records = ("TITLE", "HEADER", "REMARK", "CRYST1", "HET", "LINK", "SEQRES") @@ -901,10 +1087,12 @@ class FixTempFactorColumnArgs(BaseTool): bfactor: float = Field(1.0, description="Bfactor value to use") only_fill: bool = Field( True, - description="""Only fill empty bfactor columns. - Avoids replacing existing values. - False if you want to replace all values - with the bfactor value""", + description=( + "Only fill empty bfactor columns." + "Avoids replacing existing values." + "False if you want to replace all values" + "with the bfactor value" + ), ) @@ -962,8 +1150,10 @@ def fix_occupancy_columns(pdb_file, occupancy=1.0, only_fill=True): ), pdb._hetatom_have_occupancy(file_name) if atoms_have_bfactor and HETATM_have_bfactor and only_fill: f.close() - return """Occupancy column filled with occupancy - already, no fix needed for occupancy""" + return ( + "Occupancy column filled with occupancy" + "already, no fix needed for occupancy" + ) f.close() # fix element column records = ("TITLE", "HEADER", "REMARK", "CRYST1", "HET", "LINK", "SEQRES") @@ -1013,10 +1203,12 @@ class FixOccupancyColumnArgs(BaseTool): occupancy: float = Field(1.0, description="Occupancy value to be set") only_fill: bool = Field( True, - description="""Only fill empty occupancy columns. - Avoids replacing existing values. - False if you want to replace all - values with the occupancy value""", + description=( + "Only fill empty occupancy columns." + "Avoids replacing existing values." + "False if you want to replace all" + "values with the occupancy value" + ), ) @@ -1046,23 +1238,29 @@ class PDBFilesFixInp(BaseModel): pdbfile: str = Field(..., description="PDB file to be fixed") ElemColum: typing.Optional[bool] = Field( False, - description="""List of fixes to be applied. If None, a - validation of what fixes are needed is performed.""", + description=( + "List of fixes to be applied. If None, a" + "validation of what fixes are needed is performed." + ), ) tempFactor: typing.Optional[typing.Tuple[float, bool]] = Field( (...), - description="""Tuple of ( float, bool) - first arg is the - value to be set as the tempFill, and third arg indicates - if only empty TempFactor columns have to be filled""", + description=( + "Tuple of ( float, bool)" + "first arg is the" + "value to be set as the tempFill, and third arg indicates" + "if only empty TempFactor columns have to be filled" + ), ) Occupancy: typing.Optional[typing.Tuple[float, bool]] = Field( (...), - description="""Tuple of (bool, float, bool) - where first arg indicates if Occupancy - fix has to be applied, second arg is the - value to be set, and third arg indicates - if only empty Occupancy columns have to be filled""", + description=( + "Tuple of (bool, float, bool)" + "where first arg indicates if Occupancy" + "fix has to be applied, second arg is the" + "value to be set, and third arg indicates" + "if only empty Occupancy columns have to be filled" + ), ) @root_validator @@ -1085,18 +1283,22 @@ def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict: if occupancy: if len(occupancy) != 2: return { - "error": """if you want to fix the occupancy - column argument must be a tuple of (bool, float)""" + "error": ( + "if you want to fix the occupancy" + "column argument must be a tuple of (bool, float)" + ) } if not isinstance(occupancy[0], float): return {"error": "occupancy first arg must be a float"} if not isinstance(occupancy[1], bool): - return {"error": """occupancy second arg must be a bool"""} + return {"error": "occupancy second arg must be a bool"} if tempFactor: if len(tempFactor != 2): return { - "error": """if you want to fix the tempFactor - column argument must be a tuple of (float, bool)""" + "error": ( + "if you want to fix the tempFactor" + "column argument must be a tuple of (float, bool)" + ) } if not isinstance(tempFactor[0], bool): return {"error": "occupancy first arg must be a float"} @@ -1113,9 +1315,9 @@ class FixPDBFile(BaseTool): description: str = "Fixes PDB files columns if needed" args_schema: Type[BaseModel] = PDBFilesFixInp - path_registry: typing.Optional[PathRegistry] + path_registry: Optional[PathRegistry] - def __init__(self, path_registry: typing.Optional[PathRegistry]): + def __init__(self, path_registry: Optional[PathRegistry]): super().__init__() self.path_registry = path_registry @@ -1162,3 +1364,123 @@ def _run(self, query: Dict): return "PDB file fixed" else: return "PDB not fully fixed" + + +class MolPDB: + def is_smiles(self, text: str) -> bool: + try: + m = Chem.MolFromSmiles(text, sanitize=False) + if m is None: + return False + return True + except Exception: + return False + + def largest_mol( + self, smiles: str + ) -> ( + str + ): # from https://github.com/ur-whitelab/chemcrow-public/blob/main/chemcrow/utils.py + ss = smiles.split(".") + ss.sort(key=lambda a: len(a)) + while not self.is_smiles(ss[-1]): + rm = ss[-1] + ss.remove(rm) + return ss[-1] + + def molname2smiles( + self, query: str + ) -> ( + str + ): # from https://github.com/ur-whitelab/chemcrow-public/blob/main/chemcrow/tools/databases.py + url = " https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{}/{}" + r = requests.get(url.format(query, "property/IsomericSMILES/JSON")) + # convert the response to a json object + data = r.json() + # return the SMILES string + try: + smi = data["PropertyTable"]["Properties"][0]["IsomericSMILES"] + except KeyError: + return ( + "Could not find a molecule matching the text." + "One possible cause is that the input is incorrect, " + "input one molecule at a time." + ) + # remove salts + return Chem.CanonSmiles(self.largest_mol(smi)) + + def smiles2name(self, smi: str) -> str: + try: + smi = Chem.MolToSmiles(Chem.MolFromSmiles(smi), canonical=True) + except Exception: + return "Invalid SMILES string" + # query the PubChem database + r = requests.get( + "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/smiles/" + + smi + + "/synonyms/JSON" + ) + data = r.json() + try: + name = data["InformationList"]["Information"][0]["Synonym"][0] + except KeyError: + return "Unknown Molecule" + return name + + def small_molecule_pdb(self, mol_str: str, path_registry) -> str: + # takes in molecule name or smiles (converts to smiles if name) + # writes pdb file name.pdb (gets name from smiles if possible) + # output is done message + ps = Chem.SmilesParserParams() + ps.removeHs = False + try: + if self.is_smiles(mol_str): + m = Chem.MolFromSmiles(mol_str) + mol_name = self.smiles2name(mol_str) + else: # if input is not smiles, try getting smiles + smi = self.molname2smiles(mol_str) + m = Chem.MolFromSmiles(smi) + mol_name = mol_str + try: # only if needed + m = Chem.AddHs(m) + except Exception: # TODO: we should be more specific here + pass + Chem.AllChem.EmbedMolecule(m) + file_name = f"files/pdb/{mol_name}.pdb" + Chem.MolToPDBFile(m, file_name) + # add to path registry + if path_registry: + _ = path_registry.map_path( + mol_name, file_name, f"pdb file for the small molecule {mol_name}" + ) + return ( + f"PDB file for {mol_str} successfully created and saved to {file_name}." + ) + except Exception: # TODO: we should be more specific here + print( + "There was an error getting pdb. Please input a single molecule name." + f"{mol_str},{mol_name}, {smi}" + ) + return ( + "There was an error getting pdb. Please input a single molecule name." + ) + + +class SmallMolPDB(BaseTool): + name = "SmallMoleculePDB" + description = ( + "Creates a PDB file for a small molecule" + "Use this tool when you need to use a small molecule in a simulation." + "Input can be a molecule name or a SMILES string." + ) + path_registry: Optional[PathRegistry] + + def __init__(self, path_registry: Optional[PathRegistry]): + super().__init__() + self.path_registry = path_registry + + def _run(self, mol_str: str) -> str: + """use the tool.""" + mol_pdb = MolPDB() + output = mol_pdb.small_molecule_pdb(mol_str, self.path_registry) + return output diff --git a/mdagent/tools/base_tools/simulation_tools/setup_and_run.py b/mdagent/tools/base_tools/simulation_tools/setup_and_run.py index 19ffaac6..808d9ca1 100644 --- a/mdagent/tools/base_tools/simulation_tools/setup_and_run.py +++ b/mdagent/tools/base_tools/simulation_tools/setup_and_run.py @@ -387,6 +387,15 @@ def _setup_and_run_simulation(self, query, PathRegistry): simulation = Simulation(modeller.topology, system, integrator) simulation.context.setPositions(modeller.positions) simulation.minimizeEnergy() + # save initial positions to registry + file_name = "initial_positions.pdb" + with open(file_name, "w") as f: + PDBFile.writeFile( + simulation.topology, + simulation.context.getState(getPositions=True).getPositions(), + f, + ) + print("Initial Positions saved to initial_positions.pdb") simulation.reporters.append(PDBReporter(f"{name}.pdb", 1000)) # reporter_args = {"reportInterval": 1000} reporter_args = {} @@ -579,6 +588,7 @@ class SetUpandRunFunctionInput(BaseModel): "constraints": "None", "rigidWater": False, "constraintTolerance": None, + "solvate": False, }, description=( "Parameters for the openmm system. " @@ -593,6 +603,8 @@ class SetUpandRunFunctionInput(BaseModel): "None, HBonds, AllBonds or OnlyWater." "For rigidWater, you can choose from the following:\n" "True, False.\n" + "Finally, if you want to solvate the system, before the simulation," + "you can set solvate to True.\n" "Example1:\n" "{'nonbondedMethod': 'NoCutoff',\n" "'constraints': 'None',\n" @@ -602,7 +614,8 @@ class SetUpandRunFunctionInput(BaseModel): "'nonbondedCutoff': 1.0,\n" "'constraints': 'HBonds',\n" "'rigidWater': True,\n" - "'constraintTolerance': 0.00001} " + "'constraintTolerance': 0.00001,\n" + "'solvate': True} " ), ) integrator_params: Dict[str, Any] = Field( @@ -670,6 +683,7 @@ def __init__( "constraints": AllBonds, "rigidWater": True, "constraintTolerance": 0.000001, + "solvate": False, } self.sim_params = self.params.get("simmulation_params", None) if self.sim_params is None: @@ -688,7 +702,7 @@ def setup_system(self): print("Building system...") st.markdown("Building system", unsafe_allow_html=True) self.pdb_id = self.params["pdb_id"] - self.pdb_path = self.path_registry.get_mapped_path(name=self.pdb_id) + self.pdb_path = self.path_registry.get_mapped_path(self.pdb_id) self.pdb = PDBFile(self.pdb_path) self.forcefield = ForceField(*self.params["forcefield_files"]) self.system = self._create_system(self.pdb, self.forcefield, **self.sys_params) @@ -735,12 +749,12 @@ def create_simulation(self): print("Creating simulation...") st.markdown("Creating simulation", unsafe_allow_html=True) self.simulation = Simulation( - self.pdb.topology, + self.modeller.topology, self.system, self.integrator, Platform.getPlatformByName("CPU"), ) - self.simulation.context.setPositions(self.pdb.positions) + self.simulation.context.setPositions(self.modeller.positions) # TEMPORARY FILE MANAGEMENT OR PATH REGISTRY MAPPING if self.save: @@ -759,8 +773,7 @@ def create_simulation(self): Sim_id=self.sim_id, term="txt", ) - traj_id = self.path_registry.get_fileid(trajectory_name, FileType.RECORD) - log_id = self.path_registry.get_fileid(log_name, FileType.RECORD) + traj_desc = ( f"Simulation trajectory for protein {self.pdb_id}" f" and simulation {self.sim_id}" @@ -787,8 +800,8 @@ def create_simulation(self): ) ) self.registry_records = [ - (traj_id, f"files/records/{trajectory_name}", traj_desc), - (log_id, f"files/records/{log_name}", log_desc), + ("holder", f"files/records/{trajectory_name}", traj_desc), + ("holder", f"files/records/{log_name}", log_desc), ] # TODO add checkpoint too? @@ -821,6 +834,7 @@ def _create_system( constraints="None", rigidWater=False, constraintTolerance=None, + solvate=False, **kwargs, ): # Create a dictionary to hold system parameters @@ -850,8 +864,26 @@ def _create_system( # if use_constraint_tolerance: # constraintTolerance = system_params.pop('constraintTolerance') - - system = forcefield.createSystem(pdb.topology, **system_params) + self.modeller = Modeller(pdb.topology, pdb.positions) + if solvate: + try: + self.modeller.addSolvent(forcefield) + except ValueError as e: + print("Error adding solvent", type(e).__name__, "–", e) + if "No Template for" in str(e): + raise ValueError(str(e)) + except AttributeError as e: + print("Error adding solvent: ", type(e).__name__, "–", e) + print("Trying to add solvent with 1 nm padding") + if "NoneType" and "value_in_unit" in str(e): + try: + self.modeller.addSolvent(forcefield, padding=1 * nanometers) + except Exception as e: + print("Error adding solvent", type(e).__name__, "–", e) + raise (e) + system = forcefield.createSystem(self.modeller.topology, **system_params) + else: + system = forcefield.createSystem(self.modeller.topology, **system_params) return system @@ -876,9 +908,10 @@ def unit_to_string(unit): nonbondedCutoff = unit_to_string(nbCo) constraints = self.sys_params.get("constraints", "None") rigidWater = self.sys_params.get("rigidWater", False) - ewaldErrorTolerance = {self.sys_params.get("ewaldErrorTolerance", 0.0005)} + ewaldErrorTolerance = self.sys_params.get("ewaldErrorTolerance", 0.0005) constraintTolerance = self.sys_params.get("constraintTolerance", None) hydrogenMass = self.sys_params.get("hydrogenMass", None) + solvate = self.sys_params.get("solvate", False) integrator_type = self.int_params.get("integrator_type", "LangevinMiddle") friction = self.int_params.get("Friction", 1.0 / picoseconds) @@ -956,46 +989,51 @@ def unit_to_string(unit): # Simulate print('Building system...') - topology = pdb.topology - positions = pdb.positions + modeller = Modeller(pdb.topology, pdb.positions) """ + if solvate: + script_content += ( + """modeller.addSolvent(forcefield, padding=1*nanometers)""" + ) + if nonbondedMethod == NoCutoff: if hydrogenMass: script_content += """ - system = forcefield.createSystem(topology, nonbondedMethod=nonbondedMethod, - constraints=constraints, rigidWater=rigidWater, hydrogenMass=hydrogenMass) + system = forcefield.createSystem(modeller.topology, + nonbondedMethod=nonbondedMethod, constraints=constraints, + rigidWater=rigidWater, hydrogenMass=hydrogenMass) """ else: script_content += """ - system = forcefield.createSystem(topology, nonbondedMethod=nonbondedMethod, - constraints=constraints, rigidWater=rigidWater) + system = forcefield.createSystem(modeller.topology, + nonbondedMethod=nonbondedMethod, constraints=constraints, + rigidWater=rigidWater) """ if nonbondedMethod == CutoffNonPeriodic or nonbondedMethod == CutoffPeriodic: if hydrogenMass: script_content += """ - system = forcefield.createSystem(topology, - nonbondedMethod=nonbondedMethod, - nonbondedCutoff=nonbondedCutoff, constraints=constraints, - rigidWater=rigidWater, hydrogenMass=hydrogenMass) + system = forcefield.createSystem(modeller.topology, + nonbondedMethod=nonbondedMethod, nonbondedCutoff=nonbondedCutoff, + constraints=constraints, rigidWater=rigidWater, + hydrogenMass=hydrogenMass) """ else: script_content += """ - system = forcefield.createSystem(topology, - nonbondedMethod=nonbondedMethod, - nonbondedCutoff=nonbondedCutoff, constraints=constraints, - rigidWater=rigidWater) + system = forcefield.createSystem(modeller.topology, + nonbondedMethod=nonbondedMethod, nonbondedCutoff=nonbondedCutoff, + constraints=constraints, rigidWater=rigidWater) """ if nonbondedMethod == PME: if hydrogenMass: script_content += """ - system = forcefield.createSystem(topology, + system = forcefield.createSystem(modeller.topology, nonbondedMethod=nonbondedMethod, nonbondedCutoff=nonbondedCutoff, ewaldErrorTolerance=ewaldErrorTolerance, constraints=constraints, rigidWater=rigidWater, hydrogenMass=hydrogenMass) """ else: script_content += """ - system = forcefield.createSystem(topology, + system = forcefield.createSystem(modeller.topology, nonbondedMethod=nonbondedMethod, nonbondedCutoff=nonbondedCutoff, ewaldErrorTolerance=ewaldErrorTolerance, constraints=constraints, rigidWater=rigidWater) @@ -1009,14 +1047,14 @@ def unit_to_string(unit): script_content += """ integrator = LangevinMiddleIntegrator(temperature, friction, dt) integrator.setConstraintTolerance(constraintTolerance) - simulation = Simulation(topology, system, integrator, platform) - simulation.context.setPositions(positions) + simulation = Simulation(modeller.topology, system, integrator, platform) + simulation.context.setPositions(modeller.positions) """ if integrator_type == "LangevinMiddle" and constraints == "None": script_content += """ integrator = LangevinMiddleIntegrator(temperature, friction, dt) - simulation = Simulation(topology, system, integrator, platform) - simulation.context.setPositions(positions) + simulation = Simulation(modeller.topology, system, integrator, platform) + simulation.context.setPositions(modeller.positions) """ script_content += """ @@ -1065,6 +1103,16 @@ def run(self): self.simulation.minimizeEnergy() print("Minimization complete!") + top_name = f"files/pdb/{self.sim_id}_initial_positions.pdb" + top_description = f"Initial positions for simulation {self.sim_id}" + with open(top_name, "w") as f: + PDBFile.writeFile( + self.simulation.topology, + self.simulation.context.getState(getPositions=True).getPositions(), + f, + ) + self.path_registry.map_path(f"top_{self.sim_id}", top_name, top_description) + print("Initial Positions saved to initial_positions.pdb") st.markdown("Minimization complete! Equilibrating...", unsafe_allow_html=True) print("Equilibrating...") _temp = self.int_params["Temperature"] @@ -1148,7 +1196,14 @@ def _run(self, **input_args): print("simulation set!") st.markdown("simulation set!", unsafe_allow_html=True) except ValueError as e: - return str(e) + f"This were the inputs {input_args}" + msg = str(e) + f"This were the inputs {input_args}" + if "No template for" in msg: + msg += ( + "This error is likely due to non standard residues " + "in the protein, if you havent done it yet, try " + "cleaning the pdb file using the cleaning tool" + ) + return msg except FileNotFoundError: return f"File not found, check File id. This were the inputs {input_args}" except OpenMMException as e: @@ -1181,8 +1236,19 @@ def _run(self, **input_args): for record in records: os.rename(record[1].split("/")[-1], f"{record[1]}") for record in records: + record[0] = self.path_registry.get_fileid( # Step necessary here to + record[1].split("/")[-1], # avoid id being repeated + FileType.RECORD, + ) self.path_registry.map_path(*record) - return "Simulation done!" + return ( + "Simulation done! \n Summary: \n" + "Record files written to files/records/ with IDs and descriptions: " + f"{[(record[0],record[2]) for record in records]}\n" + "Standalone script written to files/simulations/ with ID: " + f"{sim_id}.\n" + f"The initial topology file ID is top_{sim_id} saved in files/pdb/" + ) except Exception as e: print(f"An exception was found: {str(e)}.") return f"An exception was found trying to write the filenames: {str(e)}." @@ -1383,9 +1449,12 @@ def _process_parameters(self, user_params, param_type="system_params"): try: processed_params[key] = float(value) except TypeError as e: - error_msg += f"""Invalid ewaldErrorTolerance: {e}. - If you are using null or None, just dont include - as part of the parameters.\n""" + error_msg += ( + f"Invalid ewaldErrorTolerance: {e}. " + "If you are using null or None, " + "just dont include it " + "as part of the parameters.\n" + ) if key == "constraints": try: if type(value) == str: @@ -1398,14 +1467,19 @@ def _process_parameters(self, user_params, param_type="system_params"): elif value == "HAngles": processed_params[key] = HAngles else: - error_msg += f"""Invalid constraints. got {value}. - Try using None, HBonds, AllBonds, - HAngles""" + error_msg += ( + f"Invalid constraints: Got {value}. " + "Try using None, HBonds, AllBonds or " + "HAngles\n" + ) else: processed_params[key] = value except TypeError as e: - error_msg += f"""Invalid constraints: {e}. If you are using - null or None, just dont include as part of the parameters.\n""" + error_msg += ( + f"Invalid constraints: {e}. If you are using " + "null or None, just dont include as " + "part of the parameters.\n" + ) if key == "rigidWater" or key == "rigidwater": if type(value) == bool: processed_params[key] = value @@ -1414,17 +1488,42 @@ def _process_parameters(self, user_params, param_type="system_params"): elif value == "False": processed_params[key] = False else: - error_msg += f"""Invalid rigidWater. got {value}. - Try using True or False.\n""" + error_msg += ( + f"Invalid rigidWater: got {value}. " + "Try using True or False.\n" + ) if key == "constraintTolerance" or key == "constrainttolerance": try: processed_params[key] = float(value) except ValueError as e: - error_msg += f"Invalid constraintTolerance. {e}." + error_msg += f"Invalid constraintTolerance: {e}." except TypeError as e: - error_msg += f"""Invalid constraintTolerance. {e}. If - constraintTolerance is null or None, - just dont include as part of the parameters.\n""" + error_msg += ( + f"Invalid constraintTolerance: {e}. If " + "constraintTolerance is null or None, " + "just dont include as part of " + "the parameters.\n" + ) + if key == "solvate": + try: + if type(value) == bool: + processed_params[key] = value + elif value == "True": + processed_params[key] = True + elif value == "False": + processed_params[key] = False + else: + error_msg += ( + f"Invalid solvate: got {value}. " + "Use either True or False.\n" + ) + except TypeError as e: + error_msg += ( + f"Invalid solvate: {e}. If solvate is null or " + "None, just dont include as part of " + "the parameters.\n" + ) + return processed_params, error_msg if param_type == "integrator_params": for key, value in user_params.items(): @@ -1438,9 +1537,11 @@ def _process_parameters(self, user_params, param_type="system_params"): elif value == "Brownian" or value == BrownianIntegrator: processed_params[key] = "Brownian" else: - error_msg += f"""\nInvalid integrator_type. got {value}. - Try using LangevinMiddle, Langevin, - Verlet, or Brownian.""" + error_msg += ( + f"Invalid integrator_type: got {value}. " + "Try using LangevinMiddle, Langevin, " + "Verlet, or Brownian.\n" + ) if key == "Temperature" or key == "temperature": temperature, msg = self.parse_temperature(value) processed_params[key] = temperature @@ -1469,8 +1570,10 @@ def _process_parameters(self, user_params, param_type="system_params"): elif value == "NVE": processed_params[key] = "NVE" else: - error_msg += f"""Invalid Ensemble. got {value}. - Try using NPT, NVT, or NVE.""" + error_msg += ( + f"Invalid Ensemble. got {value}. " + "Try using NPT, NVT, or NVE.\n" + ) if key == "Number of Steps" or key == "number of steps": processed_params[key] = int(value) @@ -1501,6 +1604,7 @@ def check_system_params(cls, values): "constraints": AllBonds, "rigidWater": True, "constraintTolerance": 0.00001, + "solvate": False, } integrator_params = values.get("integrator_params") if integrator_params: @@ -1618,7 +1722,7 @@ def check_system_params(cls, values): if file not in FORCEFIELD_LIST: error_msg += "The forcefield file is not present" - save = values.get("final", False) + save = values.get("save", True) if type(save) != bool: error_msg += "save must be a boolean value" diff --git a/mdagent/tools/maketools.py b/mdagent/tools/maketools.py index c3878b57..84a4e779 100644 --- a/mdagent/tools/maketools.py +++ b/mdagent/tools/maketools.py @@ -19,13 +19,14 @@ CleaningToolFunction, ListRegistryPaths, ModifyBaseSimulationScriptTool, - Name2PDBTool, PackMolTool, PPIDistance, + ProteinName2PDBTool, RMSDCalculator, Scholar2ResultLLM, SetUpandRunFunction, SimulationOutputFigures, + SmallMolPDB, VisualizeProtein, ) from .subagent_tools import RetryExecuteSkill, SkillRetrieval, WorkflowPlan @@ -80,8 +81,9 @@ def make_all_tools( CheckDirectoryFiles(), ListRegistryPaths(path_registry=path_instance), # MapPath2Name(path_registry=path_instance), - Name2PDBTool(path_registry=path_instance), + ProteinName2PDBTool(path_registry=path_instance), PackMolTool(path_registry=path_instance), + SmallMolPDB(path_registry=path_instance), VisualizeProtein(path_registry=path_instance), PPIDistance(), RMSDCalculator(), @@ -89,18 +91,19 @@ def make_all_tools( ModifyBaseSimulationScriptTool(path_registry=path_instance, llm=llm), SimulationOutputFigures(), ] - - # tools using subagents if subagent_settings is None: subagent_settings = SubAgentSettings(path_registry=path_instance) + + # tools using subagents subagents_tools = [] if not skip_subagents: subagents_tools = [ CreateNewTool(subagent_settings=subagent_settings), RetryExecuteSkill(subagent_settings=subagent_settings), SkillRetrieval(subagent_settings=subagent_settings), - WorkflowPlan(subagent_settings=subagent_settings), ] + if subagent_settings.curriculum: + WorkflowPlan(subagent_settings=subagent_settings) # add 'learned' tools here # disclaimer: assume they don't need path_registry @@ -125,7 +128,7 @@ def get_tools( llm: BaseLanguageModel, subagent_settings: Optional[SubAgentSettings] = None, top_k_tools=15, - subagents_required=True, + skip_subagents=False, human=False, ): if subagent_settings: @@ -134,7 +137,7 @@ def get_tools( ckpt_dir = "ckpt" retrieved_tools = [] - if subagents_required: + if not skip_subagents: # add subagents-related tools by default retrieved_tools = [ CreateNewTool(subagent_settings=subagent_settings), @@ -191,7 +194,8 @@ class CreateNewToolInputSchema(BaseModel): orig_prompt: str = Field(description="Full user prompt you got from the beginning.") curr_tools: str = Field( description="""List of all tools you have access to. Such as - this tool, 'ExecuteSkill', 'SkillRetrieval', and maybe `Name2PDBTool`, etc.""" + this tool, 'ExecuteSkill', + 'SkillRetrieval', and maybe `ProteinName2PDBTool`, etc.""" ) execute: Optional[bool] = Field( True, diff --git a/mdagent/tools/subagent_tools.py b/mdagent/tools/subagent_tools.py index 410c7310..f369157e 100644 --- a/mdagent/tools/subagent_tools.py +++ b/mdagent/tools/subagent_tools.py @@ -97,7 +97,8 @@ class WorkflowPlanInputSchema(BaseModel): ) curr_tools: str = Field( description="""List of all tools you have access to. Such as - this tool, 'ExecuteSkill', 'SkillRetrieval', and maybe `Name2PDBTool`, etc.""" + this tool, 'ExecuteSkill', + 'SkillRetrieval', and maybe `ProteinName2PDBTool`, etc.""" ) files: str = Field(description="List of all files you have access to.") # ^ would be nice if MDAgent could give files in case user provides unmapped files diff --git a/mdagent/utils/path_registry.py b/mdagent/utils/path_registry.py index 702833ca..46f730af 100644 --- a/mdagent/utils/path_registry.py +++ b/mdagent/utils/path_registry.py @@ -10,6 +10,8 @@ class FileType(Enum): PROTEIN = 1 SIMULATION = 2 RECORD = 3 + SOLVENT = 4 + UNKNOWN = 5 class PathRegistry: @@ -23,6 +25,71 @@ def get_instance(cls): def __init__(self): self.json_file_path = "paths_registry.json" + self._init_path_registry() + + def _init_path_registry(self): + base_directory = "files" + subdirectories = ["pdb", "records", "simulations", "solvents"] + existing_registry = self._load_existing_registry() + file_names_in_registry = [] + if existing_registry != {}: + for _, registry in existing_registry.items(): + file_names_in_registry.append(registry["name"]) + else: + with open(self.json_file_path, "w") as json_file: + json.dump({}, json_file) + for subdir in subdirectories: + subdir_path = os.path.join(base_directory, subdir) + if os.path.exists(subdir_path): + for file_name in os.listdir(subdir_path): + if file_name not in file_names_in_registry: + file_type = self._determine_file_type(subdir) + file_id = self.get_fileid(file_name, file_type) + # TODO get descriptions from file names if possible + # TODO make this a method. In theory, previous downlaods + # or simulation files should be already registered + if file_type == FileType.PROTEIN: + name_parts = file_name.split("_") + protein_name = name_parts[0] + status = name_parts[1] + description = ( + f"Protein {protein_name} pdb file. " + "downloaded from RCSB Protein Data Bank. " + + ( + "Preprocessed for simulation." + if status == "Clean" + else "" + ) + ) + elif file_type == FileType.SOLVENT: + name_parts = file_name.split("_") + solvent_name = name_parts[0] + description = f"Solvent {solvent_name} pdb file. " + else: + description = "Auto-Registered during registry init." + self.map_path( + file_id, subdir_path + "/" + file_name, description + ) + + def _load_existing_registry(self): + if self._check_for_json(): + with open(self.json_file_path, "r") as json_file: + return json.load(json_file) + return {} + + def _determine_file_type(self, subdir): + # Implement logic to determine the file type based on the subdir name + # Example: + if subdir == "pdb": + return FileType.PROTEIN + elif subdir == "records": + return FileType.RECORD + elif subdir == "simulations": + return FileType.SIMULATION + elif subdir == "solvents": + return FileType.SOLVENT + else: + return FileType.UNKNOWN # or some default value def _get_full_path(self, file_path): return os.path.abspath(file_path) @@ -59,18 +126,21 @@ def _check_json_content(self, name): def map_path(self, file_id, path, description=None): description = description or "No description provided" full_path = self._get_full_path(path) - path_dict = {file_id: {"path": full_path, "description": description}} + file_name = os.path.basename(full_path) + path_dict = { + file_id: {"path": full_path, "name": file_name, "description": description} + } self._save_mapping_to_json(path_dict) saved = self._check_json_content(file_id) return f"Path {'successfully' if saved else 'not'} mapped to name: {file_id}" # this if we want to get the path. not use as often - def get_mapped_path(self, name): + def get_mapped_path(self, fileid): if not self._check_for_json(): return "The JSON file does not exist." with open(self.json_file_path, "r") as json_file: data = json.load(json_file) - return data.get(name, {}).get("path", "Name not found in path registry.") + return data.get(fileid, {}).get("path", "Name not found in path registry.") def _clear_json(self): if self._check_for_json(): @@ -79,27 +149,27 @@ def _clear_json(self): return "JSON file cleared" return "JSON file does not exist" - def _remove_path_from_json(self, name): + def _remove_path_from_json(self, fileid): if not self._check_for_json(): return "JSON file does not exist" with open(self.json_file_path, "r") as json_file: data = json.load(json_file) - if name in data: - del data[name] + if fileid in data: + del data[fileid] with open(self.json_file_path, "w") as json_file: json.dump(data, json_file, indent=4) - return f"Path {name} removed from registry" - return f"Path {name} not found in registry" + return f"File {fileid} removed from registry" + return f"Path {fileid} not found in registry" def list_path_names(self): if not self._check_for_json(): return "JSON file does not exist" with open(self.json_file_path, "r") as json_file: data = json.load(json_file) - names = [key for key in data.keys()] + filesids = [key for key in data.keys()] return ( - "Names found in registry: " + ", ".join(names) - if names + "Names found in registry: " + ", ".join(filesids) + if filesids else "No names found. The JSON file is empty or does not" "contain name mappings." ) @@ -109,14 +179,15 @@ def list_path_names_and_descriptions(self): return "JSON file does not exist" with open(self.json_file_path, "r") as json_file: data = json.load(json_file) - names = [key for key in data.keys()] + filesids = [key for key in data.keys()] descriptions = [data[key]["description"] for key in data.keys()] - names_w_descriptions = [ - f"{name}: {description}" for name, description in zip(names, descriptions) + fileid_w_descriptions = [ + f"{fileid}: {description}" + for fileid, description in zip(filesids, descriptions) ] return ( - "Files found in registry: " + ", ".join(names_w_descriptions) - if names + "Files found in registry: " + ", ".join(fileid_w_descriptions) + if filesids else "No names found. The JSON file is empty or does not" "contain name mappings." ) @@ -134,20 +205,34 @@ def get_fileid(self, file_name: str, type: FileType): # Split the filename on underscores parts, ending = file_name.split(".") parts_list = parts.split("_") - + current_ids = self.list_path_names() # Extract the timestamp (assuming it's always in the second to last part) timestamp_part = parts_list[-1] # Get the last 6 digits of the timestamp - timestamp_digits = timestamp_part[-6:] + timestamp_digits = ( + timestamp_part[-6:] if timestamp_part.isnumeric() else "000000" + ) if type == FileType.PROTEIN: # Extract the PDB ID (assuming it's always the first part) pdb_id = parts_list[0] return pdb_id + "_" + timestamp_digits if type == FileType.SIMULATION: - return "sim" + "_" + timestamp_digits + num = 0 + sim_id = "sim" + f"{num}" + "_" + timestamp_digits + while sim_id in current_ids: + num += 1 + sim_id = "sim" + f"{num}" + "_" + timestamp_digits + return sim_id if type == FileType.RECORD: - return "rec" + "_" + timestamp_digits + num = 0 + rec_id = "rec" + f"{num}" + "_" + timestamp_digits + while rec_id in current_ids: + num += 1 + rec_id = "rec" + f"{num}" + "_" + timestamp_digits + return rec_id + if type == FileType.SOLVENT: + return parts + "_" + timestamp_digits def write_file_name(self, type: FileType, **kwargs): time_stamp = self.get_timestamp() diff --git a/setup.py b/setup.py index a25ff0db..474d02e8 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ "requests", "rmrkl", "tiktoken", + "rdkit", "streamlit", ], test_suite="tests", diff --git a/st_app.py b/st_app.py index ab21a360..a200527c 100644 --- a/st_app.py +++ b/st_app.py @@ -17,11 +17,11 @@ # Streamlit app st.title("MDAgent") -# option = st.selectbox("Choose an option:", ["Explore & Learn", "Use Learned Skills"]) -# if option == "Explore & Learn": -# explore = True -# else: -# explore = False +option = st.selectbox("Choose an option:", ["Explore & Learn", "Use Learned Skills"]) +if option == "Explore & Learn": + learn = True +else: + learn = False resume_op = st.selectbox("Resume:", ["False", "True"]) if resume_op == "True": @@ -29,6 +29,7 @@ else: resume = False + # for now I'm just going to allow pdb and cif files - we can add more later uploaded_files = st.file_uploader( "Upload a .pdb or .cif file", type=["pdb", "cif"], accept_multiple_files=True @@ -45,7 +46,7 @@ else: uploaded_file = [] -mdagent = MDAgent(resume=resume, uploaded_files=uploaded_file) +mdagent = MDAgent(resume=resume, uploaded_files=uploaded_file, learn=learn) def generate_response(prompt): diff --git a/tests/test_agent.py b/tests/test_agent.py index b24bc209..a31555fb 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -3,6 +3,7 @@ import pytest +from mdagent.mainagent.agent import MDAgent from mdagent.subagents.agents.action import Action from mdagent.subagents.agents.skill import SkillManager from mdagent.subagents.subagent_fxns import Iterator @@ -231,3 +232,17 @@ def test_update_skill_library(skill_manager): path="/mock_dir/code/test_function.py", description="Code for new tool test_function", ) + + +def test_mdagent_learn_init(): + mdagent_skill = MDAgent(learn=False) + assert mdagent_skill.skip_subagents is True + mdagent_learn = MDAgent(learn=True) + assert mdagent_learn.skip_subagents is False + + +def test_mdagent_curriculum(): + mdagent_curr = MDAgent(curriculum=True) + mdagent_no_curr = MDAgent(curriculum=False) + assert mdagent_curr.subagents_settings.curriculum is True + assert mdagent_no_curr.subagents_settings.curriculum is False diff --git a/tests/test_fxns.py b/tests/test_fxns.py index 67f6ae34..19b528e9 100644 --- a/tests/test_fxns.py +++ b/tests/test_fxns.py @@ -1,5 +1,6 @@ import json import os +import time import warnings from unittest.mock import MagicMock, mock_open, patch @@ -12,6 +13,7 @@ get_pdb, ) from mdagent.tools.base_tools.analysis_tools.plot_tools import plot_data, process_csv +from mdagent.tools.base_tools.preprocess_tools.pdb_tools import MolPDB, PackMolTool from mdagent.utils import FileType, PathRegistry warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") @@ -39,6 +41,11 @@ def cleaning_fxns(): return CleaningTools() +@pytest.fixture +def molpdb(): + return MolPDB() + + # Test simulation tools @pytest.fixture def sim_fxns(): @@ -62,6 +69,11 @@ def get_registry(): return PathRegistry() +@pytest.fixture +def packmol(get_registry): + return PackMolTool(get_registry) + + def test_process_csv(): mock_csv_content = "Time,Value1,Value2\n1,10,20\n2,15,25" mock_reader = MagicMock() @@ -133,7 +145,7 @@ def test_add_hydrogens_and_remove_water(path_to_cif, cleaning_fxns, get_registry @patch("os.path.exists") @patch("os.listdir") -def test_extract_parameters_path(mock_listdir, mock_exists, sim_fxns, get_registry): +def test_extract_parameters_path(mock_listdir, mock_exists, sim_fxns): # Test when parameters.json exists mock_exists.return_value = True assert sim_fxns._extract_parameters_path() == "simulation_parameters_summary.json" @@ -245,10 +257,17 @@ def test_map_path(): mock_json_data = { "existing_name": { "path": "existing/path", + "name": "path", "description": "Existing description", } } - new_path_dict = {"new_name": {"path": "new/path", "description": "New description"}} + new_path_dict = { + "new_name": { + "path": "new/path", + "name": "path", + "description": "New description", + } + } updated_json_data = {**mock_json_data, **new_path_dict} path_registry = PathRegistry() @@ -281,3 +300,117 @@ def test_map_path(): # Check the result message assert result == "Path successfully mapped to name: new_name" + + +def test_small_molecule_pdb(molpdb, get_registry): + # Test with a valid SMILES string + valid_smiles = "C1=CC=CC=C1" # Benzene + expected_output = ( + "PDB file for C1=CC=CC=C1 successfully created and saved to " + "files/pdb/benzene.pdb." + ) + assert molpdb.small_molecule_pdb(valid_smiles, get_registry) == expected_output + assert os.path.exists("files/pdb/benzene.pdb") + os.remove("files/pdb/benzene.pdb") # Clean up + + # test with invalid SMILES string and invalid molecule name + invalid_smiles = "C1=CC=CC=C1X" + invalid_name = "NotAMolecule" + expected_output = ( + "There was an error getting pdb. Please input a single molecule name." + ) + assert molpdb.small_molecule_pdb(invalid_smiles, get_registry) == expected_output + assert molpdb.small_molecule_pdb(invalid_name, get_registry) == expected_output + + # test with valid molecule name + valid_name = "water" + expected_output = ( + "PDB file for water successfully created and " "saved to files/pdb/water.pdb." + ) + assert molpdb.small_molecule_pdb(valid_name, get_registry) == expected_output + assert os.path.exists("files/pdb/water.pdb") + os.remove("files/pdb/water.pdb") # Clean up + + +def test_packmol_sm_download_called(packmol): + path_registry = PathRegistry() + path_registry._remove_path_from_json("water") + path_registry._remove_path_from_json("benzene") + path_registry.map_path("1A3N_144150", "files/pdb/1A3N_144150.pdb", "pdb") + with patch( + "mdagent.tools.base_tools.preprocess_tools.pdb_tools.PackMolTool._get_sm_pdbs", + new=MagicMock(), + ) as mock_get_sm_pdbs: + test_values = { + "pdbfiles_id": ["1A3N_144150"], + "small_molecules": ["water", "benzene"], + "number_of_molecules": [1, 10, 10], + "instructions": [ + ["inside box 0. 0. 0. 100. 100. 100."], + ["inside box 0. 0. 0. 100. 100. 100."], + ["inside box 0. 0. 0. 100. 100. 100."], + ], + } + + packmol._run(**test_values) + + mock_get_sm_pdbs.assert_called_with(["water", "benzene"]) + + +def test_packmol_download_only(packmol): + path_registry = PathRegistry() + path_registry._remove_path_from_json("water") + path_registry._remove_path_from_json("benzene") + small_molecules = ["water", "benzene"] + packmol._get_sm_pdbs(small_molecules) + assert os.path.exists("files/pdb/water.pdb") + assert os.path.exists("files/pdb/benzene.pdb") + os.remove("files/pdb/water.pdb") + os.remove("files/pdb/benzene.pdb") + + +def test_packmol_download_only_once(packmol): + path_registry = PathRegistry() + path_registry._remove_path_from_json("water") + small_molecules = ["water"] + packmol._get_sm_pdbs(small_molecules) + assert os.path.exists("files/pdb/water.pdb") + water_time = os.path.getmtime("files/pdb/water.pdb") + time.sleep(5) + + # Call the function again with the same molecule + packmol._get_sm_pdbs(small_molecules) + water_time_after = os.path.getmtime("files/pdb/water.pdb") + + assert water_time == water_time_after + # Clean up + os.remove("files/pdb/water.pdb") + + +mocked_files = {"files/solvents": ["water.pdb"]} + + +def mock_exists(path): + return path in mocked_files + + +def mock_listdir(path): + return mocked_files.get(path, []) + + +@pytest.fixture +def path_registry_with_mocked_fs(): + with patch("os.path.exists", side_effect=mock_exists): + with patch("os.listdir", side_effect=mock_listdir): + registry = PathRegistry() + registry.get_timestamp = lambda: "20240109" + return registry + + +def test_init_path_registry(path_registry_with_mocked_fs): + # This test will run with the mocked file system + # Here, you can assert if 'water.pdb' under 'solvents' is registered correctly + # Depending on how your PathRegistry class stores the registry, + # you may need to check the internal state or the contents of the JSON file. + # For example: + assert "water_000000" in path_registry_with_mocked_fs.list_path_names()