Skip to content

Commit

Permalink
Merge pull request #58 from ant-xuexiao/init_stream_agent
Browse files Browse the repository at this point in the history
feat: init stream agent
  • Loading branch information
xingwanying authored Mar 27, 2024
2 parents 970fc7f + 16379ce commit 8ccbbf4
Show file tree
Hide file tree
Showing 12 changed files with 167 additions and 33 deletions.
11 changes: 6 additions & 5 deletions lui/src/services/ChatController.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import { IPrompt } from 'lui/interface';
import { ChatMessage } from '@ant-design/pro-chat';

/**
* Chat api
* @param message IPrompt
*/
export async function streamChat(messages: IPrompt[]): Promise<Response> {
return fetch('http://localhost:8000/api/chat/stream', {
export async function streamChat(
messages: ChatMessage<Record<string, any>>[],
): Promise<Response> {
return fetch('http://127.0.0.1:8000/api/chat/stream', {
method: 'POST',
credentials: 'include',
headers: {
'Content-Type': 'application/json',
connection: 'keep-alive',
'keep-alive': 'timeout=5',
},
body: JSON.stringify({
messages: messages,
input_data: messages[0].content,
}),
});
}
20 changes: 1 addition & 19 deletions lui/src/utils/chatTranslator.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,3 @@
const convertChunkToJson = (rawData: string) => {
const regex = /data: (.*?})\s*$/;
const match = rawData.match(regex);
if (match && match[1]) {
try {
return JSON.parse(match[1]);
} catch (e) {
console.error('Parsing error:', e);
return null;
}
} else {
console.error('No valid JSON found in input');
return null;
}
};

export const handleStream = async (response: Response) => {
const reader = response.body!.getReader();
const decoder = new TextDecoder('utf-8');
Expand All @@ -30,9 +14,7 @@ export const handleStream = async (response: Response) => {
return;
}
const chunk = decoder.decode(value, { stream: true });
const message = convertChunkToJson(chunk);

controller.enqueue(encoder.encode(message.data));
controller.enqueue(encoder.encode(chunk));
push();
})
.catch((err) => {
Expand Down
Binary file modified server/__pycache__/data_class.cpython-311.pyc
Binary file not shown.
Binary file removed server/__pycache__/index.cpython-311.pyc
Binary file not shown.
127 changes: 127 additions & 0 deletions server/agent/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Any, AsyncIterator, List, Literal
from langchain.agents import AgentExecutor, tool
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages,
)
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain.prompts import MessagesPlaceholder
from langchain_community.tools.convert_to_openai import format_tool_to_openai_tool
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from tools import issue


prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are very powerful assistant. When you are asked questions, you can determine whether to use the corresponding tools based on the descriptions of the actions. There may be two situations:"
"1. Talk with the user as normal. "
"2. If they ask you about issues, use a tool",
),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)



TOOL_MAPPING = {
"create_issue": issue.create_issue,
"get_issues_list": issue.get_issues_list,
"get_issues_by_number": issue.get_issues_by_number
}
TOOLS = ["create_issue", "get_issues_list", "get_issues_by_number"]


def _create_agent_with_tools(openai_api_key: str ) -> AgentExecutor:
openai_api_key=openai_api_key
llm = ChatOpenAI(model="gpt-4", temperature=0.2, streaming=True)
tools = []

for requested_tool in TOOLS:
if requested_tool not in TOOL_MAPPING:
raise ValueError(f"Unknown tool: {requested_tool}")
tools.append(TOOL_MAPPING[requested_tool])

if tools:
llm_with_tools = llm.bind(
tools=[format_tool_to_openai_tool(tool) for tool in tools]
)
else:
llm_with_tools = llm

agent = (
{
"input": lambda x: x["input"],
"agent_scratchpad": lambda x: format_to_openai_tool_messages(
x["intermediate_steps"]
),
}
| prompt
| llm_with_tools
| OpenAIToolsAgentOutputParser()
)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True).with_config(
{"run_name": "agent"}
)
return agent_executor



async def agent_chat(input_data: str, openai_api_key) -> AsyncIterator[str]:
try:
agent_executor = _create_agent_with_tools(openai_api_key)
async for event in agent_executor.astream_events(
{
"input": input_data,
},
version="v1",
):
kind = event["event"]
if kind == "on_chain_start":
if (
event["name"] == "agent"
): # matches `.with_config({"run_name": "Agent"})` in agent_executor
yield "\n"
yield (
f"Starting agent: {event['name']} "
f"with input: {event['data'].get('input')}"
)
yield "\n"
elif kind == "on_chain_end":
if (
event["name"] == "agent"
): # matches `.with_config({"run_name": "Agent"})` in agent_executor
yield "\n"
yield (
f"Done agent: {event['name']} "
f"with output: {event['data'].get('output')['output']}"
)
yield "\n"
if kind == "on_chat_model_stream":
content = event["data"]["chunk"].content
if content:
# Empty content in the context of OpenAI means
# that the model is asking for a tool to be invoked.
# So we only print non-empty content
yield f"{content}"
elif kind == "on_tool_start":
yield "\n"
yield (
f"Starting tool: {event['name']} "
f"with inputs: {event['data'].get('input')}"
)
yield "\n"
elif kind == "on_tool_end":
yield "\n"
yield (
f"Done tool: {event['name']} "
f"with output: {event['data'].get('output')}"
)
yield "\n"


except Exception as e:
yield f"data: {str(e)}\n\n"

3 changes: 3 additions & 0 deletions server/data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ class Message(BaseModel):
class ChatData(BaseModel):
messages: list[Message] = []
prompt: str = None

class DataItem(BaseModel):
input_data: str
Binary file modified server/langchain_api/__pycache__/chat.cpython-311.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion server/langchain_api/chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain.chat_models import ChatOpenAI
from langchain_community.chat_models import ChatOpenAI
from langchain.prompts import (
ChatPromptTemplate,
AIMessagePromptTemplate,
Expand Down
29 changes: 24 additions & 5 deletions server/main.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,45 @@
import os
from fastapi import FastAPI

from data_class import DalleData, ChatData
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from data_class import DalleData, ChatData, DataItem
from openai_api import dalle
from langchain_api import chat
from agent import stream

open_api_key = os.getenv("OPENAI_API_KEY")

app = FastAPI()
app = FastAPI(
title="Bo-meta Server",
version="1.0",
description="Agent Chat APIs"
)

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)

@app.get("/")
def read_root():
return {"Hello": "World"}


@app.post("/api/dall-e")
def run_img_generator(input_data: DalleData):
result = dalle.img_generator(input_data, open_api_key)
return result


@app.post("/api/chat")
def run_langchain_chat(input_data: ChatData):
result = chat.langchain_chat(input_data, open_api_key)
return result


@app.post("/api/chat/stream", response_class=StreamingResponse)
async def run_agent_chat(input_data: DataItem):
result = stream.agent_chat(input_data, open_api_key)
return StreamingResponse(result, media_type="text/event-stream")
Binary file modified server/openai_api/__pycache__/dalle.cpython-311.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion server/openai_api/dalle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from langchain.chat_models import ChatOpenAI
from langchain_community.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.utilities.dalle_image_generator import DallEAPIWrapper
Expand Down
6 changes: 4 additions & 2 deletions server/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
fastapi==0.100.1
uvicorn[standard]==0.23.2
python-dotenv==1.0.0
langchain
openai
mangum
langsmith
langserve
langchain_community
langchain
langchain-openai
PyGithub
python-multipart

0 comments on commit 8ccbbf4

Please sign in to comment.