diff --git a/lui/src/services/ChatController.ts b/lui/src/services/ChatController.ts index 553b3bc9..352215fd 100644 --- a/lui/src/services/ChatController.ts +++ b/lui/src/services/ChatController.ts @@ -1,20 +1,21 @@ -import { IPrompt } from 'lui/interface'; +import { ChatMessage } from '@ant-design/pro-chat'; /** * Chat api * @param message IPrompt */ -export async function streamChat(messages: IPrompt[]): Promise { - return fetch('http://localhost:8000/api/chat/stream', { +export async function streamChat( + messages: ChatMessage>[], +): Promise { + return fetch('http://127.0.0.1:8000/api/chat/stream', { method: 'POST', - credentials: 'include', headers: { 'Content-Type': 'application/json', connection: 'keep-alive', 'keep-alive': 'timeout=5', }, body: JSON.stringify({ - messages: messages, + input_data: messages[0].content, }), }); } diff --git a/lui/src/utils/chatTranslator.ts b/lui/src/utils/chatTranslator.ts index 7756fec7..26b6fed0 100644 --- a/lui/src/utils/chatTranslator.ts +++ b/lui/src/utils/chatTranslator.ts @@ -1,19 +1,3 @@ -const convertChunkToJson = (rawData: string) => { - const regex = /data: (.*?})\s*$/; - const match = rawData.match(regex); - if (match && match[1]) { - try { - return JSON.parse(match[1]); - } catch (e) { - console.error('Parsing error:', e); - return null; - } - } else { - console.error('No valid JSON found in input'); - return null; - } -}; - export const handleStream = async (response: Response) => { const reader = response.body!.getReader(); const decoder = new TextDecoder('utf-8'); @@ -30,9 +14,7 @@ export const handleStream = async (response: Response) => { return; } const chunk = decoder.decode(value, { stream: true }); - const message = convertChunkToJson(chunk); - - controller.enqueue(encoder.encode(message.data)); + controller.enqueue(encoder.encode(chunk)); push(); }) .catch((err) => { diff --git a/server/__pycache__/data_class.cpython-311.pyc b/server/__pycache__/data_class.cpython-311.pyc index 806197c9..1b2e919e 100644 Binary files a/server/__pycache__/data_class.cpython-311.pyc and b/server/__pycache__/data_class.cpython-311.pyc differ diff --git a/server/__pycache__/index.cpython-311.pyc b/server/__pycache__/index.cpython-311.pyc deleted file mode 100644 index f6d9c009..00000000 Binary files a/server/__pycache__/index.cpython-311.pyc and /dev/null differ diff --git a/server/agent/stream.py b/server/agent/stream.py new file mode 100644 index 00000000..a0808a54 --- /dev/null +++ b/server/agent/stream.py @@ -0,0 +1,127 @@ +from typing import Any, AsyncIterator, List, Literal +from langchain.agents import AgentExecutor, tool +from langchain.agents.format_scratchpad.openai_tools import ( + format_to_openai_tool_messages, +) +from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser +from langchain.prompts import MessagesPlaceholder +from langchain_community.tools.convert_to_openai import format_tool_to_openai_tool +from langchain_core.prompts import ChatPromptTemplate +from langchain_openai import ChatOpenAI +from tools import issue + + +prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are very powerful assistant. When you are asked questions, you can determine whether to use the corresponding tools based on the descriptions of the actions. There may be two situations:" + "1. Talk with the user as normal. " + "2. If they ask you about issues, use a tool", + ), + ("user", "{input}"), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ] +) + + + +TOOL_MAPPING = { + "create_issue": issue.create_issue, + "get_issues_list": issue.get_issues_list, + "get_issues_by_number": issue.get_issues_by_number +} +TOOLS = ["create_issue", "get_issues_list", "get_issues_by_number"] + + +def _create_agent_with_tools(openai_api_key: str ) -> AgentExecutor: + openai_api_key=openai_api_key + llm = ChatOpenAI(model="gpt-4", temperature=0.2, streaming=True) + tools = [] + + for requested_tool in TOOLS: + if requested_tool not in TOOL_MAPPING: + raise ValueError(f"Unknown tool: {requested_tool}") + tools.append(TOOL_MAPPING[requested_tool]) + + if tools: + llm_with_tools = llm.bind( + tools=[format_tool_to_openai_tool(tool) for tool in tools] + ) + else: + llm_with_tools = llm + + agent = ( + { + "input": lambda x: x["input"], + "agent_scratchpad": lambda x: format_to_openai_tool_messages( + x["intermediate_steps"] + ), + } + | prompt + | llm_with_tools + | OpenAIToolsAgentOutputParser() + ) + agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True).with_config( + {"run_name": "agent"} + ) + return agent_executor + + + +async def agent_chat(input_data: str, openai_api_key) -> AsyncIterator[str]: + try: + agent_executor = _create_agent_with_tools(openai_api_key) + async for event in agent_executor.astream_events( + { + "input": input_data, + }, + version="v1", + ): + kind = event["event"] + if kind == "on_chain_start": + if ( + event["name"] == "agent" + ): # matches `.with_config({"run_name": "Agent"})` in agent_executor + yield "\n" + yield ( + f"Starting agent: {event['name']} " + f"with input: {event['data'].get('input')}" + ) + yield "\n" + elif kind == "on_chain_end": + if ( + event["name"] == "agent" + ): # matches `.with_config({"run_name": "Agent"})` in agent_executor + yield "\n" + yield ( + f"Done agent: {event['name']} " + f"with output: {event['data'].get('output')['output']}" + ) + yield "\n" + if kind == "on_chat_model_stream": + content = event["data"]["chunk"].content + if content: + # Empty content in the context of OpenAI means + # that the model is asking for a tool to be invoked. + # So we only print non-empty content + yield f"{content}" + elif kind == "on_tool_start": + yield "\n" + yield ( + f"Starting tool: {event['name']} " + f"with inputs: {event['data'].get('input')}" + ) + yield "\n" + elif kind == "on_tool_end": + yield "\n" + yield ( + f"Done tool: {event['name']} " + f"with output: {event['data'].get('output')}" + ) + yield "\n" + + + except Exception as e: + yield f"data: {str(e)}\n\n" + diff --git a/server/data_class.py b/server/data_class.py index b159a9fe..6fbe9170 100644 --- a/server/data_class.py +++ b/server/data_class.py @@ -13,3 +13,6 @@ class Message(BaseModel): class ChatData(BaseModel): messages: list[Message] = [] prompt: str = None + +class DataItem(BaseModel): + input_data: str diff --git a/server/langchain_api/__pycache__/chat.cpython-311.pyc b/server/langchain_api/__pycache__/chat.cpython-311.pyc index a133e4e9..cefaad03 100644 Binary files a/server/langchain_api/__pycache__/chat.cpython-311.pyc and b/server/langchain_api/__pycache__/chat.cpython-311.pyc differ diff --git a/server/langchain_api/chat.py b/server/langchain_api/chat.py index 6eb1b04b..5a323e1b 100644 --- a/server/langchain_api/chat.py +++ b/server/langchain_api/chat.py @@ -1,4 +1,4 @@ -from langchain.chat_models import ChatOpenAI +from langchain_community.chat_models import ChatOpenAI from langchain.prompts import ( ChatPromptTemplate, AIMessagePromptTemplate, diff --git a/server/main.py b/server/main.py index c2fda1c0..8a9354c7 100644 --- a/server/main.py +++ b/server/main.py @@ -1,26 +1,45 @@ import os from fastapi import FastAPI - -from data_class import DalleData, ChatData +from fastapi.responses import StreamingResponse +from fastapi.middleware.cors import CORSMiddleware +from data_class import DalleData, ChatData, DataItem from openai_api import dalle from langchain_api import chat +from agent import stream open_api_key = os.getenv("OPENAI_API_KEY") -app = FastAPI() +app = FastAPI( + title="Bo-meta Server", + version="1.0", + description="Agent Chat APIs" + ) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["*"], +) @app.get("/") def read_root(): return {"Hello": "World"} - @app.post("/api/dall-e") def run_img_generator(input_data: DalleData): result = dalle.img_generator(input_data, open_api_key) return result - @app.post("/api/chat") def run_langchain_chat(input_data: ChatData): result = chat.langchain_chat(input_data, open_api_key) return result + + +@app.post("/api/chat/stream", response_class=StreamingResponse) +async def run_agent_chat(input_data: DataItem): + result = stream.agent_chat(input_data, open_api_key) + return StreamingResponse(result, media_type="text/event-stream") diff --git a/server/openai_api/__pycache__/dalle.cpython-311.pyc b/server/openai_api/__pycache__/dalle.cpython-311.pyc index 589e7ed0..d7dd70f8 100644 Binary files a/server/openai_api/__pycache__/dalle.cpython-311.pyc and b/server/openai_api/__pycache__/dalle.cpython-311.pyc differ diff --git a/server/openai_api/dalle.py b/server/openai_api/dalle.py index 7c92ea7f..e606e8b1 100644 --- a/server/openai_api/dalle.py +++ b/server/openai_api/dalle.py @@ -1,5 +1,5 @@ import os -from langchain.chat_models import ChatOpenAI +from langchain_community.chat_models import ChatOpenAI from langchain.chains import LLMChain from langchain.prompts import PromptTemplate from langchain.utilities.dalle_image_generator import DallEAPIWrapper diff --git a/server/requirements.txt b/server/requirements.txt index ea0502a4..cce560a8 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,9 +1,11 @@ fastapi==0.100.1 uvicorn[standard]==0.23.2 python-dotenv==1.0.0 -langchain openai mangum -langsmith +langserve +langchain_community +langchain langchain-openai PyGithub +python-multipart