Skip to content

Commit

Permalink
expanding llm providers, updating litsearch
Browse files Browse the repository at this point in the history
  • Loading branch information
qcampbel committed Aug 7, 2024
1 parent 7f56fe5 commit c1ffe0b
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 16 deletions.
18 changes: 17 additions & 1 deletion mdagent/tools/base_tools/util_tools/search_tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import re
from typing import Optional
Expand All @@ -14,6 +15,16 @@
from mdagent.utils import PathRegistry


def configure_logging(path):
# to log all runtime errors from paperscraper, which can be VERY noisy
log_file = os.path.join(path, "scraping_errors.log")
logging.basicConfig(
filename=log_file,
level=logging.ERROR,
format="%(asctime)s:%(levelname)s:%(message)s",
)


def paper_scraper(search: str, pdir: str = "query") -> dict:
try:
return paperscraper.search_papers(search, pdir=pdir)
Expand All @@ -36,6 +47,7 @@ def paper_search(llm, query, path_registry):
query_chain = prompt | llm | StrOutputParser()
if not os.path.isdir(path):
os.mkdir(path)
configure_logging(path)
search = query_chain.invoke(query)
print("\nSearch:", search)
papers = paper_scraper(search, pdir=f"{path}/{re.sub(' ', '', search)}")
Expand All @@ -45,10 +57,14 @@ def paper_search(llm, query, path_registry):
def scholar2result_llm(llm, query, path_registry, k=5, max_sources=2):
"""Useful to answer questions that require
technical knowledge. Ask a specific question."""
if llm.model_name.startswith("gpt"):
docs = paperqa.Docs(llm=llm.model_name)
else:
docs = paperqa.Docs() # default gpt model

papers = paper_search(llm, query, path_registry)
if len(papers) == 0:
return "Failed. Not enough papers found"
docs = paperqa.Docs(llm=llm.model_name)
not_loaded = 0
for path, data in papers.items():
try:
Expand Down
6 changes: 5 additions & 1 deletion mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import streamlit as st
from dotenv import load_dotenv
from langchain import agents
Expand Down Expand Up @@ -70,8 +72,10 @@ def make_all_tools(
# all_tools += [PythonREPLTool()]
all_tools += [
ModifyBaseSimulationScriptTool(path_registry=path_instance, llm=llm),
Scholar2ResultLLM(llm=llm, path_registry=path_instance),
]
if "OPENAI_API_KEY" in os.environ:
# LiteratureSearch only works with OpenAI
all_tools += [Scholar2ResultLLM(llm=llm, path_registry=path_instance)]
if human:
all_tools += [agents.load_tools(["human"], llm)[0]]

Expand Down
29 changes: 16 additions & 13 deletions mdagent/utils/makellm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_openai import ChatOpenAI


def _make_llm(model, temp, verbose):
from langchain_openai import ChatOpenAI

if model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
llm = ChatOpenAI(
temperature=temp,
Expand All @@ -11,25 +12,27 @@ def _make_llm(model, temp, verbose):
streaming=True if verbose else False,
callbacks=[StreamingStdOutCallbackHandler()] if verbose else None,
)
elif model.startswith("llama"):
elif model.startswith("accounts/fireworks"):
from langchain_fireworks import ChatFireworks

llm = ChatFireworks(
temperature=temp,
model_name=f"accounts/fireworks/models/{model}",
model_name=model,
request_timeout=1000,
streaming=True if verbose else False,
callbacks=[StreamingStdOutCallbackHandler()] if verbose else None,
)
elif model.startswith("together/"):
# user needs to add 'together/' prefix to use TogetherAI provider
from langchain_together import ChatTogether

llm = ChatTogether(
temperature=temp,
model=model.replace("together/", ""),
request_timeout=1000,
streaming=True if verbose else False,
callbacks=[StreamingStdOutCallbackHandler()] if verbose else None,
)
# elif model.startswith("Meta-Llama"):
# from langchain_together import ChatTogether
# llm = ChatTogether(
# temperature=temp,
# model=f"meta-llama/{model}",
# request_timeout=1000,
# streaming=True if verbose else False,
# callbacks=[StreamingStdOutCallbackHandler()] if verbose else None,
# )
else:
raise ValueError(f"Invalid or Unsupported model name: {model}")
raise ValueError(f"Unrecognized or Unsupported model name: {model}")
return llm
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"langchain-community",
"langchain-fireworks==0.1.7",
"langchain-openai==0.1.19",
# "langchain-together==0.1.4",
"langchain-together==0.1.4",
"matplotlib",
"nbformat",
"openai",
Expand Down

0 comments on commit c1ffe0b

Please sign in to comment.