diff --git a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py index bbb17550..e7d2018b 100644 --- a/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py +++ b/mdagent/tools/base_tools/preprocess_tools/pdb_tools.py @@ -322,18 +322,23 @@ def run_packmol(self, PathRegistry): result = subprocess.run( "./" + cmd, shell=True, text=True, capture_output=True ) + if result.returncode != 0: + return ( + "Packmol failed to run. Please check the input file and try again." + ) PathRegistry.map_path( f"packed_structures_v{self.file_number}.pdb", f"packed_structures_v{self.file_number}.pdb", self.file_description, ) - print(result.stdout) # validate final pdb - pdb_validation = validate_pdb_format("packed_structures.pdb") + pdb_validation = validate_pdb_format(f"packed_structures{self.file_number}.pdb") if pdb_validation[0] == 0: # delete .inp files os.remove("packmol.inp") + for molecule in self.molecules: + os.remove(molecule.filename) return "PDB file validated successfully" elif pdb_validation[0] == 1: # format pdb_validation[1] list of errors @@ -394,10 +399,81 @@ class PackmolInput(BaseModel): ), ) - @root_validator + +class PackMolTool(BaseTool): + name: str = "packmol_tool" + description: str = ( + "Useful when you need to create a box " + "of different types of molecules.\n" + "Three different examples:\n" + "pdbfiles_id: ['1a2b_123456', 'water_000000']\n" + "number_of_molecules: [1, 1000]\n" + "instructions: [['inside box 0. 0. 0. 90. 90. 90.'], " + "['inside box 0. 0. 0. 90. 90. 90.']]\n" + "will pack 1 molecule of 1a2b_123456 and 1000 molecules of water_000000. \n" + "pdbfiles_id: ['1a2b_123456']\n" + "number_of_molecules: [1]\n" + "instructions: [['center\n fixed 0. 0. 0. 0. 0. 0.']]\n" + "This will fix the center of protein 1a2b_123456 at " + "the center of the box with no rotation.\n" + "pdbfiles_id: ['1a2b_123456']\n" + "number_of_molecules: [1]\n" + "instructions: [['outside sphere 2.30 3.40 4.50 8.0]]\n" + "This will place the protein 1a2b_123456 outside a sphere " + "centered at 2.30 3.40 4.50 with radius 8.0\n" + ) + + args_schema: Type[BaseModel] = PackmolInput + + path_registry: typing.Optional[PathRegistry] + + def __init__(self, path_registry: typing.Optional[PathRegistry]): + super().__init__() + self.path_registry = path_registry + + def _run(self, **values) -> str: + """use the tool.""" + + if self.path_registry is None: # this should not happen + raise ValidationError("Path registry not initialized") + try: + values = self.validate_input(values) + except ValidationError as e: + return str(e) + error_msg = values.get("error", None) + pdbfiles = values.get("pdbfiles_id", []) + pdbfiles = [self.path_registry.get_mapped_path(pdbfile) for pdbfile in pdbfiles] + pdbfile_names = [pdbfile.split("/")[-1] for pdbfile in pdbfiles] + # copy them to the current directory with temp_ names + for pdbfile, pdbfile_name in zip(pdbfiles, pdbfile_names): + os.system(f"cp {pdbfile} temp_{pdbfile_name}") + pdbfile_names = [f"temp_{pdbfile_name}" for pdbfile_name in pdbfile_names] + number_of_molecules = values.get("number_of_molecules", []) + instructions = values.get("instructions", []) + if error_msg: + return error_msg + # check if packmol is installed + cmd = "command -v packmol" + result = subprocess.run(cmd, shell=True, text=True, capture_output=True) + if result.returncode != 0: + result = subprocess.run( + "./" + cmd, shell=True, text=True, capture_output=True + ) + if result.returncode != 0: + return ( + "Packmol is not installed. Please install packmol " + "at 'https://m3g.github.io/packmol/download.shtml' and try again." + ) + + return packmol_wrapper( + self.path_registry, + pdbfiles=pdbfile_names, + number_of_molecules=number_of_molecules, + instructions=instructions, + ) + def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict: # check if is only a string - print("values", values) if isinstance(values, str): print("values is a string", values) raise ValidationError("Input must be a dictionary") @@ -408,7 +484,7 @@ def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict: if not (len(pdbfiles) == len(number_of_molecules) == len(instructions)): return { "error": ( - " The lengths of pdbfiles, number_of_molecules, " + "The lengths of pdbfiles, number_of_molecules, " "and instructions must be equal to use this tool." ) } @@ -441,6 +517,12 @@ def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict: file_ids = registry.list_path_names() for pdbfile_id in pdbfiles: + if "_" not in pdbfile_id: + return { + "error": ( + f"{pdbfile_id} is not a valid pdbfile_id" "in the path_registry" + ) + } if pdbfile_id not in file_ids: # look for files in the current directory # that match some part of the pdbfile @@ -455,55 +537,6 @@ def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict: } return values - -class PackMolTool(BaseTool): - name: str = "packmol_tool" - description: str = ( - "Useful when you need to create a box " - "of different types of molecules molecules" - ) - - args_schema: Type[BaseModel] = PackmolInput - - path_registry: typing.Optional[PathRegistry] - - def __init__(self, path_registry: typing.Optional[PathRegistry]): - super().__init__() - self.path_registry = path_registry - - def _run(self, **values) -> str: - """use the tool.""" - - if self.path_registry is None: # this should not happen - raise ValidationError("Path registry not initialized") - - error_msg = values.get("error", None) - pdbfiles = values.get("pdbfiles", []) - pdbfiles = [self.path_registry.get_mapped_path(pdbfile) for pdbfile in pdbfiles] - number_of_molecules = values.get("number_of_molecules", []) - instructions = values.get("instructions", []) - if error_msg: - return error_msg - # check if packmol is installed - cmd = "command -v packmol" - result = subprocess.run(cmd, shell=True, text=True, capture_output=True) - if result.returncode != 0: - result = subprocess.run( - "./" + cmd, shell=True, text=True, capture_output=True - ) - if result.returncode != 0: - return ( - "Packmol is not installed. Please install packmol " - "at 'https://m3g.github.io/packmol/download.shtml' and try again." - ) - - return packmol_wrapper( - self.path_registry, - pdbfiles=pdbfiles, - number_of_molecules=number_of_molecules, - instructions=instructions, - ) - async def _arun(self, values: str) -> str: """Use the tool asynchronously.""" raise NotImplementedError("custom_search does not support async") diff --git a/mdagent/utils/path_registry.py b/mdagent/utils/path_registry.py index ac2abbd0..76bc4158 100644 --- a/mdagent/utils/path_registry.py +++ b/mdagent/utils/path_registry.py @@ -40,9 +40,7 @@ def _init_path_registry(self): json.dump({}, json_file) for subdir in subdirectories: subdir_path = os.path.join(base_directory, subdir) - print("subdir_path: ", subdir_path) if os.path.exists(subdir_path): - print("this subdir exists: ", subdir_path) for file_name in os.listdir(subdir_path): if file_name not in file_names_in_registry: file_type = self._determine_file_type(subdir)