Skip to content

Commit

Permalink
Update to use langgraph and langchain 0.1
Browse files Browse the repository at this point in the history
- for now still using pregel api, that will change in a future pr
  • Loading branch information
nfcampos committed Jan 19, 2024
1 parent da6bd8c commit d62775d
Show file tree
Hide file tree
Showing 17 changed files with 304 additions and 148 deletions.
17 changes: 8 additions & 9 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from agent_executor.checkpoint import RedisCheckpoint
from langchain.schema.messages import AnyMessage
from langchain.utilities.redis import get_client
from permchain.channels import Topic
from permchain.channels.base import ChannelsManager, create_checkpoint
from langgraph.channels import Topic
from langgraph.channels.base import ChannelsManager
from langgraph.checkpoint.base import empty_checkpoint
from redis.client import Redis as RedisType

from app.schema import Assistant, AssistantWithoutUserId, Thread, ThreadWithoutUserId
Expand Down Expand Up @@ -148,9 +149,8 @@ def get_thread(user_id: str, thread_id: str) -> Thread | None:
def get_thread_messages(user_id: str, thread_id: str):
"""Get all messages for a thread."""
client = RedisCheckpoint()
checkpoint = client.get(
{"configurable": {"user_id": user_id, "thread_id": thread_id}}
)
config = {"configurable": {"user_id": user_id, "thread_id": thread_id}}
checkpoint = client.get(config) or empty_checkpoint()
# TODO replace hardcoded messages channel with
# channel extracted from agent
with ChannelsManager(
Expand All @@ -163,16 +163,15 @@ def post_thread_messages(user_id: str, thread_id: str, messages: Sequence[AnyMes
"""Add messages to a thread."""
client = RedisCheckpoint()
config = {"configurable": {"user_id": user_id, "thread_id": thread_id}}
checkpoint = client.get(config)
checkpoint = client.get(config) or empty_checkpoint()
# TODO replace hardcoded messages channel with
# channel extracted from agent
with ChannelsManager(
{"messages": Topic(AnyMessage, accumulate=True)}, checkpoint
) as channels:
channels["messages"].update(messages)
checkpoint = {
k: v for k, v in create_checkpoint(channels).items() if k == "messages"
}
checkpoint["channel_versions"]["messages"] += 1
checkpoint["channel_values"]["messages"] = channels["messages"].checkpoint()
client.put(config, checkpoint)


Expand Down
17 changes: 11 additions & 6 deletions backend/packages/agent-executor/agent_executor/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
import pickle
from functools import partial
from typing import Any, Mapping
from typing import Any

from langchain.pydantic_v1 import Field
from langchain.schema.runnable import RunnableConfig
from langchain.schema.runnable.utils import ConfigurableFieldSpec
from langchain.utilities.redis import get_client
from permchain.checkpoint.base import BaseCheckpointAdapter
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.checkpoint.base import Checkpoint
from redis.client import Redis as RedisType


Expand All @@ -26,7 +27,7 @@ def _load(mapping: dict[bytes, bytes]) -> dict:
}


class RedisCheckpoint(BaseCheckpointAdapter):
class RedisCheckpoint(BaseCheckpointSaver):
client: RedisType = Field(
default_factory=partial(get_client, os.environ.get("REDIS_URL"))
)
Expand Down Expand Up @@ -60,8 +61,12 @@ def _hash_key(self, config: RunnableConfig) -> str:
config["configurable"]["user_id"], config["configurable"]["thread_id"]
)

def get(self, config: RunnableConfig) -> Mapping[str, Any] | None:
return _load(self.client.hgetall(self._hash_key(config)))
def get(self, config: RunnableConfig) -> Checkpoint | None:
value = _load(self.client.hgetall(self._hash_key(config)))
if value.get("v") == 1:
return value
else:
return None

def put(self, config: RunnableConfig, checkpoint: Mapping[str, Any]) -> None:
def put(self, config: RunnableConfig, checkpoint: Checkpoint) -> None:
return self.client.hmset(self._hash_key(config), _dump(checkpoint))
11 changes: 6 additions & 5 deletions backend/packages/agent-executor/agent_executor/dnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
from permchain import BaseCheckpointAdapter, Channel, Pregel
from permchain.channels import LastValue, Topic
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.channels import LastValue, Topic
from langgraph.pregel import Channel, Pregel

character_system_msg = """You are a dungeon master for a game of dungeons and dragons.
Expand Down Expand Up @@ -86,7 +87,7 @@ def _maybe_update_character(message: AnyMessage):
)


def create_dnd_bot(llm: BaseChatModel, checkpoint: BaseCheckpointAdapter):
def create_dnd_bot(llm: BaseChatModel, checkpoint: BaseCheckpointSaver):
character_model = llm.bind(
functions=[convert_pydantic_to_openai_function(CharacterNotebook)],
)
Expand Down Expand Up @@ -122,7 +123,7 @@ def _route_to_chain(_input):
| _route_to_chain
)
dnd = Pregel(
chains={"executor": executor, "update_state": state_chain},
nodes={"executor": executor, "update_state": state_chain},
channels={
"messages": Topic(AnyMessage, accumulate=True),
"character": LastValue(str),
Expand All @@ -131,6 +132,6 @@ def _route_to_chain(_input):
},
input=["messages"],
output=["messages"],
checkpoint=checkpoint,
checkpointer=checkpoint,
)
return dnd
8 changes: 4 additions & 4 deletions backend/packages/agent-executor/agent_executor/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
"""
from typing import List

from langchain.document_loaders import Blob
from langchain.document_loaders.base import BaseBlobParser
from langchain.schema import Document
from langchain.schema.vectorstore import VectorStore
from langchain_community.document_loaders import Blob
from langchain_community.document_loaders.base import BaseBlobParser
from langchain_core.documents import Document
from langchain_core.vectorstores import VectorStore
from langchain.text_splitter import TextSplitter


Expand Down
12 changes: 6 additions & 6 deletions backend/packages/agent-executor/agent_executor/permchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
RunnablePassthrough,
)
from langchain.tools import BaseTool
from permchain import Channel, Pregel, ReservedChannels
from permchain.channels import Topic
from permchain.checkpoint.base import BaseCheckpointAdapter
from langgraph.channels import Topic
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.pregel import Channel, Pregel, ReservedChannels


