Skip to content

Commit

Permalink
upgraded langchain, added and sync llms in tools
Browse files Browse the repository at this point in the history
  • Loading branch information
qcampbel committed Aug 7, 2024
1 parent b9b790b commit a97652c
Show file tree
Hide file tree
Showing 11 changed files with 92 additions and 139 deletions.
26 changes: 11 additions & 15 deletions mdagent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from dotenv import load_dotenv
from langchain.agents import AgentExecutor, OpenAIFunctionsAgent
from langchain.agents.structured_chat.base import StructuredChatAgent
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chat_models import ChatOpenAI

from ..tools import get_tools, make_all_tools
from ..utils import PathRegistry, SetCheckpoint, _make_llm
Expand Down Expand Up @@ -38,7 +36,7 @@ def __init__(
tools=None,
agent_type="OpenAIFunctionsAgent", # this can also be structured_chat
model="gpt-4-1106-preview", # current name for gpt-4 turbo
tools_model="gpt-4-1106-preview",
tools_model=None,
temp=0.1,
verbose=True,
ckpt_dir="ckpt",
Expand All @@ -48,10 +46,15 @@ def __init__(
run_id="",
use_memory=True,
):
self.llm = _make_llm(model, temp, verbose)
if tools_model is None:
tools_model = model
self.tools_llm = _make_llm(tools_model, temp, verbose)

self.use_memory = use_memory
self.path_registry = PathRegistry.get_instance(ckpt_dir=ckpt_dir)
self.ckpt_dir = self.path_registry.ckpt_dir
self.memory = MemoryManager(self.path_registry, run_id=run_id)
self.memory = MemoryManager(self.path_registry, self.tools_llm, run_id=run_id)
self.run_id = self.memory.run_id

self.uploaded_files = uploaded_files
Expand All @@ -60,18 +63,9 @@ def __init__(

self.agent = None
self.agent_type = agent_type
self.user_tools = tools
self.tools_llm = _make_llm(tools_model, temp, verbose)
self.top_k_tools = top_k_tools
self.use_human_tool = use_human_tool

self.llm = ChatOpenAI(
temperature=temp,
model=model,
client=None,
streaming=True,
callbacks=[StreamingStdOutCallbackHandler()],
)
self.user_tools = tools

def _initialize_tools_and_agent(self, user_input=None):
"""Retrieve tools and initialize the agent."""
Expand All @@ -89,6 +83,7 @@ def _initialize_tools_and_agent(self, user_input=None):
# retrieve all tools, including new tools if any
self.tools = make_all_tools(
self.tools_llm,
top_k_tools=self.top_k_tools,
human=self.use_human_tool,
)
return AgentExecutor.from_agent_and_tools(
Expand All @@ -97,6 +92,7 @@ def _initialize_tools_and_agent(self, user_input=None):
self.llm,
self.tools,
),
verbose=self.verbose,
handle_parsing_errors=True,
)

Expand All @@ -107,7 +103,7 @@ def run(self, user_input, callbacks=None):
elif self.agent_type == "OpenAIFunctionsAgent":
self.prompt = openaifxn_prompt.format(input=user_input, context=run_memory)
self.agent = self._initialize_tools_and_agent(user_input)
model_output = self.agent.run(self.prompt, callbacks=callbacks)
model_output = self.agent.invoke(self.prompt, callbacks=callbacks)
if self.use_memory:
self.memory.generate_agent_summary(model_output)
print("Your run id is: ", self.run_id)
Expand Down
18 changes: 4 additions & 14 deletions mdagent/agent/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import random
import string

from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

from mdagent.utils import PathRegistry

Expand All @@ -32,8 +30,7 @@ class MemoryManager:
def __init__(
self,
path_registry: PathRegistry,
model="gpt-3.5-turbo",
temp=0.1,
llm,
run_id="",
):
self.path_registry = path_registry
Expand All @@ -46,14 +43,7 @@ def __init__(
else:
pull_mem = True

llm = ChatOpenAI(
temperature=temp,
model=model,
client=None,
streaming=True,
callbacks=[StreamingStdOutCallbackHandler()],
)
self.llm_agent_trace = LLMChain(llm=llm, prompt=agent_summary_template)
self.llm_agent_trace = agent_summary_template | llm | StrOutputParser()

self._make_all_dirs()
if pull_mem:
Expand Down Expand Up @@ -138,7 +128,7 @@ def generate_agent_summary(self, agent_trace):
Returns:
- None
"""
llm_out = self.llm_agent_trace({"agent_trace": agent_trace})["text"]
llm_out = self.llm_agent_trace.invoke({"agent_trace": agent_trace})
key_str = f"{self.run_id}.{self.get_summary_number()}"
run_summary = {key_str: llm_out}
self._write_to_json(run_summary, self.agent_trace_summary)
Expand Down
80 changes: 7 additions & 73 deletions mdagent/agent/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
Complete format:
Thought: (reflect on your progress and decide what " "to do next)
Action: (the action name, should be the name of a tool)
Action Input: (the input string to the action)
Action:
```
{{
action: (the action name, should be the name of a tool),
action_input: (the input string to the action)
}}
'''
OR
Expand All @@ -41,77 +46,6 @@
Question: {input} """,
)


modular_analysis_prompt = PromptTemplate(
input_variables=[
"Main_Task",
"Subtask_types",
"Proteins",
"Parameters",
"UserProposedPlan",
"context",
],
template="""
Approach the molecular dynamics inquiry by dissecting it into its modular
components:
Main Task: {Main_Task}
Subtasks: {Subtask_types}
Target Proteins: {Proteins}
Parameters: {Parameters}
Initial Plan Proposed by User: {UserProposedPlan}
The Main Task is the user's request.
The Subtasks are (some of/all) the individual steps that may need to be taken
to complete the Main Task; Preprocessing/Preparation usually involves
cleaning the initial pdb file (adding hydrogens, removing/adding water, etc.)
or making the required box for the simulation, Simulation involves running the
simulation and/or modifying the simulation script, Postprocessing involves
analyzing the results of the simulation (either using provided tools or figuring
it out on your own). Finally, Question is used if the user query is more
of a question than a request for a specific task.
the Target Proteins are the protein(s) that the user wants to focus on,
the Parameters are the 'special' conditions that the user wants to set and use
for the simulation, preprocessing and or analysis.
Sometimes users already have an idea of what is needed to be done.
Initial Plan Proposed by User is the user's initial plan for the simulation. You
can use this as a guide to understand what the user wants to do. You can also
modify it if you think is necessary.
You can only respond with a single complete
'Thought, Action, Action Input' format
OR a single 'Final Answer' format.
Complete format:
Thought: (reflect on your progress and decide what " "to do next)
Action: (the action name, should be the name of a tool)
Action Input: (the input string to the action)
OR
Final Answer: (the final answer to the original input
question)
Use the tools provided, using the most specific tool
available for each action.
Your final answer should contain all information
necessary to answer the question and subquestions.
Your thought process should be clean and clear,
and you must explicitly state the actions you are taking.
If you are asked to continue
or reference previous runs,
the context will be provided to you.
If context is provided, you should assume
you are continuing a chat.
Here is the input:
Previous Context: {context}
""",
)

openaifxn_prompt = PromptTemplate(
input_variables=["input", "context"],
template="""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import Optional

from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.tools import BaseTool
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field

from mdagent.utils import FileType, PathRegistry
Expand Down Expand Up @@ -48,7 +48,7 @@ def _prompt_summary(self, query: str):
prompt = PromptTemplate(
template=prompt_template, input_variables=["base_script", "query"]
)
llm_chain = LLMChain(prompt=prompt, llm=self.llm)
llm_chain = prompt | self.llm | StrOutputParser()

return llm_chain.invoke(query)

Expand Down
25 changes: 8 additions & 17 deletions mdagent/tools/base_tools/util_tools/git_issues_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,17 @@

import requests
import tiktoken
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.tools import BaseTool
from langchain_core.output_parsers import StrOutputParser
from serpapi import GoogleSearch

from mdagent.utils import _make_llm


class GitToolFunctions:
"""Class to store the functions of the tool."""

def __init__(
self,
model: str = "gpt-3.5-turbo-16k",
temp: float = 0.05,
verbose: bool = False,
):
self.model = model
self.temp = temp
self.verbose = verbose
self.llm = _make_llm(model=self.model, temp=self.temp, verbose=self.verbose)
def __init__(self, llm):
self.llm = llm

def _prompt_summary(self, query: str, output: str):
prompt_template = """You're receiving the following github issues and comments.
Expand Down Expand Up @@ -54,9 +44,9 @@ def _prompt_summary(self, query: str, output: str):
prompt = PromptTemplate(
template=prompt_template, input_variables=["query", "output"]
)
llm_chain = LLMChain(prompt=prompt, llm=self.llm)
llm_chain = prompt | self.llm | StrOutputParser()

return llm_chain.run({"query": query, "output": output})
return llm_chain.invoke({"query": query, "output": output})

"""Function to get the number of requests remaining for the Github API """

Expand All @@ -80,12 +70,13 @@ class SerpGitTool(BaseTool):
Input: """
serp_key: Optional[str]

def __init__(self, serp_key):
def __init__(self, serp_key, llm):
super().__init__()
self.serp_key = serp_key
self.llm = llm

def _run(self, query: str):
fxns = GitToolFunctions()
fxns = GitToolFunctions(self.llm)
# print("this is the key", self.serp_key)
params = {
"engine": "google",
Expand Down
23 changes: 20 additions & 3 deletions mdagent/tools/base_tools/util_tools/search_tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import re
from typing import Optional
Expand All @@ -8,11 +9,22 @@
import paperscraper
from langchain.base_language import BaseLanguageModel
from langchain.tools import BaseTool
from langchain_core.output_parsers import StrOutputParser
from pypdf.errors import PdfReadError

from mdagent.utils import PathRegistry


def configure_logging(path):
# to log all runtime errors from paperscraper, which can be VERY noisy
log_file = os.path.join(path, "scraping_errors.log")
logging.basicConfig(
filename=log_file,
level=logging.ERROR,
format="%(asctime)s:%(levelname)s:%(message)s",
)


def paper_scraper(search: str, pdir: str = "query") -> dict:
try:
return paperscraper.search_papers(search, pdir=pdir)
Expand All @@ -32,10 +44,11 @@ def paper_search(llm, query, path_registry):
)

path = f"{path_registry.ckpt_files}/query"
query_chain = langchain.chains.llm.LLMChain(llm=llm, prompt=prompt)
query_chain = prompt | llm | StrOutputParser()
if not os.path.isdir(path):
os.mkdir(path)
search = query_chain.run(query)
configure_logging(path)
search = query_chain.invoke(query)
print("\nSearch:", search)
papers = paper_scraper(search, pdir=f"{path}/{re.sub(' ', '', search)}")
return papers
Expand All @@ -44,10 +57,14 @@ def paper_search(llm, query, path_registry):
def scholar2result_llm(llm, query, path_registry, k=5, max_sources=2):
"""Useful to answer questions that require
technical knowledge. Ask a specific question."""
if llm.model_name.startswith("gpt"):
docs = paperqa.Docs(llm=llm.model_name)
else:
docs = paperqa.Docs() # uses default gpt model in paperqa

papers = paper_search(llm, query, path_registry)
if len(papers) == 0:
return "Failed. Not enough papers found"
docs = paperqa.Docs(llm=llm.model_name)
not_loaded = 0
for path, data in papers.items():
try:
Expand Down
10 changes: 6 additions & 4 deletions mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os

import streamlit as st
from dotenv import load_dotenv
from langchain import agents
from langchain.base_language import BaseLanguageModel
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings

from mdagent.utils import PathRegistry

Expand Down Expand Up @@ -70,8 +72,9 @@ def make_all_tools(
# all_tools += [PythonREPLTool()]
all_tools += [
ModifyBaseSimulationScriptTool(path_registry=path_instance, llm=llm),
Scholar2ResultLLM(llm=llm, path_registry=path_instance),
]
if "OPENAI_API_KEY" in os.environ:
all_tools += [Scholar2ResultLLM(llm=llm, path_registry=path_instance)]
if human:
all_tools += [agents.load_tools(["human"], llm)[0]]

Expand Down Expand Up @@ -151,7 +154,6 @@ def get_tools(
ids=[tool.name],
metadatas=[{"tool_name": tool.name, "index": i}],
)
vectordb.persist()

# retrieve 'k' tools
k = min(top_k_tools, vectordb._collection.count())
Expand Down
Loading

0 comments on commit a97652c

Please sign in to comment.