diff --git a/.env.example b/.env.example index fdee9af0..014b5807 100644 --- a/.env.example +++ b/.env.example @@ -5,7 +5,5 @@ # OpenAI API Key OPENAI_API_KEY=YOUR_OPENAI_API_KEY_GOES_HERE # pragma: allowlist secret -# PQA API Key to use LiteratureSearch tool (optional) -- it also requires OpenAI key -PQA_API_KEY=YOUR_PQA_API_KEY_GOES_HERE # pragma: allowlist secret # Optional: add TogetherAI, Fireworks, or Anthropic API key here to use their models diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bba975c9..bcd60379 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,28 +10,28 @@ repos: - id: mixed-line-ending - id: check-added-large-files - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.0.270" + rev: "v0.7.1" hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix ] - repo: https://github.com/psf/black - rev: "23.3.0" + rev: "24.10.0" hooks: - id: black language_version: python3 - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.3.0" + rev: "v1.13.0" hooks: - id: mypy args: [--pretty, --ignore-missing-imports] additional_dependencies: [types-requests] - repo: https://github.com/PyCQA/isort - rev: "5.12.0" + rev: "5.13.2" hooks: - id: isort args: [--profile=black] - repo: https://github.com/Yelp/detect-secrets - rev: v1.0.3 + rev: v1.5.0 hooks: - id: detect-secrets args: [--exclude-files, ".github/workflows/"] diff --git a/mdagent/tools/base_tools/__init__.py b/mdagent/tools/base_tools/__init__.py index 23a1fd21..04c4fa35 100644 --- a/mdagent/tools/base_tools/__init__.py +++ b/mdagent/tools/base_tools/__init__.py @@ -4,11 +4,7 @@ from .analysis_tools.plot_tools import SimulationOutputFigures from .analysis_tools.ppi_tools import PPIDistance from .analysis_tools.rdf_tool import RDFTool -from .analysis_tools.rgy import ( - RadiusofGyrationAverage, - RadiusofGyrationPerFrame, - RadiusofGyrationPlot, -) +from .analysis_tools.rgy import RadiusofGyrationTool from .analysis_tools.rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF from .analysis_tools.sasa import SolventAccessibleSurfaceArea from .analysis_tools.secondary_structure import ( @@ -80,9 +76,7 @@ "PCATool", "PPIDistance", "ProteinName2PDBTool", - "RadiusofGyrationAverage", - "RadiusofGyrationPerFrame", - "RadiusofGyrationPlot", + "RadiusofGyrationTool", "RDFTool", "RMSDCalculator", "Scholar2ResultLLM", diff --git a/mdagent/tools/base_tools/analysis_tools/__init__.py b/mdagent/tools/base_tools/analysis_tools/__init__.py index 81562527..26d58784 100644 --- a/mdagent/tools/base_tools/analysis_tools/__init__.py +++ b/mdagent/tools/base_tools/analysis_tools/__init__.py @@ -3,7 +3,7 @@ from .pca_tools import PCATool from .plot_tools import SimulationOutputFigures from .ppi_tools import PPIDistance -from .rgy import RadiusofGyrationAverage, RadiusofGyrationPerFrame, RadiusofGyrationPlot +from .rgy import RadiusofGyrationTool from .rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF from .sasa import SolventAccessibleSurfaceArea from .vis_tools import VisFunctions, VisualizeProtein @@ -17,9 +17,7 @@ "MomentOfInertia", "PCATool", "PPIDistance", - "RadiusofGyrationAverage", - "RadiusofGyrationPerFrame", - "RadiusofGyrationPlot", + "RadiusofGyrationTool", "RMSDCalculator", "SimulationOutputFigures", "SolventAccessibleSurfaceArea", diff --git a/mdagent/tools/base_tools/analysis_tools/plot_tools.py b/mdagent/tools/base_tools/analysis_tools/plot_tools.py index 81899b9a..dfd5d693 100644 --- a/mdagent/tools/base_tools/analysis_tools/plot_tools.py +++ b/mdagent/tools/base_tools/analysis_tools/plot_tools.py @@ -124,7 +124,7 @@ def _run(self, file_id: str) -> str: plotting_tools._find_file(file_id) plotting_tools.process_csv() plot_result = plotting_tools.plot_data() - if type(plot_result) == str: + if isinstance(plot_result, str): return "Succeeded. IDs of figures created: " + plot_result else: return "Failed. No figures created." diff --git a/mdagent/tools/base_tools/analysis_tools/rdf_tool.py b/mdagent/tools/base_tools/analysis_tools/rdf_tool.py index 2e6fd5d4..e6fa24cb 100644 --- a/mdagent/tools/base_tools/analysis_tools/rdf_tool.py +++ b/mdagent/tools/base_tools/analysis_tools/rdf_tool.py @@ -159,7 +159,7 @@ def validate_input(self, input): ) if stride: - if type(stride) != int: + if not isinstance(stride, int): try: stride = int(stride) if stride <= 0: diff --git a/mdagent/tools/base_tools/analysis_tools/rgy.py b/mdagent/tools/base_tools/analysis_tools/rgy.py index 61ad7698..5976d9f5 100644 --- a/mdagent/tools/base_tools/analysis_tools/rgy.py +++ b/mdagent/tools/base_tools/analysis_tools/rgy.py @@ -14,6 +14,7 @@ def __init__(self, path_registry): self.top_file = "" self.traj_file = "" self.traj = None + self.rgy_file = "" def _load_traj(self, top_file: str, traj_file: str): self.traj_file = traj_file @@ -25,38 +26,36 @@ def _load_traj(self, top_file: str, traj_file: str): traj_required=True, ) - def rgy_per_frame(self, force_recompute: bool = False) -> str: + def rgy_per_frame(self) -> str: rg_per_frame = md.compute_rg(self.traj) self.rgy_file = ( f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.traj_file}.csv" ) rgy_id = f"rgy_{self.traj_file}" - if rgy_id in self.path_registry.list_path_names() and force_recompute is False: - print("RGY already computed, skipping re-compute") - # todo -> maybe allow re-compute & save under different id/path - else: - np.savetxt( - self.rgy_file, - rg_per_frame, - delimiter=",", - header="Radius of Gyration (nm)", - ) - self.path_registry.map_path( - f"rgy_{self.traj_file}", - self.rgy_file, - description=f"Radii of gyration per frame for {self.traj_file}", - ) + np.savetxt( + self.rgy_file, + rg_per_frame, + delimiter=",", + header="Radius of Gyration (nm)", + ) + self.path_registry.map_path( + f"rgy_{self.traj_file}", + self.rgy_file, + description=f"Radii of gyration per frame for {self.traj_file}", + ) return f"Radii of gyration saved to {self.rgy_file} with id {rgy_id}." def rgy_average(self) -> str: - _ = self.rgy_per_frame() + if not self.rgy_file: + _ = self.rgy_per_frame() rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1) avg_rg = rg_per_frame.mean() return f"Average radius of gyration: {avg_rg:.2f} nm" def plot_rgy(self) -> str: - _ = self.rgy_per_frame() + if not self.rgy_file: + _ = self.rgy_per_frame() rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1) fig_analysis = f"rgy_{self.traj_file}" plot_name = self.path_registry.write_file_name( @@ -66,9 +65,8 @@ def plot_rgy(self) -> str: plot_id = self.path_registry.get_fileid( file_name=plot_name, type=FileType.FIGURE ) - if plot_name.endswith(".png"): - plot_name = plot_name.split(".png")[0] plot_path = f"{self.path_registry.ckpt_figures}/{plot_name}" + plot_path = plot_path if plot_path.endswith(".png") else plot_path + ".png" print("plot_path", plot_path) plt.plot(rg_per_frame) plt.xlabel("Frame") @@ -78,106 +76,51 @@ def plot_rgy(self) -> str: plt.savefig(f"{plot_path}") self.path_registry.map_path( plot_id, - plot_path + ".png", + plot_path, description=f"Plot of radii of gyration over time for {self.traj_file}", ) plt.close() plt.clf() - return "Plot saved as: " + f"{plot_name}.png with plot ID {plot_id}" - - -class RadiusofGyrationAverage(BaseTool): - name = "RadiusofGyrationAverage" - description = """This tool calculates the average radius of gyration - for a trajectory. Give this tool BOTH the trajectory file ID and the - topology file ID.""" + return "Plot saved as: " + f"{plot_name} with plot ID {plot_id}" - path_registry: Optional[PathRegistry] + def compute_plot_return_avg(self) -> str: + rgy_per_frame = self.rgy_per_frame() + avg_rgy = self.rgy_average() + plot_rgy = self.plot_rgy() + return rgy_per_frame + plot_rgy + avg_rgy - def __init__(self, path_registry): - super().__init__() - self.path_registry = path_registry - def _run(self, traj_file: str, top_file: str) -> str: - """use the tool.""" - RGY = RadiusofGyration(self.path_registry) - try: - RGY._load_traj(top_file=top_file, traj_file=traj_file) - except Exception as e: - return f"Error loading traj: {e}" - try: - return "Succeeded. " + RGY.rgy_average() - except ValueError as e: - return f"Failed. ValueError: {e}" - except Exception as e: - return f"Failed. {type(e).__name__}: {e}" - - async def _arun(self, query: str) -> str: - """Use the tool asynchronously.""" - raise NotImplementedError("custom_search does not support async") - - -class RadiusofGyrationPerFrame(BaseTool): - name = "RadiusofGyrationPerFrame" - description = """This tool calculates the radius of gyration - at each frame of a given trajectory. +class RadiusofGyrationTool(BaseTool): + name = "RadiusofGyrationTool" + description = """This tool calculates and plots + the radius of gyration + at each frame of a given trajectory and retuns the average. Give this tool BOTH the trajectory file ID and the - topology file ID. - The tool will save the radii of gyration to a csv file and - map it to the registry.""" + topology file ID.""" path_registry: Optional[PathRegistry] + rgy: Optional[RadiusofGyration] + load_traj: bool = True - def __init__(self, path_registry): + def __init__(self, path_registry, load_traj=True): super().__init__() self.path_registry = path_registry + self.rgy = RadiusofGyration(path_registry) + self.load_traj = load_traj # only for testing def _run(self, traj_file: str, top_file: str) -> str: """use the tool.""" - RGY = RadiusofGyration(self.path_registry) - try: - RGY._load_traj(top_file=top_file, traj_file=traj_file) - except Exception as e: - return f"Error loading traj: {e}" - try: - return "Succeeded. " + RGY.rgy_per_frame() - except ValueError as e: - return f"Failed. ValueError: {e}" - except Exception as e: - return f"Failed. {type(e).__name__}: {e}" - - async def _arun(self, query: str) -> str: - """Use the tool asynchronously.""" - raise NotImplementedError("custom_search does not support async") - - -class RadiusofGyrationPlot(BaseTool): - name = "RadiusofGyrationPlot" - description = """This tool calculates the radius of gyration - at each frame of a given trajectory file and plots it. - Give this tool BOTH the trajectory file ID and the - topology file ID. - The tool will save the plot to a png file and map it to the registry.""" - - path_registry: Optional[PathRegistry] + assert self.rgy is not None, "RadiusofGyration instance is not initialized" - def __init__(self, path_registry): - super().__init__() - self.path_registry = path_registry - - def _run(self, traj_file: str, top_file: str) -> str: - """use the tool.""" - RGY = RadiusofGyration(self.path_registry) - try: - RGY._load_traj(top_file=top_file, traj_file=traj_file) - except Exception as e: - return f"Error loading traj: {e}" + if self.load_traj: + try: + self.rgy._load_traj(top_file=top_file, traj_file=traj_file) + except Exception as e: + return f"Error loading traj: {e}" try: - return "Succeeded. " + RGY.plot_rgy() - except ValueError as e: - return f"Failed. ValueError: {e}" + return "Succeeded. " + self.rgy.compute_plot_return_avg() except Exception as e: - return f"Failed. {type(e).__name__}: {e}" + return f"Failed Computing RGY: {e}" async def _arun(self, query: str) -> str: """Use the tool asynchronously.""" diff --git a/mdagent/tools/base_tools/preprocess_tools/uniprot.py b/mdagent/tools/base_tools/preprocess_tools/uniprot.py index 28dfae69..309fd006 100644 --- a/mdagent/tools/base_tools/preprocess_tools/uniprot.py +++ b/mdagent/tools/base_tools/preprocess_tools/uniprot.py @@ -475,7 +475,7 @@ def get_sequence_info(self, query: str, primary_accession: str) -> dict: - 'crc64': The CRC64 hash of the protein sequence (probably not useful) - 'md5': The MD5 hash of the protein sequence (probably not useful) """ - seq_info = self.data = self.get_data(query, desired_field="sequence") + seq_info = self.get_data(query, desired_field="sequence") if not seq_info: return {} seq_info_specific = self._match_primary_accession(seq_info, primary_accession)[ @@ -693,9 +693,11 @@ def get_ids( if include_uniprotkbids: all_ids + [entry["uniProtkbId"] for entry in ids_] if ids_ else [] accession = self.get_data(query, desired_field="accession") - all_ids + [ - entry["primaryAccession"] for entry in accession - ] if accession else [] + ( + all_ids + [entry["primaryAccession"] for entry in accession] + if accession + else [] + ) if single_id: return [all_ids[0]] if all_ids else [] return list(set(all_ids)) diff --git a/mdagent/tools/base_tools/simulation_tools/setup_and_run.py b/mdagent/tools/base_tools/simulation_tools/setup_and_run.py index d6dfe023..ca1b3c30 100644 --- a/mdagent/tools/base_tools/simulation_tools/setup_and_run.py +++ b/mdagent/tools/base_tools/simulation_tools/setup_and_run.py @@ -265,7 +265,7 @@ def setup_system(self): raise ValueError(str(e)) else: raise ValueError( - f"Error building system. Please check the forcefield files {str(e)}" + f"Error building system. Please check the forcefield files {str(e)}. Included force fields are: {FORCEFIELD_LIST}" ) if self.sys_params.get("nonbondedMethod", None) in [ @@ -273,13 +273,21 @@ def setup_system(self): PME, ]: if self.sim_params["Ensemble"] == "NPT": - self.system.addForce( - MonteCarloBarostat( - self.int_params["Pressure"], - self.int_params["Temperature"], - self.sim_params.get("barostatInterval", 25), - ) + pressure = self.int_params.get("Pressure", 1.0) + + if "Pressure" not in self.int_params: + print( + "Warning: 'Pressure' not provided. ", + "Using default pressure of 1.0 atm.", + ) + + self.system.addForce( + MonteCarloBarostat( + pressure, + self.int_params["Temperature"], + self.sim_params.get("barostatInterval", 25), ) + ) def setup_integrator(self): print("Setting up integrator...") @@ -1219,7 +1227,7 @@ def _process_parameters(self, user_params, param_type="system_params"): ) if key == "constraints": try: - if type(value) == str: + if isinstance(value, str): if value == "None": processed_params[key] = None elif value == "HBonds": @@ -1243,7 +1251,7 @@ def _process_parameters(self, user_params, param_type="system_params"): "part of the parameters.\n" ) if key == "rigidWater" or key == "rigidwater": - if type(value) == bool: + if isinstance(value, bool): processed_params[key] = value elif value == "True": processed_params[key] = True @@ -1268,7 +1276,7 @@ def _process_parameters(self, user_params, param_type="system_params"): ) if key == "solvate": try: - if type(value) == bool: + if isinstance(value, bool): processed_params[key] = value elif value == "True": processed_params[key] = True @@ -1480,7 +1488,7 @@ def check_system_params(cls, values): # forcefield forcefield_files = values.get("forcefield_files") - if forcefield_files is None or forcefield_files is []: + if forcefield_files is None or forcefield_files == []: print("Setting default forcefields") forcefield_files = ["amber14-all.xml", "amber14/tip3pfb.xml"] elif len(forcefield_files) == 0: @@ -1489,10 +1497,12 @@ def check_system_params(cls, values): else: for file in forcefield_files: if file not in FORCEFIELD_LIST: - error_msg += "The forcefield file is not present" - + error_msg += ( + "The forcefield file is not present: forcefield files are: " + + str(FORCEFIELD_LIST) + ) save = values.get("save", True) - if type(save) != bool: + if not isinstance(save, bool): error_msg += "save must be a boolean value" if error_msg != "": @@ -1550,7 +1560,10 @@ def create_simulation_input(pdb_path, forcefield_files): Water_model = Forcefield_files[1] # check if they are part of the list if Forcefield not in FORCEFIELD_LIST: - raise Exception("Forcefield not recognized") + raise Exception( + "Forcefield not recognized: Possible forcefields are: " + + str(FORCEFIELD_LIST) + ) if Water_model not in FORCEFIELD_LIST: raise Exception("Water model not recognized") diff --git a/mdagent/tools/maketools.py b/mdagent/tools/maketools.py index a7f514a1..f831708e 100644 --- a/mdagent/tools/maketools.py +++ b/mdagent/tools/maketools.py @@ -46,9 +46,7 @@ PCATool, PPIDistance, ProteinName2PDBTool, - RadiusofGyrationAverage, - RadiusofGyrationPerFrame, - RadiusofGyrationPlot, + RadiusofGyrationTool, RDFTool, Scholar2ResultLLM, SetUpandRunFunction, @@ -74,7 +72,7 @@ def make_all_tools( all_tools += [ ModifyBaseSimulationScriptTool(path_registry=path_instance, llm=llm), ] - if "OPENAI_API_KEY" in os.environ and "PQA_API_KEY" in os.environ: + if path_instance.ckpt_papers: all_tools += [Scholar2ResultLLM(llm=llm, path_registry=path_instance)] if human: all_tools += [agents.load_tools(["human"], llm)[0]] @@ -99,9 +97,7 @@ def make_all_tools( PCATool(path_registry=path_instance), PPIDistance(path_registry=path_instance), ProteinName2PDBTool(path_registry=path_instance), - RadiusofGyrationAverage(path_registry=path_instance), - RadiusofGyrationPerFrame(path_registry=path_instance), - RadiusofGyrationPlot(path_registry=path_instance), + RadiusofGyrationTool(path_registry=path_instance), RDFTool(path_registry=path_instance), SetUpandRunFunction(path_registry=path_instance), SimulationOutputFigures(path_registry=path_instance), diff --git a/tests/test_analysis/test_rgy_tool.py b/tests/test_analysis/test_rgy_tool.py index 4e5033c2..c131d5be 100644 --- a/tests/test_analysis/test_rgy_tool.py +++ b/tests/test_analysis/test_rgy_tool.py @@ -1,6 +1,9 @@ import pytest -from mdagent.tools.base_tools.analysis_tools.rgy import RadiusofGyration +from mdagent.tools.base_tools.analysis_tools.rgy import ( + RadiusofGyration, + RadiusofGyrationTool, +) @pytest.fixture @@ -13,6 +16,24 @@ def rgy(get_registry, loaded_cif_traj): return rgy +@pytest.fixture +def rgy_tool(get_registry, loaded_cif_traj): + registry = get_registry("raw", False) + rgy_tool = RadiusofGyrationTool(path_registry=registry, load_traj=False) + rgy_tool.rgy.traj = loaded_cif_traj + rgy_tool.rgy.top_file = "test_top_dummy" + rgy_tool.rgy.traj_file = "test_traj_dummy" + return rgy_tool + + +def test_rgy_tool(rgy_tool): + output = rgy_tool._run(traj_file="test_top_dummy", top_file="test_traj_dummy") + assert "Radii of gyration saved to " in output + assert "Average radius of gyration: " in output + assert "Plot saved as: " in output + assert ".png" in output + + def test_rgy_per_frame(rgy): output = rgy.rgy_per_frame() assert "Radii of gyration saved to " in output @@ -26,3 +47,12 @@ def test_rgy_average(rgy): def test_plot_rgy(rgy): output = rgy.plot_rgy() assert "Plot saved as: " in output + assert ".png" in output + + +def test_compute_plot_return_avg_rgy(rgy): + output = rgy.compute_plot_return_avg() + assert "Radii of gyration saved to " in output + assert "Average radius of gyration: " in output + assert "Plot saved as: " in output + assert ".png" in output