Skip to content

Commit

Permalink
updated literature search
Browse files Browse the repository at this point in the history
  • Loading branch information
qcampbel committed Oct 1, 2024
1 parent 0db390f commit 933aa85
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 79 deletions.
3 changes: 2 additions & 1 deletion mdagent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ def __init__(
uploaded_files=[], # user input files to add to path registry
run_id="",
use_memory=False,
paper_dir="ckpt/paper_collection", # papers for pqa, relative path within repo
):
self.llm = _make_llm(model, temp, streaming)
if tools_model is None:
tools_model = model
self.tools_llm = _make_llm(tools_model, temp, streaming)

self.use_memory = use_memory
self.path_registry = PathRegistry.get_instance(ckpt_dir=ckpt_dir)
self.path_registry = PathRegistry.get_instance(ckpt_dir, paper_dir)
self.ckpt_dir = self.path_registry.ckpt_dir
self.memory = MemoryManager(self.path_registry, self.tools_llm, run_id=run_id)
self.run_id = self.memory.run_id
Expand Down
163 changes: 91 additions & 72 deletions mdagent/tools/base_tools/util_tools/search_tools.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,109 @@
import logging
import os
import re
from typing import Optional

import langchain
import nest_asyncio
import paperqa
import paperscraper
from langchain.base_language import BaseLanguageModel
from langchain.tools import BaseTool
from langchain_core.output_parsers import StrOutputParser
from pypdf.errors import PdfReadError

from mdagent.utils import PathRegistry


def configure_logging(path):
# to log all runtime errors from paperscraper, which can be VERY noisy
log_file = os.path.join(path, "scraping_errors.log")
logging.basicConfig(
filename=log_file,
level=logging.ERROR,
format="%(asctime)s:%(levelname)s:%(message)s",
)


def paper_scraper(search: str, pdir: str = "query") -> dict:
try:
return paperscraper.search_papers(search, pdir=pdir)
except KeyError:
return {}


def paper_search(llm, query, path_registry):
prompt = langchain.prompts.PromptTemplate(
input_variables=["question"],
template="""
I would like to find scholarly papers to answer
this question: {question}. Your response must be at
most 10 words long.
'A search query that would bring up papers that can answer
this question would be: '""",
)

path = f"{path_registry.ckpt_files}/query"
query_chain = prompt | llm | StrOutputParser()
if not os.path.isdir(path):
os.mkdir(path)
configure_logging(path)
search = query_chain.invoke(query)
print("\nSearch:", search)
papers = paper_scraper(search, pdir=f"{path}/{re.sub(' ', '', search)}")
return papers


def scholar2result_llm(llm, query, path_registry, k=5, max_sources=2):
"""Useful to answer questions that require
technical knowledge. Ask a specific question."""
if llm.model_name.startswith("gpt"):
docs = paperqa.Docs(llm=llm.model_name)
# from pypdf.errors import PdfReadError


# def configure_logging(path):
# # to log all runtime errors from paperscraper, which can be VERY noisy
# log_file = os.path.join(path, "scraping_errors.log")
# logging.basicConfig(
# filename=log_file,
# level=logging.ERROR,
# format="%(asctime)s:%(levelname)s:%(message)s",
# )


# def paper_scraper(search: str, pdir: str = "query") -> dict:
# try:
# return paperscraper.search_papers(search, pdir=pdir)
# except KeyError:
# return {}


# def paper_search(llm, query, path_registry):
# prompt = langchain.prompts.PromptTemplate(
# input_variables=["question"],
# template="""
# I would like to find scholarly papers to answer
# this question: {question}. Your response must be at
# most 10 words long.
# 'A search query that would bring up papers that can answer
# this question would be: '""",
# )

# path = f"{path_registry.ckpt_files}/query"
# query_chain = prompt | llm | StrOutputParser()
# if not os.path.isdir(path):
# os.mkdir(path)
# configure_logging(path)
# search = query_chain.invoke(query)
# print("\nSearch:", search)
# papers = paper_scraper(search, pdir=f"{path}/{re.sub(' ', '', search)}")
# return papers


