Skip to content

Commit

Permalink
Pqa tool (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 authored Feb 27, 2024
1 parent 9228878 commit 759b34e
Show file tree
Hide file tree
Showing 11 changed files with 223 additions and 44 deletions.
3 changes: 0 additions & 3 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,5 @@
# OpenAI API Key
OPENAI_API_KEY=YOUR_OPENAI_API_KEY_GOES_HERE # pragma: allowlist secret

# PQA API Key
PQA_API_KEY=YOUR_PQA_API_KEY_GOES_HERE # pragma: allowlist secret

# Serp API key
SERP_API_KEY=YOUR_SERP_API_KEY_GOES_HERE # pragma: allowlist secret
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ jobs:

steps:
- uses: actions/checkout@v2
- name: Set up Python "3.9"
- name: Set up Python "3.11"
uses: actions/setup-python@v2
with:
python-version: "3.9"
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
5 changes: 2 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ jobs:
environment-file: environment.yaml
python-version: ${{ matrix.python-version }}
auto-activate-base: true
- name: Install openmm pdbfixer mdanalysis with conda
- name: Install pdbfixer with conda
shell: bash -l {0}
run: |
conda install -c conda-forge openmm pdbfixer mdanalysis
conda install -c conda-forge pdbfixer
- name: Install dependencies
shell: bash -l {0}
run: |
Expand All @@ -45,6 +45,5 @@ jobs:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SEMANTIC_SCHOLAR_API_KEY: ${{ secrets.SEMANTIC_SCHOLAR_API_KEY }}
PQA_API_KEY : ${{ secrets.PQA_API_TOKEN }}
run: |
pytest -m "not skip" tests
3 changes: 0 additions & 3 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,5 @@
# Rule for detecting OpenAI API keys
OpenAI API Key: \b[secrets]{3}_[a-zA-Z0-9]{32}\b

# Rule for detecting pqa API keys
PQA API Key: "pqa[a-zA-Z0-9-._]+"

# Rule for detecting serp API keys
# Serp API Key: "[a-zA-Z0-9]{64}"
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ To use the OpenMM features in the agent, please set up a conda environment, foll
- Create conda environment: `conda env create -n mdagent -f environment.yaml`
- Activate your environment: `conda activate mdagent`

If you already have a conda environment, you can install the necessary dependencies with the following steps.
- Install the necessary conda dependencies: `conda install -c conda-forge openmm pdbfixer mdanalysis`
If you already have a conda environment, you can install, pdbfixer, a necessary dependency with the following steps.
- Install the necessary conda dependencies: `conda install -c conda-forge pdbfixer`


## Installation
Expand Down
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pre-commit
pytest
pytest-mock
85 changes: 67 additions & 18 deletions mdagent/tools/base_tools/util_tools/search_tools.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,75 @@
import pqapi
import os
import re

import langchain
import paperqa
import paperscraper
from langchain.base_language import BaseLanguageModel
from langchain.tools import BaseTool
from pypdf.errors import PdfReadError


def paper_scraper(search: str, pdir: str = "query") -> dict:
try:
return paperscraper.search_papers(search, pdir=pdir)
except KeyError:
return {}


def paper_search(llm, query):
prompt = langchain.prompts.PromptTemplate(
input_variables=["question"],
template="""
I would like to find scholarly papers to answer
this question: {question}. Your response must be at
most 10 words long.
'A search query that would bring up papers that can answer
this question would be: '""",
)

query_chain = langchain.chains.llm.LLMChain(llm=llm, prompt=prompt)
if not os.path.isdir("./query"): # todo: move to ckpt
os.mkdir("query/")
search = query_chain.run(query)
print("\nSearch:", search)
papers = paper_scraper(search, pdir=f"query/{re.sub(' ', '', search)}")
return papers


def scholar2result_llm(llm, query, k=5, max_sources=2):
"""Useful to answer questions that require
technical knowledge. Ask a specific question."""
papers = paper_search(llm, query)
if len(papers) == 0:
return "Not enough papers found"
docs = paperqa.Docs(llm=llm)
not_loaded = 0
for path, data in papers.items():
try:
docs.add(path, data["citation"])
except (ValueError, FileNotFoundError, PdfReadError):
not_loaded += 1

print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}")
answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer
return answer


class Scholar2ResultLLM(BaseTool):
name = "LiteratureSearch"
description = """Input a specific question,
returns an answer from literature search."""
description = (
"Useful to answer questions that require technical "
"knowledge. Ask a specific question."
)
llm: BaseLanguageModel = None

