Skip to content

Commit

Permalink
updated most files, still need tests and setup
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 committed Feb 22, 2024
1 parent c05de8b commit 68a1654
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 76 deletions.
113 changes: 64 additions & 49 deletions mdagent/tools/base_tools/preprocess_tools/pdb_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,16 @@
from pydantic import BaseModel, Field, ValidationError, root_validator
from rdkit import Chem

from mdagent.utils import FileType, PathRegistry
from mdagent.utils import FileType, PathRegistry, move_files_to_ckpt_path

from .elements import list_of_elements


def get_pdb(query_string, path_registry=None):
def get_pdb(query_string, path_registry):
"""
Search RSCB's protein data bank using the given query string
and return the path to pdb file in either CIF or PDB format
"""
if path_registry is None:
path_registry = PathRegistry.get_instance()
url = "https://search.rcsb.org/rcsbsearch/v2/query?json={search-request}"
query = {
"query": {
Expand Down Expand Up @@ -53,15 +51,15 @@ def get_pdb(query_string, path_registry=None):
file_format=filetype,
)
file_id = path_registry.get_fileid(filename, FileType.PROTEIN)
directory = "files/pdb"
directory = path_registry.get_current_ckpt() + "files/pdb"
# Create the directory if it does not exist
if not os.path.exists(directory):
os.makedirs(directory)

with open(f"{directory}/{filename}", "w") as file:
file_path = f"{directory}/{filename}"
with open(file_path, "w") as file:
file.write(pdb.text)

return filename, file_id
return file_path, filename, file_id
return None


Expand All @@ -88,13 +86,13 @@ def _run(self, query: str) -> str:
try:
if self.path_registry is None: # this should not happen
return "Path registry not initialized"
filename, pdbfile_id = get_pdb(query, self.path_registry)
file_path, filename, pdbfile_id = get_pdb(query, self.path_registry)
if pdbfile_id is None:
return "Name2PDB tool failed to find and download PDB file."
else:
self.path_registry.map_path(
pdbfile_id,
f"files/pdb/{filename}",
file_path,
f"PDB file downloaded from RSCB, PDBFile ID: {pdbfile_id}",
)
return f"Name2PDB tool successful. downloaded the PDB file:{pdbfile_id}"
Expand Down Expand Up @@ -282,12 +280,13 @@ def get_number_of_atoms(self):

class PackmolBox:
def __init__(
self, file_number=1, file_description="PDB file for simulation with: \n"
self, ckpt, file_number=1, file_description="PDB file for simulation with: \n"
):
self.molecules = []
self.file_number = 1
self.file_description = file_description
self.final_name = None
self.ckpt = ckpt

def add_molecule(self, molecule):
self.molecules.append(molecule)
Expand All @@ -313,7 +312,9 @@ def generate_input_header(self):
]
)
)
while os.path.exists(f"files/pdb/{_final_name}_v{self.file_number}.pdb"):
while os.path.exists(
f"{self.ckpt}/files/pdb/{_final_name}_v{self.file_number}.pdb"
):
self.file_number += 1

self.final_name = f"{_final_name}_v{self.file_number}.pdb"
Expand All @@ -338,7 +339,7 @@ def generate_input(self):
# Convert list of input data to a single string
return "\n".join(input_data)

def run_packmol(self, PathRegistry):
def run_packmol(self, path_registry):
# Use the generated input to execute Packmol
input_string = self.generate_input()
# Write the input to a file
Expand Down Expand Up @@ -366,11 +367,11 @@ def run_packmol(self, PathRegistry):
for molecule in self.molecules:
os.remove(molecule.filename)
# name of packed pdb file
time_stamp = PathRegistry.get_timestamp()[-6:]
os.rename(self.final_name, f"files/pdb/{self.final_name}")
PathRegistry.map_path(
time_stamp = path_registry.get_timestamp()[-6:]
os.rename(self.final_name, f"{self.ckpt}/files/pdb/{self.final_name}")
path_registry.map_path(
f"PACKED_{time_stamp}",
f"files/pdb/{self.final_name}",
f"{self.ckpt}/files/pdb/{self.final_name}",
self.file_description,
)
# move file to files/pdb
Expand All @@ -391,7 +392,7 @@ def run_packmol(self, PathRegistry):


def packmol_wrapper(
PathRegistry,
path_registry,
pdbfiles: List,
files_id: List,
number_of_molecules: List,
Expand All @@ -401,7 +402,7 @@ def packmol_wrapper(
of different types of molecules molecules"""

# create a box
box = PackmolBox()
box = PackmolBox(ckpt=path_registry.get_current_ckpt())
# add molecules to the box
for (
pdbfile,
Expand All @@ -416,7 +417,7 @@ def packmol_wrapper(
# generate input
# run packmol
print("Packing:", box.file_description, "\nThe file name is:", box.final_name)
return box.run_packmol(PathRegistry)
return box.run_packmol(path_registry)


"""Args schema for packmol_wrapper tool. Useful for OpenAI functions"""
Expand Down Expand Up @@ -502,7 +503,7 @@ def _run(self, **values) -> str:
if self.path_registry is None: # this should not happen
raise ValidationError("Path registry not initialized")
try:
values = self.validate_input(values)
values = self.validate_input(values, self.path_registry)
except ValidationError as e:
return str(e)
error_msg = values.get("error", None)
Expand All @@ -517,7 +518,10 @@ def _run(self, **values) -> str:
pdbfile_names = [pdbfile.split("/")[-1] for pdbfile in pdbfiles]
# copy them to the current directory with temp_ names

pdbfile_names = [f"temp_{pdbfile_name}" for pdbfile_name in pdbfile_names]
pdbfile_names = [
f"{self.path_registry.get_current_ckpt()}/temp_{pdbfile_name}"
for pdbfile_name in pdbfile_names
]
number_of_molecules = values.get("number_of_molecules", [])
instructions = values.get("instructions", [])
small_molecules = values.get("small_molecules", [])
Expand All @@ -530,7 +534,7 @@ def _run(self, **values) -> str:
small_molecule.split("/")[-1] for small_molecule in small_molecules_files
]
small_molecules_file_names = [
f"temp_{small_molecule_file_name}"
f"{self.path_registry.get_current_ckpt()}/temp_{small_molecule_file_name}"
for small_molecule_file_name in small_molecules_file_names
]
# append small molecules to pdbfiles
Expand Down Expand Up @@ -563,7 +567,7 @@ def _run(self, **values) -> str:
instructions=instructions,
)

def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict:
def validate_input(cls, values: Union[str, Dict[str, Any]], path_registry) -> Dict:
# check if is only a string
if isinstance(values, str):
print("values is a string", values)
Expand Down Expand Up @@ -640,8 +644,7 @@ def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict:
}

# Further validation, e.g., checking if files exist
registry = PathRegistry.get_instance()
file_ids = registry.list_path_names()
file_ids = path_registry.list_path_names()

for pdbfile_id in pdbfiles:
if "_" not in pdbfile_id:
Expand All @@ -653,7 +656,7 @@ def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict:
if pdbfile_id not in file_ids:
# look for files in the current directory
# that match some part of the pdbfile
ids_w_description = registry.list_path_names_and_descriptions()
ids_w_description = path_registry.list_path_names_and_descriptions()

return {
"error": (
Expand All @@ -664,7 +667,7 @@ def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict:
}
for small_molecule in small_molecules:
if small_molecule not in file_ids:
result = molPDB.small_molecule_pdb(small_molecule, registry)
result = molPDB.small_molecule_pdb(small_molecule, path_registry)
if "successfully" not in result:
return {
"error": (
Expand Down Expand Up @@ -887,9 +890,9 @@ def fix_element_column(pdb_file, custom_element_dict=None):

# extract Title, Header, Remarks, and Cryst1 records
file_name = pdb_file.split(".")[0]
# check if theres a file-name-fixed.pdb file
if os.path.isfile(file_name + "-fixed.pdb"):
pdb_file = file_name + "-fixed.pdb"
# check if theres a file-name_fixed.pdb file
if os.path.isfile(file_name + "_fixed.pdb"):
pdb_file = file_name + "_fixed.pdb"
assert isinstance(pdb_file, str), "pdb_file must be a string"
with open(pdb_file, "r") as f:
print("I read the initial file")
Expand Down Expand Up @@ -924,7 +927,7 @@ def fix_element_column(pdb_file, custom_element_dict=None):
# join the linees
new_pdb = "".join(new_pdb)
# write new pdb file as pdb_file-fixed.pdb
new_pdb_file = file_name.split(".")[0] + "-fixed.pdb"
new_pdb_file = file_name.split(".")[0] + "_fixed.pdb"
print("name of fixed pdb file", new_pdb_file)
# write the unchanged records first and then the new pdb file
assert isinstance(new_pdb_file, str), "new_pdb_file must be a string"
Expand Down Expand Up @@ -1018,8 +1021,8 @@ def fix_temp_factor_column(pdb_file, bfactor=1.00, only_fill=True):
return "pdb_file must be a string"
file_name = pdb_file.split(".")[0]

if os.path.isfile(file_name + "-fixed.pdb"):
file_name = file_name + "-fixed.pdb"
if os.path.isfile(file_name + "_fixed.pdb"):
file_name = file_name + "_fixed.pdb"

assert isinstance(file_name, str), "pdb_file must be a string"
with open(file_name, "r") as f:
Expand Down Expand Up @@ -1052,12 +1055,12 @@ def fix_temp_factor_column(pdb_file, bfactor=1.00, only_fill=True):
new_pdb = _fix_temp_factor_column(pdb_file_lines, bfactor, only_fill)
# join the linees
new_pdb = "".join(new_pdb)
# write new pdb file as pdb_file-fixed.pdb
new_pdb_file = file_name + "-fixed.pdb"
# write new pdb file as pdb_file_fixed.pdb
new_pdb_file = file_name + "_fixed.pdb"
# organize columns HEADER, TITLE, REMARKS, CRYST1, ATOM, HETATM, CONECT, MASTER, END

assert isinstance(new_pdb_file, str), "new_pdb_file must be a string"
# write new pdb file as pdb_file-fixed.pdb
# write new pdb file as pdb_file_fixed.pdb
with open(new_pdb_file, "w") as f:
f.writelines(_unchanged_records)
f.write(new_pdb)
Expand Down Expand Up @@ -1137,8 +1140,8 @@ def fix_occupancy_columns(pdb_file, occupancy=1.0, only_fill=True):
# extract Title, Header, Remarks, and Cryst1 records
# get name from pdb_file
file_name = pdb_file.split(".")[0]
if os.path.isfile(file_name + "-fixed.pdb"):
file_name = file_name + "-fixed.pdb"
if os.path.isfile(file_name + "_fixed.pdb"):
file_name = file_name + "_fixed.pdb"

assert isinstance(pdb_file, str), "pdb_file must be a string"
with open(file_name, "r") as f:
Expand Down Expand Up @@ -1169,10 +1172,10 @@ def fix_occupancy_columns(pdb_file, occupancy=1.0, only_fill=True):
new_pdb = _fix_occupancy_column(pdb_file_lines, occupancy, only_fill)
# join the linees
new_pdb = "".join(new_pdb)
# write new pdb file as pdb_file-fixed.pdb
new_pdb_file = file_name + "-fixed.pdb"
# write new pdb file as pdb_file_fixed.pdb
new_pdb_file = file_name + "_fixed.pdb"

# write new pdb file as pdb_file-fixed.pdb
# write new pdb file as pdb_file_fixed.pdb
assert isinstance(new_pdb_file, str), "new_pdb_file must be a string"
with open(new_pdb_file, "w") as f:
f.writelines(_unchanged_records)
Expand Down Expand Up @@ -1264,7 +1267,7 @@ class PDBFilesFixInp(BaseModel):
)

@root_validator
def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict:
def validate_input(cls, values: Union[str, Dict[str, Any]], path_registry) -> Dict:
if isinstance(values, str):
print("values is a string", values)
raise ValidationError("Input must be a dictionary")
Expand Down Expand Up @@ -1346,19 +1349,31 @@ def _run(self, query: Dict):
if "Occupancy" in error_set:
fix_occupancy_columns(pdbfile)

validate = validate_pdb_format(pdbfile + "-fixed.pdb")
validate = validate_pdb_format(pdbfile + "_fixed.pdb")
if validate[0] == 0:
name = pdbfile + "-fixed.pdb"
name = pdbfile + "_fixed.pdb"
if "ckpt" not in name:
ckpt = self.path_registry.get_current_ckpt()
# move file
path = move_files_to_ckpt_path(name, ckpt)
else:
path = name
description = "PDB file fixed"
self.path_registry.map_path(name, name, description)
self.path_registry.map_path(name, path, description)
return "PDB file fixed"
else:
return "PDB not fully fixed"
else:
apply_fixes(pdbfile, query)
validate = validate_pdb_format(pdbfile + "-fixed.pdb")
validate = validate_pdb_format(pdbfile + "_fixed.pdb")
if validate[0] == 0:
name = pdbfile + "-fixed.pdb"
name = pdbfile + "_fixed.pdb"
if "ckpt" not in name:
ckpt = self.path_registry.get_current_ckpt()
# move file
path = move_files_to_ckpt_path(name, ckpt)
else:
path = name
description = "PDB file fixed"
self.path_registry.map_path(name, name, description)
return "PDB file fixed"
Expand Down Expand Up @@ -1446,7 +1461,7 @@ def small_molecule_pdb(self, mol_str: str, path_registry) -> str:
except Exception: # TODO: we should be more specific here
pass
Chem.AllChem.EmbedMolecule(m)
file_name = f"files/pdb/{mol_name}.pdb"
file_name = f"{path_registry.get_current_ckpt()}/files/pdb/{mol_name}.pdb"
Chem.MolToPDBFile(m, file_name)
# add to path registry
if path_registry:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _run(self, *args, **input):
type=FileType.SIMULATION, Sim_id=base_script_id, modified=True
)
file_id = self.path_registry.get_fileid(filename, type=FileType.SIMULATION)
directory = "files/simulations"
directory = f"{self.path_registry.get_current_ckpt()}/files/simulations"
if not os.path.exists(directory):
os.makedirs(directory)
with open(f"{directory}/{filename}", "w") as file:
Expand Down
Loading

0 comments on commit 68a1654

Please sign in to comment.