Skip to content

Commit

Permalink
feat: support multiple llm (#231)
Browse files Browse the repository at this point in the history
- 支持换模 `gemini-1.5-flash`
  • Loading branch information
RaoHai authored Aug 23, 2024
2 parents a203793 + 17fc0a6 commit 9f9707e
Show file tree
Hide file tree
Showing 17 changed files with 203 additions and 65 deletions.
1 change: 1 addition & 0 deletions .github/workflows/aws-preview.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ jobs:
GithubAppsClientId=${{ secrets.X_GITHUB_APPS_CLIENT_ID }} \
GithubAppsClientSecret=${{ secrets.X_GITHUB_APPS_CLIENT_SECRET }} \
OpenAIAPIKey=${{ secrets.OPENAI_API_KEY }} \
GeminiAPIKey=${{ secrets.GEMINI_API_KEY }} \
SupabaseServiceKey=${{ secrets.SUPABASE_SERVICE_KEY }} \
SupabaseUrl=${{ secrets.SUPABASE_URL }} \
TavilyAPIKey=${{ secrets.TAVILY_API_KEY }} \
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/aws-prod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ jobs:
GithubAppsClientId=${{ secrets.X_GITHUB_APPS_CLIENT_ID }} \
GithubAppsClientSecret=${{ secrets.X_GITHUB_APPS_CLIENT_SECRET }} \
OpenAIAPIKey=${{ secrets.OPENAI_API_KEY }} \
GeminiAPIKey=${{ secrets.GEMINI_API_KEY }} \
SupabaseServiceKey=${{ secrets.SUPABASE_SERVICE_KEY }} \
SupabaseUrl=${{ secrets.SUPABASE_URL }} \
TavilyAPIKey=${{ secrets.TAVILY_API_KEY }} \
Expand Down
5 changes: 5 additions & 0 deletions petercat_utils/data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ class ImageURLContentBlock(BaseModel):
image_url: ImageURL
type: Literal["image_url"]

class ImageRawURLContentBlock(BaseModel):
image_url: str
type: Literal["image_url"]


class TextContentBlock(BaseModel):
text: str
Expand All @@ -42,6 +46,7 @@ class Message(BaseModel):

class ChatData(BaseModel):
messages: List[Message] = []
llm: Optional[str] = "openai"
prompt: Optional[str] = None
bot_id: Optional[str] = None

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "petercat_utils"
version = "0.1.28"
version = "0.1.30"
description = ""
authors = ["raoha.rh <[email protected]>"]
readme = "README.md"
Expand All @@ -24,7 +24,7 @@ md_report_color = "auto"
python = "^3.8"
langchain_community = "^0.2.11"
langchain_openai = "^0.1.20"
langchain_core = "0.2.28"
langchain_core = "^0.2.28"
langchain = "^0.2.12"
supabase = "2.6.0"
pydantic = "2.7.0"
Expand Down
35 changes: 10 additions & 25 deletions server/agent/base.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,42 @@
import json
from typing import AsyncIterator, Dict, Callable, Optional
from langchain.agents import AgentExecutor
from agent.llm.base import BaseLLMClient
from petercat_utils.data_class import ChatData, Message
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages,
)
from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain.prompts import MessagesPlaceholder
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
from langchain_community.tools.tavily_search.tool import TavilySearchResults
from langchain_openai import ChatOpenAI
from petercat_utils import get_env_variable

OPEN_API_KEY = get_env_variable("OPENAI_API_KEY")

TAVILY_API_KEY = get_env_variable("TAVILY_API_KEY")


class AgentBuilder:

def __init__(
self,
chat_model: BaseLLMClient,
prompt: str,
tools: Dict[str, Callable],
enable_tavily: Optional[bool] = True,
temperature: Optional[int] = 0.2,
max_tokens: Optional[int] = 1500,
streaming: Optional[bool] = False,
):
"""
@class `Builde AgentExecutor based on tools and prompt`
@param prompt: str
@param tools: Dict[str, Callable]
@param enable_tavily: Optional[bool] If set True, enables the Tavily tool
@param temperature: Optional[int]
@param max_tokens: Optional[int]
@param streaming: Optional[bool]
"""
self.prompt = prompt
self.tools = tools
self.enable_tavily = enable_tavily
self.temperature = temperature
self.max_tokens = max_tokens
self.streaming = streaming
self.chat_model = chat_model
self.agent_executor = self._create_agent_with_tools()

def init_tavily_tools(self):
Expand All @@ -54,21 +46,16 @@ def init_tavily_tools(self):
return [tavily_tool]

