-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
223 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
pre-commit | ||
pytest | ||
pytest-mock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,75 @@ | ||
import pqapi | ||
import os | ||
import re | ||
|
||
import langchain | ||
import paperqa | ||
import paperscraper | ||
from langchain.base_language import BaseLanguageModel | ||
from langchain.tools import BaseTool | ||
from pypdf.errors import PdfReadError | ||
|
||
|
||
def paper_scraper(search: str, pdir: str = "query") -> dict: | ||
try: | ||
return paperscraper.search_papers(search, pdir=pdir) | ||
except KeyError: | ||
return {} | ||
|
||
|
||
def paper_search(llm, query): | ||
prompt = langchain.prompts.PromptTemplate( | ||
input_variables=["question"], | ||
template=""" | ||
I would like to find scholarly papers to answer | ||
this question: {question}. Your response must be at | ||
most 10 words long. | ||
'A search query that would bring up papers that can answer | ||
this question would be: '""", | ||
) | ||
|
||
query_chain = langchain.chains.llm.LLMChain(llm=llm, prompt=prompt) | ||
if not os.path.isdir("./query"): # todo: move to ckpt | ||
os.mkdir("query/") | ||
search = query_chain.run(query) | ||
print("\nSearch:", search) | ||
papers = paper_scraper(search, pdir=f"query/{re.sub(' ', '', search)}") | ||
return papers | ||
|
||
|
||
def scholar2result_llm(llm, query, k=5, max_sources=2): | ||
"""Useful to answer questions that require | ||
technical knowledge. Ask a specific question.""" | ||
papers = paper_search(llm, query) | ||
if len(papers) == 0: | ||
return "Not enough papers found" | ||
docs = paperqa.Docs(llm=llm) | ||
not_loaded = 0 | ||
for path, data in papers.items(): | ||
try: | ||
docs.add(path, data["citation"]) | ||
except (ValueError, FileNotFoundError, PdfReadError): | ||
not_loaded += 1 | ||
|
||
print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}") | ||
answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer | ||
return answer | ||
|
||
|
||
class Scholar2ResultLLM(BaseTool): | ||
name = "LiteratureSearch" | ||
description = """Input a specific question, | ||
returns an answer from literature search.""" | ||
description = ( | ||
"Useful to answer questions that require technical " | ||
"knowledge. Ask a specific question." | ||
) | ||
llm: BaseLanguageModel = None | ||
|
||
pqa_key: str = "" | ||
|
||
def __init__(self, pqa_key: str): | ||
def __init__(self, llm): | ||
super().__init__() | ||
self.pqa_key = pqa_key | ||
self.llm = llm | ||
|
||
def _run(self, question: str) -> str: | ||
"""Use the tool""" | ||
try: | ||
response = pqapi.agent_query("default", question) | ||
return response.answer | ||
except Exception: | ||
return "Literature search failed." | ||
|
||
async def _arun(self, question: str) -> str: | ||
"""Use the tool asynchronously""" | ||
raise NotImplementedError | ||
def _run(self, query) -> str: | ||
return scholar2result_llm(self.llm, query) | ||
|
||
async def _arun(self, query) -> str: | ||
"""Use the tool asynchronously.""" | ||
raise NotImplementedError("this tool does not support async") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/Users/samcox/anaconda3/envs/mda_feb21/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | ||
" from .autonotebook import tqdm as notebook_tqdm\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from mdagent import MDAgent" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#until we update to new version\n", | ||
"import nest_asyncio\n", | ||
"nest_asyncio.apply()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"mda = MDAgent()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"prompt = \"Are there any studies that show that the use of a mask can reduce the spread of COVID-19?\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\"Masks COVID-19 transmission reduction studies\"\n", | ||
"Search: \"Masks COVID-19 transmission reduction studies\"\n", | ||
"\n", | ||
"Found 14 papers but couldn't load 0\n", | ||
"Yes, there are studies that show that the use of a mask can reduce the spread of COVID-19. The review by Howard et al. (2021) indicates that mask-wearing reduces the transmissibility of COVID-19 by limiting the spread of infected respiratory particles. This conclusion is supported by evidence from both laboratory and clinical studies." | ||
] | ||
} | ||
], | ||
"source": [ | ||
"answer = mda.run(prompt)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"'Yes, there are studies that show that the use of a mask can reduce the spread of COVID-19. The review by Howard et al. (2021) indicates that mask-wearing reduces the transmissibility of COVID-19 by limiting the spread of infected respiratory particles. This conclusion is supported by evidence from both laboratory and clinical studies.'" | ||
] | ||
}, | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"answer" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "mdagent", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.8" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters