From c389a4512c100c7d6698c1741c245ab11c0e6b94 Mon Sep 17 00:00:00 2001 From: yingying Date: Wed, 27 Mar 2024 22:45:33 +0800 Subject: [PATCH] feat: add chat history --- lui/src/Chat/index.tsx | 8 ++- lui/src/services/ChatController.ts | 9 ++-- server/__pycache__/data_class.cpython-311.pyc | Bin 1300 -> 1074 bytes server/agent/stream.py | 47 +++++++++++------- server/data_class.py | 3 -- .../__pycache__/chat.cpython-311.pyc | Bin 2702 -> 2702 bytes server/main.py | 4 +- .../__pycache__/dalle.cpython-311.pyc | Bin 1689 -> 1689 bytes 8 files changed, 42 insertions(+), 29 deletions(-) diff --git a/lui/src/Chat/index.tsx b/lui/src/Chat/index.tsx index 31dc9d57..0bbe1101 100644 --- a/lui/src/Chat/index.tsx +++ b/lui/src/Chat/index.tsx @@ -61,7 +61,13 @@ const Chat: FC = memo(({ helloMessage }) => { return []; }} request={async (messages) => { - const response = await streamChat(messages); + const newMessages = messages.map((message) => { + return { + role: message.role, + content: message.content as string, + }; + }); + const response = await streamChat(newMessages); return handleStream(response); }} inputAreaProps={{ className: 'userInputBox h-24 !important' }} diff --git a/lui/src/services/ChatController.ts b/lui/src/services/ChatController.ts index 352215fd..c22d0439 100644 --- a/lui/src/services/ChatController.ts +++ b/lui/src/services/ChatController.ts @@ -1,12 +1,10 @@ -import { ChatMessage } from '@ant-design/pro-chat'; +import { IPrompt } from 'lui/interface'; /** * Chat api * @param message IPrompt */ -export async function streamChat( - messages: ChatMessage>[], -): Promise { +export async function streamChat(messages: IPrompt[]): Promise { return fetch('http://127.0.0.1:8000/api/chat/stream', { method: 'POST', headers: { @@ -15,7 +13,8 @@ export async function streamChat( 'keep-alive': 'timeout=5', }, body: JSON.stringify({ - input_data: messages[0].content, + messages: messages, + prompt: '', }), }); } diff --git a/server/__pycache__/data_class.cpython-311.pyc b/server/__pycache__/data_class.cpython-311.pyc index 1b2e919e0d7294c3d0ec6d2bf29288b4ce2a3a27..f1e86a44337570330335062ed2a0a55235eabc97 100644 GIT binary patch delta 80 zcmbQjwTXjoIWI340}vz#vZUQ+n8+u=m^M*8l`VxMm_d_s*e05Q_i+ delta 262 zcmdnQF@=k7IWI340}xbxVM^m>oX97^*fdc+Rf&ngogsyzg&~DAm2nv}1H)<{hJYy6 z6s}+fP4118wlGb0W)hpYOHq`=C9x#Yvm`b57FT9oL1{^R3XnhfJ)@X~6i`zUh!6zg zVo4y;fCO&{2sijQ_%{M^gMSepOs$_L+hk8>c@3~;UXUnQiXX`O#bJ}1pHiBWYF8u+ i AgentExecutor: if tools: llm_with_tools = llm.bind( - tools=[format_tool_to_openai_tool(tool) for tool in tools] + tools=[convert_to_openai_tool(tool) for tool in tools] ) else: llm_with_tools = llm @@ -57,6 +59,7 @@ def _create_agent_with_tools(openai_api_key: str ) -> AgentExecutor: "agent_scratchpad": lambda x: format_to_openai_tool_messages( x["intermediate_steps"] ), + "chat_history": lambda x: x["chat_history"], } | prompt | llm_with_tools @@ -68,13 +71,28 @@ def _create_agent_with_tools(openai_api_key: str ) -> AgentExecutor: return agent_executor +def chat_history_transform(messages: list[Message]): + transformed_messages = [] + for message in messages: + print('meaage', message) + if message.role == "user": + transformed_messages.append(HumanMessage(content=message.content)) + elif message.role == "assistant": + transformed_messages.append(AIMessage(content=message.content)) + else: + transformed_messages.append(FunctionMessage(content=message.content)) + return transformed_messages + -async def agent_chat(input_data: str, openai_api_key) -> AsyncIterator[str]: +async def agent_chat(input_data: ChatData, openai_api_key) -> AsyncIterator[str]: try: + messages = input_data.messages agent_executor = _create_agent_with_tools(openai_api_key) + print(chat_history_transform(messages)) async for event in agent_executor.astream_events( { - "input": input_data, + "input": messages[len(messages) - 1].content, + "chat_history": chat_history_transform(messages), }, version="v1", ): @@ -82,29 +100,22 @@ async def agent_chat(input_data: str, openai_api_key) -> AsyncIterator[str]: if kind == "on_chain_start": if ( event["name"] == "agent" - ): # matches `.with_config({"run_name": "Agent"})` in agent_executor - yield "\n" - yield ( + ): + print( f"Starting agent: {event['name']} " f"with input: {event['data'].get('input')}" ) - yield "\n" elif kind == "on_chain_end": if ( event["name"] == "agent" - ): # matches `.with_config({"run_name": "Agent"})` in agent_executor - yield "\n" - yield ( + ): + print ( f"Done agent: {event['name']} " f"with output: {event['data'].get('output')['output']}" ) - yield "\n" if kind == "on_chat_model_stream": content = event["data"]["chunk"].content if content: - # Empty content in the context of OpenAI means - # that the model is asking for a tool to be invoked. - # So we only print non-empty content yield f"{content}" elif kind == "on_tool_start": yield "\n" diff --git a/server/data_class.py b/server/data_class.py index 6fbe9170..b159a9fe 100644 --- a/server/data_class.py +++ b/server/data_class.py @@ -13,6 +13,3 @@ class Message(BaseModel): class ChatData(BaseModel): messages: list[Message] = [] prompt: str = None - -class DataItem(BaseModel): - input_data: str diff --git a/server/langchain_api/__pycache__/chat.cpython-311.pyc b/server/langchain_api/__pycache__/chat.cpython-311.pyc index cefaad0360c36534d853a3ce11ccd92be330f37b..99fe28e1cf5af508dbd9d31bfbdf8148221d1930 100644 GIT binary patch delta 19 ZcmeAZ?Gxo%&dbZi00dz_HgdIa0RSx!1hW7D delta 19 ZcmeAZ?Gxo%&dbZi00cpgH*&Rb0RSwI1eyQ< diff --git a/server/main.py b/server/main.py index 8a9354c7..d1320c1d 100644 --- a/server/main.py +++ b/server/main.py @@ -2,7 +2,7 @@ from fastapi import FastAPI from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware -from data_class import DalleData, ChatData, DataItem +from data_class import DalleData, ChatData from openai_api import dalle from langchain_api import chat from agent import stream @@ -40,6 +40,6 @@ def run_langchain_chat(input_data: ChatData): @app.post("/api/chat/stream", response_class=StreamingResponse) -async def run_agent_chat(input_data: DataItem): +async def run_agent_chat(input_data: ChatData): result = stream.agent_chat(input_data, open_api_key) return StreamingResponse(result, media_type="text/event-stream") diff --git a/server/openai_api/__pycache__/dalle.cpython-311.pyc b/server/openai_api/__pycache__/dalle.cpython-311.pyc index d7dd70f8fa09992588c5c18ea4f85df461ef9671..7f8bb0a8994a9efcdc7e1219d755b683bd889872 100644 GIT binary patch delta 19 ZcmbQqJCm1dIWI340}zD$*vQq-1^_Ld1i%0Q delta 19 ZcmbQqJCm1dIWI340}zBi-pJL@1^_KP1gro6