def _create_agent_with_tools(self) -> AgentExecutor:
llm = ChatOpenAI(
model_name="gpt-4o",
temperature=self.temperature,
streaming=self.streaming,
max_tokens=self.max_tokens,
openai_api_key=OPEN_API_KEY,
)
llm = self.chat_model.get_client()

tools = self.init_tavily_tools() if self.enable_tavily else []

for tool in self.tools.values():
tools.append(tool)

if tools:
llm = llm.bind_tools([convert_to_openai_tool(tool) for tool in tools])
parsed_tools = self.chat_model.get_tools(tools)
llm = llm.bind_tools(parsed_tools)

self.prompt = self.get_prompt()
agent = (
Expand Down Expand Up @@ -102,13 +89,11 @@ def get_prompt(self):
]
)

@staticmethod
def chat_history_transform(messages: list[Message]):
def chat_history_transform(self, messages: list[Message]):
transformed_messages = []
for message in messages:
print("message", message)
if message.role == "user":
transformed_messages.append(HumanMessage(content=message.content))
transformed_messages.append(HumanMessage(self.chat_model.parse_content(content=message.content)))
elif message.role == "assistant":
transformed_messages.append(AIMessage(content=message.content))
else:
Expand All @@ -120,7 +105,7 @@ async def run_stream_chat(self, input_data: ChatData) -> AsyncIterator[str]:
messages = input_data.messages
async for event in self.agent_executor.astream_events(
{
"input": messages[len(messages) - 1].content,
"input": self.chat_model.parse_content(messages[len(messages) - 1].content),
"chat_history": self.chat_history_transform(messages),
},
version="v1",
Expand Down
9 changes: 7 additions & 2 deletions server/agent/bot_builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import AsyncIterator, Optional
from agent.llm import get_llm
from petercat_utils.data_class import ChatData

from agent.base import AgentBuilder
Expand All @@ -13,10 +14,14 @@


def agent_stream_chat(
input_data: ChatData, user_id: str, bot_id: Optional[str] = None
input_data: ChatData,
user_id: str,
bot_id: Optional[str] = None,
llm: Optional[str] = "openai"
) -> AsyncIterator[str]:
prompt = generate_prompt_by_user_id(user_id, bot_id)
agent = AgentBuilder(
prompt=prompt, tools=TOOL_MAPPING, enable_tavily=False, streaming=True
chat_model=get_llm(llm=llm),
prompt=prompt, tools=TOOL_MAPPING, enable_tavily=False
)
return agent.run_stream_chat(input_data)
23 changes: 23 additions & 0 deletions server/agent/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Optional

from agent.llm.base import BaseLLMClient
from agent.llm.gemini import GeminiClient
from agent.llm.openai import OpenAIClient
from petercat_utils.utils.env import get_env_variable

OPEN_API_KEY = get_env_variable("OPENAI_API_KEY")
GEMINI_API_KEY = get_env_variable("GEMINI_API_KEY")

def get_llm(
llm: str = 'openai',
temperature: Optional[int] = 0.2,
max_tokens: Optional[int] = 1500,
streaming: Optional[bool] = False
) -> BaseLLMClient:

match llm:
case "openai":
return OpenAIClient(temperature=temperature, streaming=streaming, max_tokens=max_tokens)
case "gemini":
return GeminiClient(temperature=temperature,streaming=streaming, max_tokens=max_tokens)
return None
22 changes: 22 additions & 0 deletions server/agent/llm/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

from abc import abstractmethod
from typing import Any, Dict, List, Optional
from langchain_core.language_models import BaseChatModel

from petercat_utils.data_class import MessageContent

class BaseLLMClient():
def __init__(self, temperature: Optional[int] = 0.2, max_tokens: Optional[int] = 1500, streaming: Optional[bool] = False):
pass

@abstractmethod
def get_client() -> BaseChatModel:
pass

@abstractmethod
def get_tools(self, tool: List[Any]) -> list[Dict[str, Any]]:
pass

@abstractmethod
def parse_content(self, content: List[MessageContent]) -> List[MessageContent]:
pass
40 changes: 40 additions & 0 deletions server/agent/llm/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any, List, Optional
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_genai.chat_models import convert_to_genai_function_declarations
from langchain_openai import ChatOpenAI

from agent.llm.base import BaseLLMClient
from petercat_utils.data_class import ImageRawURLContentBlock, MessageContent
from petercat_utils.utils.env import get_env_variable

GEMINI_API_KEY = get_env_variable("GEMINI_API_KEY")

def parse_gemini_input(message: MessageContent):
match message.type:
case "image_url":
return ImageRawURLContentBlock(image_url=message.image_url.url, type="image_url")
case _:
return message

class GeminiClient(BaseLLMClient):
_client: ChatOpenAI

def __init__(self, temperature: Optional[int] = 0.2, max_tokens: Optional[int] = 1500, streaming: Optional[bool] = False):
self._client = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
temperature=temperature,
streaming=streaming,
max_tokens=max_tokens,
google_api_key=GEMINI_API_KEY,
)

