diff --git a/mdagent/agent/agent.py b/mdagent/agent/agent.py index 2adc7330..039f2e1d 100644 --- a/mdagent/agent/agent.py +++ b/mdagent/agent/agent.py @@ -46,6 +46,7 @@ def __init__( uploaded_files=[], # user input files to add to path registry run_id="", use_memory=False, + paper_dir="ckpt/paper_collection", # papers for pqa, relative path within repo ): self.llm = _make_llm(model, temp, streaming) if tools_model is None: @@ -53,7 +54,7 @@ def __init__( self.tools_llm = _make_llm(tools_model, temp, streaming) self.use_memory = use_memory - self.path_registry = PathRegistry.get_instance(ckpt_dir=ckpt_dir) + self.path_registry = PathRegistry.get_instance(ckpt_dir, paper_dir) self.ckpt_dir = self.path_registry.ckpt_dir self.memory = MemoryManager(self.path_registry, self.tools_llm, run_id=run_id) self.run_id = self.memory.run_id diff --git a/mdagent/tools/base_tools/util_tools/search_tools.py b/mdagent/tools/base_tools/util_tools/search_tools.py index 1015d69d..90273145 100644 --- a/mdagent/tools/base_tools/util_tools/search_tools.py +++ b/mdagent/tools/base_tools/util_tools/search_tools.py @@ -1,90 +1,109 @@ -import logging -import os -import re from typing import Optional -import langchain import nest_asyncio import paperqa -import paperscraper from langchain.base_language import BaseLanguageModel from langchain.tools import BaseTool -from langchain_core.output_parsers import StrOutputParser -from pypdf.errors import PdfReadError from mdagent.utils import PathRegistry - -def configure_logging(path): - # to log all runtime errors from paperscraper, which can be VERY noisy - log_file = os.path.join(path, "scraping_errors.log") - logging.basicConfig( - filename=log_file, - level=logging.ERROR, - format="%(asctime)s:%(levelname)s:%(message)s", - ) - - -def paper_scraper(search: str, pdir: str = "query") -> dict: - try: - return paperscraper.search_papers(search, pdir=pdir) - except KeyError: - return {} - - -def paper_search(llm, query, path_registry): - prompt = langchain.prompts.PromptTemplate( - input_variables=["question"], - template=""" - I would like to find scholarly papers to answer - this question: {question}. Your response must be at - most 10 words long. - 'A search query that would bring up papers that can answer - this question would be: '""", - ) - - path = f"{path_registry.ckpt_files}/query" - query_chain = prompt | llm | StrOutputParser() - if not os.path.isdir(path): - os.mkdir(path) - configure_logging(path) - search = query_chain.invoke(query) - print("\nSearch:", search) - papers = paper_scraper(search, pdir=f"{path}/{re.sub(' ', '', search)}") - return papers - - -def scholar2result_llm(llm, query, path_registry, k=5, max_sources=2): - """Useful to answer questions that require - technical knowledge. Ask a specific question.""" - if llm.model_name.startswith("gpt"): - docs = paperqa.Docs(llm=llm.model_name) +# from pypdf.errors import PdfReadError + + +# def configure_logging(path): +# # to log all runtime errors from paperscraper, which can be VERY noisy +# log_file = os.path.join(path, "scraping_errors.log") +# logging.basicConfig( +# filename=log_file, +# level=logging.ERROR, +# format="%(asctime)s:%(levelname)s:%(message)s", +# ) + + +# def paper_scraper(search: str, pdir: str = "query") -> dict: +# try: +# return paperscraper.search_papers(search, pdir=pdir) +# except KeyError: +# return {} + + +# def paper_search(llm, query, path_registry): +# prompt = langchain.prompts.PromptTemplate( +# input_variables=["question"], +# template=""" +# I would like to find scholarly papers to answer +# this question: {question}. Your response must be at +# most 10 words long. +# 'A search query that would bring up papers that can answer +# this question would be: '""", +# ) + +# path = f"{path_registry.ckpt_files}/query" +# query_chain = prompt | llm | StrOutputParser() +# if not os.path.isdir(path): +# os.mkdir(path) +# configure_logging(path) +# search = query_chain.invoke(query) +# print("\nSearch:", search) +# papers = paper_scraper(search, pdir=f"{path}/{re.sub(' ', '', search)}") +# return papers + + +# def scholar2result_llm(llm, query, path_registry, k=5, max_sources=2): +# """Useful to answer questions that require +# technical knowledge. Ask a specific question.""" +# if llm.model_name.startswith("gpt"): +# docs = paperqa.Docs(llm=llm.model_name) +# else: +# docs = paperqa.Docs() # uses default gpt model in paperqa + +# papers = paper_search(llm, query, path_registry) +# if len(papers) == 0: +# return "Failed. Not enough papers found" +# not_loaded = 0 +# for path, data in papers.items(): +# try: +# docs.add(path, data["citation"]) +# except (ValueError, FileNotFoundError, PdfReadError): +# not_loaded += 1 + +# print( +# f"\nFound {len(papers)} papers" +# + (f" but couldn't load {not_loaded}" if not_loaded > 0 else "") +# ) +# answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer +# return "Succeeded. " + answer + + +def scholar2result_llm(llm, query, path_registry): + paper_directory = path_registry.ckpt_papers + print("Paper Directory", paper_directory) + llm_name = llm.model_name + if llm_name.startswith("gpt") or llm_name.startswith("claude"): + settings = paperqa.Settings( + llm=llm_name, + summary_llm=llm_name, + temperature=llm.temperature, + paper_directory=paper_directory, + ) else: - docs = paperqa.Docs() # uses default gpt model in paperqa - - papers = paper_search(llm, query, path_registry) - if len(papers) == 0: - return "Failed. Not enough papers found" - not_loaded = 0 - for path, data in papers.items(): - try: - docs.add(path, data["citation"]) - except (ValueError, FileNotFoundError, PdfReadError): - not_loaded += 1 - - print( - f"\nFound {len(papers)} papers" - + (f" but couldn't load {not_loaded}" if not_loaded > 0 else "") - ) - answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer - return "Succeeded. " + answer + settings = paperqa.Settings( + temperature=llm.temperature, # uses default gpt model in paperqa + paper_directory=paper_directory, + ) + response = paperqa.ask(query, settings=settings) + answer = response.answer.formatted_answer + if "I cannot answer." in answer: + answer += f" Check to ensure there's papers in {paper_directory}" + print(answer) + return answer class Scholar2ResultLLM(BaseTool): name = "LiteratureSearch" description = ( - "Useful to answer questions that require technical " - "knowledge. Ask a specific question." + "Useful to answer questions that may be found in literature. " + "Ask a specific question as the input." ) llm: BaseLanguageModel = None path_registry: Optional[PathRegistry] diff --git a/mdagent/utils/path_registry.py b/mdagent/utils/path_registry.py index 9ee4e109..e47b2f56 100644 --- a/mdagent/utils/path_registry.py +++ b/mdagent/utils/path_registry.py @@ -22,19 +22,23 @@ class PathRegistry: @classmethod # set ckpt_dir to None by default - def get_instance(cls, ckpt_dir=None): + def get_instance(cls, ckpt_dir=None, paper_dir=None): # todo: use same ckpt if run_id is given if not cls.instance or ckpt_dir is not None: - cls.instance = cls(ckpt_dir) + cls.instance = cls(ckpt_dir, paper_dir) return cls.instance - def __init__(self, ckpt_dir: str = "ckpt"): - self._set_ckpt(ckpt_dir) + def __init__( + self, ckpt_dir: str = "ckpt", paper_dir: str = "ckpt/paper_collection" + ): + self._set_ckpt(ckpt_dir, paper_dir) self._make_all_dirs() self._init_path_registry() - def _set_ckpt(self, ckpt: str): + def _set_ckpt(self, ckpt: str, paper_dir: str): self.ckpt_dir = self.set_ckpt.set_ckpt_subdir(ckpt_dir=ckpt) + if paper_dir is not None: + self.ckpt_papers = os.path.join(self.set_ckpt.find_root_dir(), paper_dir) def _make_all_dirs(self): self.json_file_path = os.path.join(self.ckpt_dir, "paths_registry.json") diff --git a/setup.py b/setup.py index 85ae8841..92506aa3 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ "matplotlib", "nbformat", "openai", - "paper-qa==4.0.0rc8 ", + "paper-qa==5.0.6", "paper-scraper @ git+https://github.com/blackadad/paper-scraper.git", "pandas", "pydantic>=2.6",