Skip to content

Commit

Permalink
few updates to lit search tool
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 committed Feb 21, 2024
1 parent f73a662 commit dfac744
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
19 changes: 11 additions & 8 deletions mdagent/tools/base_tools/util_tools/search_tools.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
import os
import re
from langchain.base_language import BaseLanguageModel

import langchain
import paperqa
import paperscraper
from langchain.base_language import BaseLanguageModel
from pypdf.errors import PdfReadError


def paper_scraper(search:str, pdir:str="query") -> dict:
def paper_scraper(search: str, pdir: str = "query") -> dict:
try:
return paperscraper.search_papers(search, pdir=pdir)
except KeyError:
return {}



def paper_search(llm, query):
prompt = langchain.prompts.PromptTemplate(
input_variables=["question"],
template="""
I would like to find scholarly papers to answer
this question: {question}.
'A search query that would bring up papers that can answer
this question would be: '""",)

this question would be: '""",
)

query_chain = langchain.chains.llm.LLMChain(llm=llm, prompt=prompt)
if not os.path.isdir("./query"): #todo: move to ckpt
if not os.path.isdir("./query"): # todo: move to ckpt
os.mkdir("query/")

search = query_chain.run(query)
Expand Down Expand Up @@ -51,7 +54,7 @@ def scholar2result_llm(llm, query):


class Scholar2ResultLLM:
name = "Literature Search"
name = "LiteratureSearch"
description = (
"Useful to answer questions that require technical ",
"knowledge. Ask a specific question.",
Expand All @@ -66,4 +69,4 @@ def _run(self, query) -> str:

async def _arun(self, query) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("this tool does not support async")
raise NotImplementedError("this tool does not support async")
1 change: 0 additions & 1 deletion mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
SimulationOutputFigures,
SmallMolPDB,
VisualizeProtein,
Scholar2ResultLLM,
)
from .subagent_tools import RetryExecuteSkill, SkillRetrieval, WorkflowPlan

Expand Down
12 changes: 7 additions & 5 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
import warnings
from unittest.mock import MagicMock, mock_open, patch
from langchain.chat_models import ChatOpenAI

import pytest
from mdagent.tools.base_tools import Scholar2ResultLLM
from mdagent.tools.base_tools import VisFunctions, get_pdb
from langchain.chat_models import ChatOpenAI

from mdagent.tools.base_tools import Scholar2ResultLLM, VisFunctions, get_pdb
from mdagent.tools.base_tools.analysis_tools.plot_tools import plot_data, process_csv
from mdagent.utils import PathRegistry

Expand Down Expand Up @@ -113,13 +114,15 @@ def test_getpdb(fibronectin, get_registry):
name, _ = get_pdb(fibronectin, get_registry)
assert name.endswith(".pdb")


@pytest.fixture
def questions():
qs = [
"What are the effects of norhalichondrin B in mammals?",
]
return qs[0]


@pytest.mark.skip(reason="This requires an API call")
def test_litsearch(questions):
llm = ChatOpenAI()
Expand All @@ -129,6 +132,5 @@ def test_litsearch(questions):
ans = searchtool._run(q)
assert isinstance(ans, str)
assert len(ans) > 0
#then if query folder exists one step back, delete it
if os.path.exists("../query"):
os.rmdir("../query")
os.rmdir("../query")

0 comments on commit dfac744

Please sign in to comment.