Skip to content

Commit

Permalink
Merge pull request #59 from ant-xuexiao/fix_chat_history
Browse files Browse the repository at this point in the history
feat: add chat history
  • Loading branch information
xingwanying authored Mar 27, 2024
2 parents 8ccbbf4 + c389a45 commit 2b232f1
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 29 deletions.
8 changes: 7 additions & 1 deletion lui/src/Chat/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,13 @@ const Chat: FC<ChatProps> = memo(({ helloMessage }) => {
return [];
}}
request={async (messages) => {
const response = await streamChat(messages);
const newMessages = messages.map((message) => {
return {
role: message.role,
content: message.content as string,
};
});
const response = await streamChat(newMessages);
return handleStream(response);
}}
inputAreaProps={{ className: 'userInputBox h-24 !important' }}
Expand Down
9 changes: 4 additions & 5 deletions lui/src/services/ChatController.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import { ChatMessage } from '@ant-design/pro-chat';
import { IPrompt } from 'lui/interface';

/**
* Chat api
* @param message IPrompt
*/
export async function streamChat(
messages: ChatMessage<Record<string, any>>[],
): Promise<Response> {
export async function streamChat(messages: IPrompt[]): Promise<Response> {
return fetch('http://127.0.0.1:8000/api/chat/stream', {
method: 'POST',
headers: {
Expand All @@ -15,7 +13,8 @@ export async function streamChat(
'keep-alive': 'timeout=5',
},
body: JSON.stringify({
input_data: messages[0].content,
messages: messages,
prompt: '',
}),
});
}
Binary file modified server/__pycache__/data_class.cpython-311.pyc
Binary file not shown.
47 changes: 29 additions & 18 deletions server/agent/stream.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Any, AsyncIterator, List, Literal
from langchain.agents import AgentExecutor, tool
from typing import AsyncIterator
from langchain.agents import AgentExecutor
from data_class import ChatData, Message
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages,
)
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain.prompts import MessagesPlaceholder
from langchain_community.tools.convert_to_openai import format_tool_to_openai_tool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from tools import issue
from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage


prompt = ChatPromptTemplate.from_messages(
Expand All @@ -19,13 +21,13 @@
"1. Talk with the user as normal. "
"2. If they ask you about issues, use a tool",
),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)



TOOL_MAPPING = {
"create_issue": issue.create_issue,
"get_issues_list": issue.get_issues_list,
Expand All @@ -46,7 +48,7 @@ def _create_agent_with_tools(openai_api_key: str ) -> AgentExecutor:

if tools:
llm_with_tools = llm.bind(
tools=[format_tool_to_openai_tool(tool) for tool in tools]
tools=[convert_to_openai_tool(tool) for tool in tools]
)
else:
llm_with_tools = llm
Expand All @@ -57,6 +59,7 @@ def _create_agent_with_tools(openai_api_key: str ) -> AgentExecutor:
"agent_scratchpad": lambda x: format_to_openai_tool_messages(
x["intermediate_steps"]
),
"chat_history": lambda x: x["chat_history"],
}
| prompt
| llm_with_tools
Expand All @@ -68,43 +71,51 @@ def _create_agent_with_tools(openai_api_key: str ) -> AgentExecutor:
return agent_executor


def chat_history_transform(messages: list[Message]):
transformed_messages = []
for message in messages:
print('meaage', message)
if message.role == "user":
transformed_messages.append(HumanMessage(content=message.content))
elif message.role == "assistant":
transformed_messages.append(AIMessage(content=message.content))
else:
transformed_messages.append(FunctionMessage(content=message.content))
return transformed_messages


async def agent_chat(input_data: str, openai_api_key) -> AsyncIterator[str]:
async def agent_chat(input_data: ChatData, openai_api_key) -> AsyncIterator[str]:
try:
messages = input_data.messages
agent_executor = _create_agent_with_tools(openai_api_key)
print(chat_history_transform(messages))
async for event in agent_executor.astream_events(
{
"input": input_data,
"input": messages[len(messages) - 1].content,
"chat_history": chat_history_transform(messages),
},
version="v1",
):
kind = event["event"]
if kind == "on_chain_start":
if (
event["name"] == "agent"
): # matches `.with_config({"run_name": "Agent"})` in agent_executor
yield "\n"
yield (
):
print(
f"Starting agent: {event['name']} "
f"with input: {event['data'].get('input')}"
)
yield "\n"
elif kind == "on_chain_end":
if (
event["name"] == "agent"
): # matches `.with_config({"run_name": "Agent"})` in agent_executor
yield "\n"
yield (
):
print (
f"Done agent: {event['name']} "
f"with output: {event['data'].get('output')['output']}"
)
yield "\n"
if kind == "on_chat_model_stream":
content = event["data"]["chunk"].content
if content:
# Empty content in the context of OpenAI means
# that the model is asking for a tool to be invoked.
# So we only print non-empty content
yield f"{content}"
elif kind == "on_tool_start":
yield "\n"
Expand Down
3 changes: 0 additions & 3 deletions server/data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,3 @@ class Message(BaseModel):
class ChatData(BaseModel):
messages: list[Message] = []
prompt: str = None

class DataItem(BaseModel):
input_data: str
Binary file modified server/langchain_api/__pycache__/chat.cpython-311.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from data_class import DalleData, ChatData, DataItem
from data_class import DalleData, ChatData
from openai_api import dalle
from langchain_api import chat
from agent import stream
Expand Down Expand Up @@ -40,6 +40,6 @@ def run_langchain_chat(input_data: ChatData):


@app.post("/api/chat/stream", response_class=StreamingResponse)
async def run_agent_chat(input_data: DataItem):
async def run_agent_chat(input_data: ChatData):
result = stream.agent_chat(input_data, open_api_key)
return StreamingResponse(result, media_type="text/event-stream")
Binary file modified server/openai_api/__pycache__/dalle.cpython-311.pyc
Binary file not shown.

0 comments on commit 2b232f1

Please sign in to comment.