def get_client(self):
return self._client

def get_tools(self, tools: List[Any]):
return [convert_to_genai_function_declarations(tool) for tool in tools]

def parse_content(self, content: List[MessageContent]):
result = [parse_gemini_input(message=message) for message in content]
print(f"parse_content, content={content}, result={result}")
return result
32 changes: 32 additions & 0 deletions server/agent/llm/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Any, List, Optional
from langchain_openai import ChatOpenAI
from langchain_core.utils.function_calling import convert_to_openai_tool

from agent.llm.base import BaseLLMClient
from petercat_utils.data_class import MessageContent
from petercat_utils.utils.env import get_env_variable


OPEN_API_KEY = get_env_variable("OPENAI_API_KEY")

class OpenAIClient(BaseLLMClient):
_client: ChatOpenAI

def __init__(self, temperature: Optional[int] = 0.2, max_tokens: Optional[int] = 1500, streaming: Optional[bool] = False):
self._client = ChatOpenAI(
model_name="gpt-4o",
temperature=temperature,
streaming=streaming,
max_tokens=max_tokens,
openai_api_key=OPEN_API_KEY,
)

def get_client(self):
return self._client

def get_tools(self, tools: List[Any]):
return [convert_to_openai_tool(tool) for tool in tools]

def parse_content(self, content: List[MessageContent]):
print(f"parse_content: {content}")
return content
54 changes: 20 additions & 34 deletions server/agent/qa_chat.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,44 @@
from typing import AsyncIterator, Optional
from petercat_utils import get_client
from petercat_utils.data_class import ChatData

from agent.base import AgentBuilder
from agent.llm import get_llm
from dao.botDAO import BotDAO
from models.bot import Bot
from prompts.bot_template import generate_prompt_by_repo_name
from petercat_utils.data_class import ChatData

from tools import issue, sourcecode, knowledge, git_info


def get_tools(bot_id: str, token: Optional[str]):
def get_tools(bot: Bot, token: Optional[str]):
issue_tools = issue.factory(access_token=token)
return {
"search_knowledge": knowledge.factory(bot_id=bot_id),
"search_knowledge": knowledge.factory(bot_id=bot.id),
"create_issue": issue_tools["create_issue"],
"get_issues": issue_tools["get_issues"],
"search_issues": issue_tools["search_issues"],
"search_code": sourcecode.search_code,
"search_repo": git_info.search_repo,
}


def init_prompt(input_data: ChatData):
if input_data.prompt:
prompt = input_data.prompt
elif input_data.bot_id:
try:
supabase = get_client()
res = (
supabase.table("bots")
.select("prompt")
.eq("id", input_data.bot_id)
.execute()
)
prompt = res.data[0]["prompt"]
except Exception as e:
print(e)
prompt = generate_prompt_by_repo_name("ant-design")
else:
prompt = generate_prompt_by_repo_name("ant-design")

return prompt


def agent_stream_chat(input_data: ChatData, user_token: str) -> AsyncIterator[str]:
bot_dao = BotDAO()
bot = bot_dao.get_bot(input_data.bot_id)

agent = AgentBuilder(
prompt=init_prompt(input_data),
tools=get_tools(bot_id=input_data.bot_id, token=user_token),
streaming=True,
chat_model=get_llm(bot.llm),
prompt=bot.prompt or generate_prompt_by_repo_name("ant-design"),
tools=get_tools(bot=bot, token=user_token),
)
return agent.run_stream_chat(input_data)


def agent_chat(input_data: ChatData, user_token: Optional[str]) -> AsyncIterator[str]:
def agent_chat(input_data: ChatData, user_token: Optional[str], llm: Optional[str] = "openai") -> AsyncIterator[str]:
bot_dao = BotDAO()
bot = bot_dao.get_bot(input_data.bot_id)

agent = AgentBuilder(
prompt=init_prompt(input_data),
tools=get_tools(input_data.bot_id, token=user_token),
chat_model=get_llm(bot.llm),
prompt=bot.prompt or generate_prompt_by_repo_name("ant-design"),
tools=get_tools(bot, token=user_token),
)
return agent.run_chat(input_data)
Loading

0 comments on commit 9f9707e

Please sign in to comment.