# def scholar2result_llm(llm, query, path_registry, k=5, max_sources=2):
# """Useful to answer questions that require
# technical knowledge. Ask a specific question."""
# if llm.model_name.startswith("gpt"):
# docs = paperqa.Docs(llm=llm.model_name)
# else:
# docs = paperqa.Docs() # uses default gpt model in paperqa

# papers = paper_search(llm, query, path_registry)
# if len(papers) == 0:
# return "Failed. Not enough papers found"
# not_loaded = 0
# for path, data in papers.items():
# try:
# docs.add(path, data["citation"])
# except (ValueError, FileNotFoundError, PdfReadError):
# not_loaded += 1

# print(
# f"\nFound {len(papers)} papers"
# + (f" but couldn't load {not_loaded}" if not_loaded > 0 else "")
# )
# answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer
# return "Succeeded. " + answer


def scholar2result_llm(llm, query, path_registry):
paper_directory = path_registry.ckpt_papers
print("Paper Directory", paper_directory)
llm_name = llm.model_name
if llm_name.startswith("gpt") or llm_name.startswith("claude"):
settings = paperqa.Settings(
llm=llm_name,
summary_llm=llm_name,
temperature=llm.temperature,
paper_directory=paper_directory,
)
else:
docs = paperqa.Docs() # uses default gpt model in paperqa

papers = paper_search(llm, query, path_registry)
if len(papers) == 0:
return "Failed. Not enough papers found"
not_loaded = 0
for path, data in papers.items():
try:
docs.add(path, data["citation"])
except (ValueError, FileNotFoundError, PdfReadError):
not_loaded += 1

print(
f"\nFound {len(papers)} papers"
+ (f" but couldn't load {not_loaded}" if not_loaded > 0 else "")
)
answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer
return "Succeeded. " + answer
settings = paperqa.Settings(
temperature=llm.temperature, # uses default gpt model in paperqa
paper_directory=paper_directory,
)
response = paperqa.ask(query, settings=settings)
answer = response.answer.formatted_answer
if "I cannot answer." in answer:
answer += f" Check to ensure there's papers in {paper_directory}"
print(answer)
return answer


class Scholar2ResultLLM(BaseTool):
name = "LiteratureSearch"
description = (
"Useful to answer questions that require technical "
"knowledge. Ask a specific question."
"Useful to answer questions that may be found in literature. "
"Ask a specific question as the input."
)
llm: BaseLanguageModel = None
path_registry: Optional[PathRegistry]
Expand Down
14 changes: 9 additions & 5 deletions mdagent/utils/path_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,23 @@ class PathRegistry:

@classmethod
# set ckpt_dir to None by default
def get_instance(cls, ckpt_dir=None):
def get_instance(cls, ckpt_dir=None, paper_dir=None):
# todo: use same ckpt if run_id is given
if not cls.instance or ckpt_dir is not None:
cls.instance = cls(ckpt_dir)
cls.instance = cls(ckpt_dir, paper_dir)
return cls.instance

def __init__(self, ckpt_dir: str = "ckpt"):
self._set_ckpt(ckpt_dir)
def __init__(
self, ckpt_dir: str = "ckpt", paper_dir: str = "ckpt/paper_collection"
):
self._set_ckpt(ckpt_dir, paper_dir)
self._make_all_dirs()
self._init_path_registry()

def _set_ckpt(self, ckpt: str):
def _set_ckpt(self, ckpt: str, paper_dir: str):
self.ckpt_dir = self.set_ckpt.set_ckpt_subdir(ckpt_dir=ckpt)
if paper_dir is not None:
self.ckpt_papers = os.path.join(self.set_ckpt.find_root_dir(), paper_dir)

def _make_all_dirs(self):
self.json_file_path = os.path.join(self.ckpt_dir, "paths_registry.json")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"matplotlib",
"nbformat",
"openai",
"paper-qa==4.0.0rc8 ",
"paper-qa==5.0.6",
"paper-scraper @ git+https://github.com/blackadad/paper-scraper.git",
"pandas",
"pydantic>=2.6",
Expand Down

0 comments on commit 933aa85

Please sign in to comment.