def _create_agent_message(
Expand Down Expand Up @@ -74,7 +74,7 @@ async def _arun_tool(
def get_agent_executor(
tools: list[BaseTool],
agent: Runnable[dict[str, list[AnyMessage]], AgentAction | AgentFinish],
checkpoint: BaseCheckpointAdapter,
checkpoint: BaseCheckpointSaver,
) -> Pregel:
tool_map = {tool.name: tool for tool in tools}
tool_lambda = RunnableLambda(_run_tool, _arun_tool).bind(tools=tool_map)
Expand Down Expand Up @@ -118,9 +118,9 @@ def route_last_message(input: dict[str, bool | Sequence[AnyMessage]]) -> Runnabl
)

return Pregel(
chains={"executor": executor},
nodes={"executor": executor},
channels={"messages": Topic(AnyMessage, accumulate=True)},
input=["messages"],
output=["messages"],
checkpoint=checkpoint,
checkpointer=checkpoint,
)
6 changes: 3 additions & 3 deletions backend/packages/agent-executor/agent_executor/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

from typing import Any, BinaryIO, List, Optional

from langchain.document_loaders.blob_loaders.schema import Blob
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
from langchain.schema.vectorstore import VectorStore
from langchain_community.document_loaders.blob_loaders import Blob
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.vectorstores import VectorStore
from langchain.text_splitter import TextSplitter

from agent_executor.ingest import ingest_blob
Expand Down
2 changes: 1 addition & 1 deletion backend/packages/agent-executor/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ authors = [
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.8.1"
python = "^3.9.0"
langchain = ">=0.0.333"
python-magic = "^0.4.27"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.tools.render import format_tool_to_openai_function

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

import boto3
from langchain.chat_models import BedrockChat, ChatAnthropic
from langchain_community.chat_models import BedrockChat, ChatAnthropic
from langchain.schema.messages import AIMessage, HumanMessage
from langchain.tools.render import render_text_description

Expand Down
6 changes: 3 additions & 3 deletions backend/packages/gizmo-agent/gizmo_agent/ingest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os

from agent_executor.upload import IngestRunnable
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema.runnable import ConfigurableField
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.redis import Redis
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores.redis import Redis
from langchain_core.runnables import ConfigurableField

index_schema = {
"tag": [{"name": "namespace"}],
Expand Down
7 changes: 3 additions & 4 deletions backend/packages/gizmo-agent/gizmo_agent/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
from agent_executor.checkpoint import RedisCheckpoint
from agent_executor.dnd import create_dnd_bot
from agent_executor.permchain import get_agent_executor
from langchain.chat_models import ChatOpenAI
from langchain_openai import ChatOpenAI
from langchain.pydantic_v1 import BaseModel, Field
from langchain.schema.messages import AnyMessage
from langchain.schema.runnable import (
from langchain_core.messages import AnyMessage
from langchain_core.runnables import (
ConfigurableField,
ConfigurableFieldMultiOption,
RunnableBinding,
)

from gizmo_agent.agent_types import (
GizmoAgentType,
get_openai_function_agent,
Expand Down
18 changes: 11 additions & 7 deletions backend/packages/gizmo-agent/gizmo_agent/tools.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from enum import Enum

from langchain.pydantic_v1 import BaseModel, Field
from langchain.retrievers import KayAiRetriever, PubMedRetriever, WikipediaRetriever
from langchain.retrievers.you import YouRetriever
from langchain.tools import ArxivQueryRun, DuckDuckGoSearchRun
from langchain_community.retrievers import (
KayAiRetriever,
PubMedRetriever,
WikipediaRetriever,
)
from langchain_community.retrievers.you import YouRetriever
from langchain_community.tools import ArxivQueryRun, DuckDuckGoSearchRun
from langchain.tools.retriever import create_retriever_tool
from langchain.tools.tavily_search import TavilyAnswer, TavilySearchResults
from langchain.utilities import ArxivAPIWrapper
from langchain.utilities.tavily_search import TavilySearchAPIWrapper
from langchain.vectorstores.redis import RedisFilter
from langchain_community.tools.tavily_search import TavilyAnswer, TavilySearchResults
from langchain_community.utilities.arxiv import ArxivAPIWrapper
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
from langchain_community.vectorstores.redis import RedisFilter

from gizmo_agent.ingest import vstore

Expand Down
Loading

0 comments on commit d62775d

Please sign in to comment.