diff --git a/mdagent/agent/agent.py b/mdagent/agent/agent.py index 2adc7330..80f143f9 100644 --- a/mdagent/agent/agent.py +++ b/mdagent/agent/agent.py @@ -46,6 +46,7 @@ def __init__( uploaded_files=[], # user input files to add to path registry run_id="", use_memory=False, + paper_dir=None, # papers for pqa, relative path within repo ): self.llm = _make_llm(model, temp, streaming) if tools_model is None: @@ -53,7 +54,7 @@ def __init__( self.tools_llm = _make_llm(tools_model, temp, streaming) self.use_memory = use_memory - self.path_registry = PathRegistry.get_instance(ckpt_dir=ckpt_dir) + self.path_registry = PathRegistry.get_instance(ckpt_dir, paper_dir) self.ckpt_dir = self.path_registry.ckpt_dir self.memory = MemoryManager(self.path_registry, self.tools_llm, run_id=run_id) self.run_id = self.memory.run_id diff --git a/mdagent/agent/prompt.py b/mdagent/agent/prompt.py index dbfbd669..9dfb91f7 100644 --- a/mdagent/agent/prompt.py +++ b/mdagent/agent/prompt.py @@ -1,7 +1,7 @@ from langchain.prompts import PromptTemplate structured_prompt = PromptTemplate( - input_variables=["input, context"], + input_variables=["input", "context"], template=""" You are an expert molecular dynamics scientist, and your task is to respond to the question or diff --git a/mdagent/tools/base_tools/analysis_tools/rgy.py b/mdagent/tools/base_tools/analysis_tools/rgy.py index 71f58b40..61ad7698 100644 --- a/mdagent/tools/base_tools/analysis_tools/rgy.py +++ b/mdagent/tools/base_tools/analysis_tools/rgy.py @@ -69,6 +69,7 @@ def plot_rgy(self) -> str: if plot_name.endswith(".png"): plot_name = plot_name.split(".png")[0] plot_path = f"{self.path_registry.ckpt_figures}/{plot_name}" + print("plot_path", plot_path) plt.plot(rg_per_frame) plt.xlabel("Frame") plt.ylabel("Radius of Gyration (nm)") @@ -77,7 +78,7 @@ def plot_rgy(self) -> str: plt.savefig(f"{plot_path}") self.path_registry.map_path( plot_id, - plot_path, + plot_path + ".png", description=f"Plot of radii of gyration over time for {self.traj_file}", ) plt.close() diff --git a/mdagent/tools/base_tools/util_tools/search_tools.py b/mdagent/tools/base_tools/util_tools/search_tools.py index 1015d69d..20424e07 100644 --- a/mdagent/tools/base_tools/util_tools/search_tools.py +++ b/mdagent/tools/base_tools/util_tools/search_tools.py @@ -1,90 +1,47 @@ -import logging -import os -import re from typing import Optional -import langchain import nest_asyncio import paperqa -import paperscraper from langchain.base_language import BaseLanguageModel from langchain.tools import BaseTool -from langchain_core.output_parsers import StrOutputParser -from pypdf.errors import PdfReadError from mdagent.utils import PathRegistry -def configure_logging(path): - # to log all runtime errors from paperscraper, which can be VERY noisy - log_file = os.path.join(path, "scraping_errors.log") - logging.basicConfig( - filename=log_file, - level=logging.ERROR, - format="%(asctime)s:%(levelname)s:%(message)s", - ) - - -def paper_scraper(search: str, pdir: str = "query") -> dict: - try: - return paperscraper.search_papers(search, pdir=pdir) - except KeyError: - return {} - - -def paper_search(llm, query, path_registry): - prompt = langchain.prompts.PromptTemplate( - input_variables=["question"], - template=""" - I would like to find scholarly papers to answer - this question: {question}. Your response must be at - most 10 words long. - 'A search query that would bring up papers that can answer - this question would be: '""", - ) - - path = f"{path_registry.ckpt_files}/query" - query_chain = prompt | llm | StrOutputParser() - if not os.path.isdir(path): - os.mkdir(path) - configure_logging(path) - search = query_chain.invoke(query) - print("\nSearch:", search) - papers = paper_scraper(search, pdir=f"{path}/{re.sub(' ', '', search)}") - return papers - - -def scholar2result_llm(llm, query, path_registry, k=5, max_sources=2): - """Useful to answer questions that require - technical knowledge. Ask a specific question.""" - if llm.model_name.startswith("gpt"): - docs = paperqa.Docs(llm=llm.model_name) +def scholar2result_llm(llm, query, path_registry): + paper_directory = path_registry.ckpt_papers + if paper_directory is None: + raise ValueError( + "'paper_dir' is None. To use this tool, the user " + "must provide a directory with PDFs at the start." + ) + print("Paper Directory", paper_directory) + llm_name = llm.model_name + if llm_name.startswith("gpt") or llm_name.startswith("claude"): + settings = paperqa.Settings( + llm=llm_name, + summary_llm=llm_name, + temperature=llm.temperature, + paper_directory=paper_directory, + ) else: - docs = paperqa.Docs() # uses default gpt model in paperqa - - papers = paper_search(llm, query, path_registry) - if len(papers) == 0: - return "Failed. Not enough papers found" - not_loaded = 0 - for path, data in papers.items(): - try: - docs.add(path, data["citation"]) - except (ValueError, FileNotFoundError, PdfReadError): - not_loaded += 1 - - print( - f"\nFound {len(papers)} papers" - + (f" but couldn't load {not_loaded}" if not_loaded > 0 else "") - ) - answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer - return "Succeeded. " + answer + settings = paperqa.Settings( + temperature=llm.temperature, # uses default gpt model in paperqa + paper_directory=paper_directory, + ) + response = paperqa.ask(query, settings=settings) + answer = response.answer.formatted_answer + if "I cannot answer." in answer: + answer += f" Check to ensure there's papers in {paper_directory}" + print(answer) + return answer class Scholar2ResultLLM(BaseTool): name = "LiteratureSearch" description = ( - "Useful to answer questions that require technical " - "knowledge. Ask a specific question." + "Useful to answer questions that may be found in literature. " + "Ask a specific question as the input." ) llm: BaseLanguageModel = None path_registry: Optional[PathRegistry] @@ -96,7 +53,11 @@ def __init__(self, llm, path_registry): def _run(self, query) -> str: nest_asyncio.apply() - return scholar2result_llm(self.llm, query, self.path_registry) + try: + return scholar2result_llm(self.llm, query, self.path_registry) + except Exception as e: + print(e) + return f"Failed. {type(e).__name__}: {e}" async def _arun(self, query) -> str: """Use the tool asynchronously.""" diff --git a/mdagent/utils/path_registry.py b/mdagent/utils/path_registry.py index 9ee4e109..c7ebbcc9 100644 --- a/mdagent/utils/path_registry.py +++ b/mdagent/utils/path_registry.py @@ -2,6 +2,7 @@ import os from datetime import datetime from enum import Enum +from typing import Optional from mdagent.utils.set_ckpt import SetCheckpoint @@ -22,20 +23,33 @@ class PathRegistry: @classmethod # set ckpt_dir to None by default - def get_instance(cls, ckpt_dir=None): + def get_instance(cls, ckpt_dir=None, paper_dir=None): # todo: use same ckpt if run_id is given if not cls.instance or ckpt_dir is not None: - cls.instance = cls(ckpt_dir) + cls.instance = cls(ckpt_dir, paper_dir) return cls.instance - def __init__(self, ckpt_dir: str = "ckpt"): + def __init__(self, ckpt_dir: str = "ckpt", paper_dir=None): self._set_ckpt(ckpt_dir) + self._set_paper_dir(paper_dir) self._make_all_dirs() self._init_path_registry() def _set_ckpt(self, ckpt: str): self.ckpt_dir = self.set_ckpt.set_ckpt_subdir(ckpt_dir=ckpt) + def _set_paper_dir(self, paper_dir: Optional[str]): + if paper_dir is None: + self.ckpt_papers = None + return + absolute_path = os.path.abspath(paper_dir) + if not os.path.exists(absolute_path) or not os.path.isdir(absolute_path): + raise ValueError( + f"Invalid paper directory: '{absolute_path}' either doesn't exist " + "or isn't a directory." + ) + self.ckpt_papers = absolute_path + def _make_all_dirs(self): self.json_file_path = os.path.join(self.ckpt_dir, "paths_registry.json") self.ckpt_files = os.path.join(self.ckpt_dir, "files") diff --git a/setup.py b/setup.py index 85ae8841..d23addcc 100644 --- a/setup.py +++ b/setup.py @@ -23,8 +23,7 @@ "matplotlib", "nbformat", "openai", - "paper-qa==4.0.0rc8 ", - "paper-scraper @ git+https://github.com/blackadad/paper-scraper.git", + "paper-qa==5.0.6", "pandas", "pydantic>=2.6", "python-dotenv", diff --git a/tests/test_preprocess/test_uniprot.py b/tests/test_preprocess/test_uniprot.py index 370dda9b..64db4887 100644 --- a/tests/test_preprocess/test_uniprot.py +++ b/tests/test_preprocess/test_uniprot.py @@ -487,14 +487,10 @@ def test_get_ids(query_uniprot): "P68871", "P02089", "P02070", - "O13163", "P02008", "B3EWR7", - "P04244", "P02094", - "P83479", "P01966", - "O93349", "P68872", "P69905", "P02088",