From f72b46d18364647efe49c3add55e99a0501bcef3 Mon Sep 17 00:00:00 2001 From: qcampbel Date: Tue, 1 Oct 2024 11:25:13 -0400 Subject: [PATCH 1/6] update literature search --- mdagent/agent/agent.py | 3 +- .../base_tools/util_tools/search_tools.py | 102 ++++++------------ mdagent/utils/path_registry.py | 15 ++- setup.py | 2 +- 4 files changed, 43 insertions(+), 79 deletions(-) diff --git a/mdagent/agent/agent.py b/mdagent/agent/agent.py index 2adc7330..039f2e1d 100644 --- a/mdagent/agent/agent.py +++ b/mdagent/agent/agent.py @@ -46,6 +46,7 @@ def __init__( uploaded_files=[], # user input files to add to path registry run_id="", use_memory=False, + paper_dir="ckpt/paper_collection", # papers for pqa, relative path within repo ): self.llm = _make_llm(model, temp, streaming) if tools_model is None: @@ -53,7 +54,7 @@ def __init__( self.tools_llm = _make_llm(tools_model, temp, streaming) self.use_memory = use_memory - self.path_registry = PathRegistry.get_instance(ckpt_dir=ckpt_dir) + self.path_registry = PathRegistry.get_instance(ckpt_dir, paper_dir) self.ckpt_dir = self.path_registry.ckpt_dir self.memory = MemoryManager(self.path_registry, self.tools_llm, run_id=run_id) self.run_id = self.memory.run_id diff --git a/mdagent/tools/base_tools/util_tools/search_tools.py b/mdagent/tools/base_tools/util_tools/search_tools.py index 1015d69d..c839c257 100644 --- a/mdagent/tools/base_tools/util_tools/search_tools.py +++ b/mdagent/tools/base_tools/util_tools/search_tools.py @@ -1,90 +1,44 @@ -import logging -import os -import re from typing import Optional -import langchain import nest_asyncio import paperqa -import paperscraper from langchain.base_language import BaseLanguageModel from langchain.tools import BaseTool -from langchain_core.output_parsers import StrOutputParser -from pypdf.errors import PdfReadError from mdagent.utils import PathRegistry -def configure_logging(path): - # to log all runtime errors from paperscraper, which can be VERY noisy - log_file = os.path.join(path, "scraping_errors.log") - logging.basicConfig( - filename=log_file, - level=logging.ERROR, - format="%(asctime)s:%(levelname)s:%(message)s", - ) - - -def paper_scraper(search: str, pdir: str = "query") -> dict: - try: - return paperscraper.search_papers(search, pdir=pdir) - except KeyError: - return {} - - -def paper_search(llm, query, path_registry): - prompt = langchain.prompts.PromptTemplate( - input_variables=["question"], - template=""" - I would like to find scholarly papers to answer - this question: {question}. Your response must be at - most 10 words long. - 'A search query that would bring up papers that can answer - this question would be: '""", - ) - - path = f"{path_registry.ckpt_files}/query" - query_chain = prompt | llm | StrOutputParser() - if not os.path.isdir(path): - os.mkdir(path) - configure_logging(path) - search = query_chain.invoke(query) - print("\nSearch:", search) - papers = paper_scraper(search, pdir=f"{path}/{re.sub(' ', '', search)}") - return papers - - -def scholar2result_llm(llm, query, path_registry, k=5, max_sources=2): - """Useful to answer questions that require - technical knowledge. Ask a specific question.""" - if llm.model_name.startswith("gpt"): - docs = paperqa.Docs(llm=llm.model_name) +def scholar2result_llm(llm, query, path_registry): + paper_directory = path_registry.ckpt_papers + if paper_directory is None: + raise ValueError("The 'paper_dir' is None and wasn't set from the start.") + print("Paper Directory", paper_directory) + llm_name = llm.model_name + if llm_name.startswith("gpt") or llm_name.startswith("claude"): + settings = paperqa.Settings( + llm=llm_name, + summary_llm=llm_name, + temperature=llm.temperature, + paper_directory=paper_directory, + ) else: - docs = paperqa.Docs() # uses default gpt model in paperqa - - papers = paper_search(llm, query, path_registry) - if len(papers) == 0: - return "Failed. Not enough papers found" - not_loaded = 0 - for path, data in papers.items(): - try: - docs.add(path, data["citation"]) - except (ValueError, FileNotFoundError, PdfReadError): - not_loaded += 1 - - print( - f"\nFound {len(papers)} papers" - + (f" but couldn't load {not_loaded}" if not_loaded > 0 else "") - ) - answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer - return "Succeeded. " + answer + settings = paperqa.Settings( + temperature=llm.temperature, # uses default gpt model in paperqa + paper_directory=paper_directory, + ) + response = paperqa.ask(query, settings=settings) + answer = response.answer.formatted_answer + if "I cannot answer." in answer: + answer += f" Check to ensure there's papers in {paper_directory}" + print(answer) + return answer class Scholar2ResultLLM(BaseTool): name = "LiteratureSearch" description = ( - "Useful to answer questions that require technical " - "knowledge. Ask a specific question." + "Useful to answer questions that may be found in literature. " + "Ask a specific question as the input." ) llm: BaseLanguageModel = None path_registry: Optional[PathRegistry] @@ -96,7 +50,11 @@ def __init__(self, llm, path_registry): def _run(self, query) -> str: nest_asyncio.apply() - return scholar2result_llm(self.llm, query, self.path_registry) + try: + return scholar2result_llm(self.llm, query, self.path_registry) + except Exception as e: + print(e) + return f"Failed. {type(e).__name__}: {e}" async def _arun(self, query) -> str: """Use the tool asynchronously.""" diff --git a/mdagent/utils/path_registry.py b/mdagent/utils/path_registry.py index 9ee4e109..37f942fb 100644 --- a/mdagent/utils/path_registry.py +++ b/mdagent/utils/path_registry.py @@ -22,19 +22,24 @@ class PathRegistry: @classmethod # set ckpt_dir to None by default - def get_instance(cls, ckpt_dir=None): + def get_instance(cls, ckpt_dir=None, paper_dir=None): # todo: use same ckpt if run_id is given if not cls.instance or ckpt_dir is not None: - cls.instance = cls(ckpt_dir) + cls.instance = cls(ckpt_dir, paper_dir) return cls.instance - def __init__(self, ckpt_dir: str = "ckpt"): - self._set_ckpt(ckpt_dir) + def __init__( + self, ckpt_dir: str = "ckpt", paper_dir: str = "ckpt/paper_collection" + ): + self._set_ckpt(ckpt_dir, paper_dir) self._make_all_dirs() self._init_path_registry() - def _set_ckpt(self, ckpt: str): + def _set_ckpt(self, ckpt: str, paper_dir: str): self.ckpt_dir = self.set_ckpt.set_ckpt_subdir(ckpt_dir=ckpt) + if paper_dir is not None: + paper_dir = os.path.join(self.set_ckpt.find_root_dir(), paper_dir) + self.ckpt_papers = paper_dir def _make_all_dirs(self): self.json_file_path = os.path.join(self.ckpt_dir, "paths_registry.json") diff --git a/setup.py b/setup.py index 85ae8841..92506aa3 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ "matplotlib", "nbformat", "openai", - "paper-qa==4.0.0rc8 ", + "paper-qa==5.0.6", "paper-scraper @ git+https://github.com/blackadad/paper-scraper.git", "pandas", "pydantic>=2.6", From dd3404c6c519a7b2d2bec029fa67d5fe8d3f17d3 Mon Sep 17 00:00:00 2001 From: qcampbel Date: Tue, 1 Oct 2024 11:26:02 -0400 Subject: [PATCH 2/6] fixed typo --- mdagent/agent/prompt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mdagent/agent/prompt.py b/mdagent/agent/prompt.py index dbfbd669..9dfb91f7 100644 --- a/mdagent/agent/prompt.py +++ b/mdagent/agent/prompt.py @@ -1,7 +1,7 @@ from langchain.prompts import PromptTemplate structured_prompt = PromptTemplate( - input_variables=["input, context"], + input_variables=["input", "context"], template=""" You are an expert molecular dynamics scientist, and your task is to respond to the question or From d794216c47f86d3f6ebe49b64d64d3b6283b32b3 Mon Sep 17 00:00:00 2001 From: qcampbel Date: Tue, 1 Oct 2024 12:37:24 -0400 Subject: [PATCH 3/6] also accept absolute paths --- mdagent/agent/agent.py | 2 +- .../base_tools/util_tools/search_tools.py | 5 +++- mdagent/utils/path_registry.py | 25 +++++++++++++------ 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/mdagent/agent/agent.py b/mdagent/agent/agent.py index 039f2e1d..80f143f9 100644 --- a/mdagent/agent/agent.py +++ b/mdagent/agent/agent.py @@ -46,7 +46,7 @@ def __init__( uploaded_files=[], # user input files to add to path registry run_id="", use_memory=False, - paper_dir="ckpt/paper_collection", # papers for pqa, relative path within repo + paper_dir=None, # papers for pqa, relative path within repo ): self.llm = _make_llm(model, temp, streaming) if tools_model is None: diff --git a/mdagent/tools/base_tools/util_tools/search_tools.py b/mdagent/tools/base_tools/util_tools/search_tools.py index c839c257..20424e07 100644 --- a/mdagent/tools/base_tools/util_tools/search_tools.py +++ b/mdagent/tools/base_tools/util_tools/search_tools.py @@ -11,7 +11,10 @@ def scholar2result_llm(llm, query, path_registry): paper_directory = path_registry.ckpt_papers if paper_directory is None: - raise ValueError("The 'paper_dir' is None and wasn't set from the start.") + raise ValueError( + "'paper_dir' is None. To use this tool, the user " + "must provide a directory with PDFs at the start." + ) print("Paper Directory", paper_directory) llm_name = llm.model_name if llm_name.startswith("gpt") or llm_name.startswith("claude"): diff --git a/mdagent/utils/path_registry.py b/mdagent/utils/path_registry.py index 37f942fb..c7ebbcc9 100644 --- a/mdagent/utils/path_registry.py +++ b/mdagent/utils/path_registry.py @@ -2,6 +2,7 @@ import os from datetime import datetime from enum import Enum +from typing import Optional from mdagent.utils.set_ckpt import SetCheckpoint @@ -28,18 +29,26 @@ def get_instance(cls, ckpt_dir=None, paper_dir=None): cls.instance = cls(ckpt_dir, paper_dir) return cls.instance - def __init__( - self, ckpt_dir: str = "ckpt", paper_dir: str = "ckpt/paper_collection" - ): - self._set_ckpt(ckpt_dir, paper_dir) + def __init__(self, ckpt_dir: str = "ckpt", paper_dir=None): + self._set_ckpt(ckpt_dir) + self._set_paper_dir(paper_dir) self._make_all_dirs() self._init_path_registry() - def _set_ckpt(self, ckpt: str, paper_dir: str): + def _set_ckpt(self, ckpt: str): self.ckpt_dir = self.set_ckpt.set_ckpt_subdir(ckpt_dir=ckpt) - if paper_dir is not None: - paper_dir = os.path.join(self.set_ckpt.find_root_dir(), paper_dir) - self.ckpt_papers = paper_dir + + def _set_paper_dir(self, paper_dir: Optional[str]): + if paper_dir is None: + self.ckpt_papers = None + return + absolute_path = os.path.abspath(paper_dir) + if not os.path.exists(absolute_path) or not os.path.isdir(absolute_path): + raise ValueError( + f"Invalid paper directory: '{absolute_path}' either doesn't exist " + "or isn't a directory." + ) + self.ckpt_papers = absolute_path def _make_all_dirs(self): self.json_file_path = os.path.join(self.ckpt_dir, "paths_registry.json") From 0e662540243d1c9b12daa37500ff39bd1c1bf0ec Mon Sep 17 00:00:00 2001 From: qcampbel Date: Tue, 1 Oct 2024 13:02:07 -0400 Subject: [PATCH 4/6] removed paperscraper --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 92506aa3..d23addcc 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,6 @@ "nbformat", "openai", "paper-qa==5.0.6", - "paper-scraper @ git+https://github.com/blackadad/paper-scraper.git", "pandas", "pydantic>=2.6", "python-dotenv", From a8187129cc06e38af4b1b70159c7998ab034b17a Mon Sep 17 00:00:00 2001 From: qcampbel Date: Wed, 2 Oct 2024 13:03:48 -0400 Subject: [PATCH 5/6] fixed rgy bug with saving path in path registry --- mdagent/tools/base_tools/analysis_tools/rgy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mdagent/tools/base_tools/analysis_tools/rgy.py b/mdagent/tools/base_tools/analysis_tools/rgy.py index 71f58b40..61ad7698 100644 --- a/mdagent/tools/base_tools/analysis_tools/rgy.py +++ b/mdagent/tools/base_tools/analysis_tools/rgy.py @@ -69,6 +69,7 @@ def plot_rgy(self) -> str: if plot_name.endswith(".png"): plot_name = plot_name.split(".png")[0] plot_path = f"{self.path_registry.ckpt_figures}/{plot_name}" + print("plot_path", plot_path) plt.plot(rg_per_frame) plt.xlabel("Frame") plt.ylabel("Radius of Gyration (nm)") @@ -77,7 +78,7 @@ def plot_rgy(self) -> str: plt.savefig(f"{plot_path}") self.path_registry.map_path( plot_id, - plot_path, + plot_path + ".png", description=f"Plot of radii of gyration over time for {self.traj_file}", ) plt.close() From cbcbb5d72895937f0f0dcc5ab2c99679091830a4 Mon Sep 17 00:00:00 2001 From: qcampbel Date: Wed, 2 Oct 2024 14:31:32 -0400 Subject: [PATCH 6/6] updated uniprot unit test --- tests/test_preprocess/test_uniprot.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_preprocess/test_uniprot.py b/tests/test_preprocess/test_uniprot.py index 370dda9b..64db4887 100644 --- a/tests/test_preprocess/test_uniprot.py +++ b/tests/test_preprocess/test_uniprot.py @@ -487,14 +487,10 @@ def test_get_ids(query_uniprot): "P68871", "P02089", "P02070", - "O13163", "P02008", "B3EWR7", - "P04244", "P02094", - "P83479", "P01966", - "O93349", "P68872", "P69905", "P02088",