Skip to content

Commit

Permalink
Fix reading and writing thread messages
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Jan 26, 2024
1 parent 8447925 commit 6107c60
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 102 deletions.
74 changes: 31 additions & 43 deletions backend/app/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from app.agent_types.xml_agent import get_xml_agent_executor
from app.agent_types.google_agent import get_google_agent_executor
from app.llms import get_openai_llm, get_anthropic_llm, get_google_llm
from langchain.pydantic_v1 import BaseModel, Field
from langchain_core.messages import AnyMessage
from langchain_core.runnables import (
ConfigurableField,
Expand Down Expand Up @@ -35,6 +34,34 @@ class AgentType(str, Enum):
DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."


def get_agent_executor(
tools: list,
agent: AgentType,
system_message: str,
):
checkpointer = RedisCheckpoint()
if agent == AgentType.GPT_35_TURBO:
llm = get_openai_llm()
return get_openai_agent_executor(tools, llm, system_message, checkpointer)
elif agent == AgentType.GPT_4:
llm = get_openai_llm(gpt_4=True)
return get_openai_agent_executor(tools, llm, system_message, checkpointer)
elif agent == AgentType.AZURE_OPENAI:
llm = get_openai_llm(azure=True)
return get_openai_agent_executor(tools, llm, system_message, checkpointer)
elif agent == AgentType.CLAUDE2:
llm = get_anthropic_llm()
return get_xml_agent_executor(tools, llm, system_message, checkpointer)
elif agent == AgentType.BEDROCK_CLAUDE2:
llm = get_anthropic_llm(bedrock=True)
return get_xml_agent_executor(tools, llm, system_message, checkpointer)
elif agent == AgentType.GEMINI:
llm = get_google_llm()
return get_google_agent_executor(tools, llm, system_message, checkpointer)
else:
raise ValueError("Unexpected agent type")


class ConfigurableAgent(RunnableBinding):
tools: Sequence[str]
agent: AgentType
Expand Down Expand Up @@ -66,39 +93,8 @@ def __init__(
_tools.append(get_retrieval_tool(assistant_id, retrieval_description))
else:
_tools.append(TOOLS[_tool]())
if agent == AgentType.GPT_35_TURBO:
llm = get_openai_llm()
_agent = get_openai_agent_executor(
_tools, llm, system_message, RedisCheckpoint()
)
elif agent == AgentType.GPT_4:
llm = get_openai_llm(gpt_4=True)
_agent = get_openai_agent_executor(
_tools, llm, system_message, RedisCheckpoint()
)
elif agent == AgentType.AZURE_OPENAI:
llm = get_openai_llm(azure=True)
_agent = get_openai_agent_executor(
_tools, llm, system_message, RedisCheckpoint()
)
elif agent == AgentType.CLAUDE2:
llm = get_anthropic_llm()
_agent = get_xml_agent_executor(
_tools, llm, system_message, RedisCheckpoint()
)
elif agent == AgentType.BEDROCK_CLAUDE2:
llm = get_anthropic_llm(bedrock=True)
_agent = get_xml_agent_executor(
_tools, llm, system_message, RedisCheckpoint()
)
elif agent == AgentType.GEMINI:
llm = get_google_llm()
_agent = get_google_agent_executor(
_tools, llm, system_message, RedisCheckpoint()
)
else:
raise ValueError("Unexpected agent type")
agent_executor = _agent.with_config({"recursion_limit": 10})
_agent = get_agent_executor(_tools, agent, system_message)
agent_executor = _agent.with_config({"recursion_limit": 50})
super().__init__(
tools=tools,
agent=agent,
Expand All @@ -110,14 +106,6 @@ def __init__(
)


class AgentInput(BaseModel):
messages: Sequence[AnyMessage] = Field(default_factory=list)


class AgentOutput(BaseModel):
messages: Sequence[AnyMessage] = Field(..., extra={"widget": {"type": "chat"}})


agent = (
ConfigurableAgent(
agent=AgentType.GEMINI,
Expand All @@ -142,7 +130,7 @@ class AgentOutput(BaseModel):
id="retrieval_description", name="Retrieval Description"
),
)
.with_types(input_type=AgentInput, output_type=AgentOutput)
.with_types(input_type=Sequence[AnyMessage], output_type=Sequence[AnyMessage])
)

if __name__ == "__main__":
Expand Down
18 changes: 7 additions & 11 deletions backend/app/api/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,12 @@
_serializer = WellKnownLCSerializer()


class AgentInput(BaseModel):
"""An input into an agent."""

messages: Sequence[AnyMessage] = Field(default_factory=list)


class CreateRunPayload(BaseModel):
"""Payload for creating a run."""

assistant_id: str
thread_id: str
input: AgentInput = Field(default_factory=AgentInput)
input: Sequence[AnyMessage] = Field(default_factory=list)


async def _run_input_and_config(request: Request, opengpts_user_id: OpengptsUserId):
Expand Down Expand Up @@ -73,6 +67,8 @@ async def _run_input_and_config(request: Request, opengpts_user_id: OpengptsUser
},
}
try:
print(body["input"])
print(agent.get_input_schema(config))
input_ = _unpack_input(agent.get_input_schema(config).validate(body["input"]))
except ValidationError as e:
raise RequestValidationError(e.errors(), body=body)
Expand Down Expand Up @@ -101,7 +97,7 @@ async def stream_run(
):
"""Create a run."""
input_, config, messages = await _run_input_and_config(request, opengpts_user_id)
streamer = StreamMessagesHandler(messages + input_["messages"])
streamer = StreamMessagesHandler(messages + input_)
event_aggregator = AsyncEventAggregatorCallback()
config["callbacks"] = [streamer, event_aggregator]

Expand All @@ -110,11 +106,11 @@ async def stream_run(
async def consume_astream() -> None:
try:
async for chunk in agent.astream(input_, config):
await streamer.send_stream.send(chunk)
# await streamer.send_stream.send(chunk)
# hack: function messages aren't generated by chat model
# so the callback handler doesn't know about them
if chunk["messages"]:
message = chunk["messages"][-1]
if chunk.get("action"):
message = chunk["action"]
if isinstance(message, FunctionMessage):
streamer.output[uuid4()] = ChatGeneration(message=message)
except Exception as e:
Expand Down
38 changes: 17 additions & 21 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
from datetime import datetime
from typing import List, Sequence
from app.agent import AgentType, get_agent_executor

import orjson
from app.checkpoint import RedisCheckpoint
from langchain.schema.messages import AnyMessage
from langchain.utilities.redis import get_client
from langgraph.channels import Topic
from langgraph.channels.base import ChannelsManager
from langgraph.checkpoint.base import empty_checkpoint
from redis.client import Redis as RedisType
Expand Down Expand Up @@ -146,33 +145,30 @@ def get_thread(user_id: str, thread_id: str) -> Thread | None:
return load(thread_hash_keys, values) if any(values) else None


# TODO remove hardcoded channel name
MESSAGES_CHANNEL_NAME = "__root__"


def get_thread_messages(user_id: str, thread_id: str):
"""Get all messages for a thread."""
client = RedisCheckpoint()
config = {"configurable": {"user_id": user_id, "thread_id": thread_id}}
checkpoint = client.get(config) or empty_checkpoint()
# TODO replace hardcoded messages channel with
# channel extracted from agent
with ChannelsManager(
{"messages": Topic(AnyMessage, accumulate=True)}, checkpoint
) as channels:
return {k: v.get() for k, v in channels.items()}
app = get_agent_executor([], AgentType.GPT_35_TURBO, "")
checkpoint = app.checkpointer.get(config) or empty_checkpoint()
with ChannelsManager(app.channels, checkpoint) as channels:
return {"messages": channels[MESSAGES_CHANNEL_NAME].get()}


def post_thread_messages(user_id: str, thread_id: str, messages: Sequence[AnyMessage]):
"""Add messages to a thread."""
client = RedisCheckpoint()
config = {"configurable": {"user_id": user_id, "thread_id": thread_id}}
checkpoint = client.get(config) or empty_checkpoint()
# TODO replace hardcoded messages channel with
# channel extracted from agent
with ChannelsManager(
{"messages": Topic(AnyMessage, accumulate=True)}, checkpoint
) as channels:
channels["messages"].update(messages)
checkpoint["channel_versions"]["messages"] += 1
checkpoint["channel_values"]["messages"] = channels["messages"].checkpoint()
client.put(config, checkpoint)
app = get_agent_executor([], AgentType.GPT_35_TURBO, "")
checkpoint = app.checkpointer.get(config) or empty_checkpoint()
with ChannelsManager(app.channels, checkpoint) as channels:
channel = channels[MESSAGES_CHANNEL_NAME]
channel.update([messages])
checkpoint["channel_values"][MESSAGES_CHANNEL_NAME] = channel.checkpoint()
checkpoint["channel_versions"][MESSAGES_CHANNEL_NAME] += 1
app.checkpointer.put(config, checkpoint)


def put_thread(user_id: str, thread_id: str, *, assistant_id: str, name: str) -> Thread:
Expand Down
13 changes: 4 additions & 9 deletions backend/app/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,10 @@ def on_llm_new_token(
self.output[run_id] += chunk
# Send the messages to the stream
self.send_stream.send_nowait(
{
"messages": (
self.messages
+ [
map_chunk_to_msg(chunk.message)
for chunk in self.output.values()
]
)
}
(
self.messages
+ [map_chunk_to_msg(chunk.message) for chunk in self.output.values()]
)
)


Expand Down
18 changes: 8 additions & 10 deletions frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@ function App() {
)?.config;
if (!config) return;
await startStream(
{
messages: [
{
content: message,
additional_kwargs: {},
type: "human",
example: false,
},
],
},
[
{
content: message,
additional_kwargs: {},
type: "human",
example: false,
},
],
chat.assistant_id,
chat.thread_id
);
Expand Down
13 changes: 5 additions & 8 deletions frontend/src/hooks/useStreamState.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export interface StreamState {
export interface StreamStateProps {
stream: StreamState | null;
startStream: (
input: { messages: Message[] },
input: Message[],
assistant_id: string,
thread_id: string
) => Promise<void>;
Expand All @@ -24,14 +24,10 @@ export function useStreamState(): StreamStateProps {
const [controller, setController] = useState<AbortController | null>(null);

const startStream = useCallback(
async (
input: { messages: Message[] },
assistant_id: string,
thread_id: string
) => {
async (input: Message[], assistant_id: string, thread_id: string) => {
const controller = new AbortController();
setController(controller);
setCurrent({ status: "inflight", messages: input.messages, merge: true });
setCurrent({ status: "inflight", messages: input, merge: true });

await fetchEventSource("/runs/stream", {
signal: controller.signal,
Expand All @@ -41,7 +37,8 @@ export function useStreamState(): StreamStateProps {
openWhenHidden: true,
onmessage(msg) {
if (msg.event === "data") {
const { messages } = JSON.parse(msg.data);
console.log(msg.data);
const messages = JSON.parse(msg.data);
setCurrent((current) => ({
status: "inflight",
messages,
Expand Down

0 comments on commit 6107c60

Please sign in to comment.