pqa_key: str = ""

def __init__(self, pqa_key: str):
def __init__(self, llm):
super().__init__()
self.pqa_key = pqa_key
self.llm = llm

def _run(self, question: str) -> str:
"""Use the tool"""
try:
response = pqapi.agent_query("default", question)
return response.answer
except Exception:
return "Literature search failed."

async def _arun(self, question: str) -> str:
"""Use the tool asynchronously"""
raise NotImplementedError
def _run(self, query) -> str:
return scholar2result_llm(self.llm, query)

async def _arun(self, query) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("this tool does not support async")
9 changes: 1 addition & 8 deletions mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def make_all_tools(

# add base tools
base_tools = [
Scholar2ResultLLM(llm=llm),
CleaningToolFunction(path_registry=path_instance),
ListRegistryPaths(path_registry=path_instance),
ProteinName2PDBTool(path_registry=path_instance),
Expand Down Expand Up @@ -108,14 +109,6 @@ def make_all_tools(
learned_tools = get_learned_tools(subagent_settings.ckpt_dir)

all_tools += base_tools + subagents_tools + learned_tools

# add other tools depending on api keys
os.getenv("SERP_API_KEY")
pqa_key = os.getenv("PQA_API_KEY")
# if serp_key:
# all_tools.append(SerpGitTool(serp_key)) # github issues search
if pqa_key:
all_tools.append(Scholar2ResultLLM(pqa_key)) # literature search
return all_tools


Expand Down
121 changes: 121 additions & 0 deletions notebooks/lit_search.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/samcox/anaconda3/envs/mda_feb21/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from mdagent import MDAgent"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#until we update to new version\n",
"import nest_asyncio\n",
"nest_asyncio.apply()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"mda = MDAgent()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"Are there any studies that show that the use of a mask can reduce the spread of COVID-19?\""
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\"Masks COVID-19 transmission reduction studies\"\n",
"Search: \"Masks COVID-19 transmission reduction studies\"\n",
"\n",
"Found 14 papers but couldn't load 0\n",
"Yes, there are studies that show that the use of a mask can reduce the spread of COVID-19. The review by Howard et al. (2021) indicates that mask-wearing reduces the transmissibility of COVID-19 by limiting the spread of infected respiratory particles. This conclusion is supported by evidence from both laboratory and clinical studies."
]
}
],
"source": [
"answer = mda.run(prompt)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Yes, there are studies that show that the use of a mask can reduce the spread of COVID-19. The review by Howard et al. (2021) indicates that mask-wearing reduces the transmissibility of COVID-19 by limiting the spread of infected respiratory particles. This conclusion is supported by evidence from both laboratory and clinical studies.'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"answer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mdagent",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
9 changes: 4 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,21 @@
license="MIT",
packages=find_packages(),
install_requires=[
"paper-scraper @ git+https://github.com/blackadad/paper-scraper.git",
"chromadb==0.3.29",
"google-search-results",
"langchain==0.0.336",
"langchain_experimental",
"matplotlib",
"nbformat",
"openai",
"paper-qa",
"python-dotenv",
"pqapi",
"requests",
"rmrkl",
"tiktoken",
"rdkit",
"streamlit",
"paper-qa",
"openmm",
"MDAnalysis",
"paper-scraper @ git+https://github.com/blackadad/paper-scraper.git",
],
test_suite="tests",
long_description=long_description,
Expand Down
23 changes: 23 additions & 0 deletions tests/test_fxns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from unittest.mock import MagicMock, mock_open, patch

import pytest
from langchain.chat_models import ChatOpenAI

from mdagent.tools.base_tools import (
CleaningTools,
Scholar2ResultLLM,
SimulationFunctions,
VisFunctions,
get_pdb,
Expand Down Expand Up @@ -438,3 +440,24 @@ def test_init_path_registry(path_registry_with_mocked_fs):
# you may need to check the internal state or the contents of the JSON file.
# For example:
assert "water_000000" in path_registry_with_mocked_fs.list_path_names()


@pytest.fixture
def questions():
qs = [
"What are the effects of norhalichondrin B in mammals?",
]
return qs[0]


@pytest.mark.skip(reason="This requires an API call")
def test_litsearch(questions):
llm = ChatOpenAI()

searchtool = Scholar2ResultLLM(llm=llm)
for q in questions:
ans = searchtool._run(q)
assert isinstance(ans, str)
assert len(ans) > 0
if os.path.exists("../query"):
os.rmdir("../query")

0 comments on commit 759b34e

Please sign in to comment.