diff --git a/requirements.txt b/requirements.txt index f149a60..e387fee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ click==8.1.7 ; python_version >= "3.10" and python_version < "4.0" colorama==0.4.6 ; python_version >= "3.10" and python_version < "4.0" and platform_system == "Windows" dataclasses-json==0.6.7 ; python_version >= "3.10" and python_version < "4.0" distro==1.9.0 ; python_version >= "3.10" and python_version < "4.0" -duckduckgo-search==6.3.7 ; python_version >= "3.10" and python_version < "4.0" +duckduckgo-search==7.2.1 ; python_version >= "3.10" and python_version < "4.0" exceptiongroup==1.2.2 ; python_version >= "3.10" and python_version < "3.11" frozenlist==1.4.1 ; python_version >= "3.10" and python_version < "4.0" greenlet==3.1.1 ; python_version < "3.13" and (platform_machine == "aarch64" or platform_machine == "ppc64le" or platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "win32" or platform_machine == "WIN32") and python_version >= "3.10" @@ -28,9 +28,10 @@ jsonpointer==3.0.0 ; python_version >= "3.10" and python_version < "4.0" langchain-community==0.3.13 ; python_version >= "3.10" and python_version < "4.0" langchain-core==0.3.28 ; python_version >= "3.10" and python_version < "4.0" langchain-openai==0.2.14 ; python_version >= "3.10" and python_version < "4.0" -langchain-text-splitters==0.3.5 ; python_version >= "3.10" and python_version < "4.0" +langchain-text-splitters==0.3.4 ; python_version >= "3.10" and python_version < "4.0" langchain==0.3.13 ; python_version >= "3.10" and python_version < "4.0" langsmith==0.1.147 ; python_version >= "3.10" and python_version < "4.0" +lxml==5.3.0 ; python_version >= "3.10" and python_version < "4.0" marshmallow==3.22.0 ; python_version >= "3.10" and python_version < "4.0" multidict==6.1.0 ; python_version >= "3.10" and python_version < "4.0" mypy-extensions==1.0.0 ; python_version >= "3.10" and python_version < "4.0" @@ -38,7 +39,7 @@ numpy==1.26.4 ; python_version >= "3.10" and python_version < "4.0" openai==1.58.1 ; python_version >= "3.10" and python_version < "4.0" orjson==3.10.7 ; python_version >= "3.10" and python_version < "4.0" and platform_python_implementation != "PyPy" packaging==24.1 ; python_version >= "3.10" and python_version < "4.0" -primp==0.8.1 ; python_version >= "3.10" and python_version < "4.0" +primp==0.10.0 ; python_version >= "3.10" and python_version < "4.0" propcache==0.2.0 ; python_version >= "3.10" and python_version < "4.0" pydantic-core==2.23.4 ; python_version >= "3.10" and python_version < "4.0" pydantic-settings==2.5.2 ; python_version >= "3.10" and python_version < "4.0" diff --git a/src/sc_system_ai/agents/classify_agent.py b/src/sc_system_ai/agents/classify_agent.py index a10d377..f03aac5 100644 --- a/src/sc_system_ai/agents/classify_agent.py +++ b/src/sc_system_ai/agents/classify_agent.py @@ -1,14 +1,17 @@ -from collections.abc import Iterator -from typing import Any, cast +from collections.abc import AsyncIterator +from typing import cast from langchain_openai import AzureChatOpenAI # from sc_system_ai.agents.tools import magic_function -from sc_system_ai.agents.search_school_data_agent import SearchSchoolDataAgentResponse from sc_system_ai.agents.tools.calling_dummy_agent import calling_dummy_agent -from sc_system_ai.agents.tools.calling_search_school_data_agent import calling_search_school_data_agent +from sc_system_ai.agents.tools.calling_search_school_data_agent import ( + CallingSearchSchoolDataAgent, + calling_search_school_data_agent, +) +from sc_system_ai.agents.tools.calling_small_talk_agent import calling_small_talk_agent from sc_system_ai.agents.tools.classify_role import classify_role -from sc_system_ai.template.agent import Agent, AgentResponse +from sc_system_ai.template.agent import Agent, AgentResponse, StreamingAgentResponse from sc_system_ai.template.ai_settings import llm from sc_system_ai.template.calling_agent import CallingAgent from sc_system_ai.template.user_prompts import User @@ -17,14 +20,14 @@ # magic_function, classify_role, calling_dummy_agent, - calling_search_school_data_agent + calling_search_school_data_agent, + calling_small_talk_agent, ] classify_agent_info = """あなたの役割は適切なエージェントを選択し処理を引き継ぐことです。 +あなたがユーザーと会話を行ってはいけません。 ユーザーの入力、会話の流れから適切なエージェントを選択してください。 引き継いだエージェントが処理を完了するまで、そのエージェントがユーザーと会話を続けるようにしてください。 - -適切なエージェントの選択、呼び出しができなかった場合は、そのままユーザーとの会話を続けてください。 """ # agentクラスの作成 @@ -35,14 +38,10 @@ def __init__( self, llm: AzureChatOpenAI = llm, user_info: User | None = None, - is_streaming: bool = True, - return_length: int = 5 ): super().__init__( llm=llm, user_info=user_info if user_info is not None else User(), - is_streaming=is_streaming, - return_length=return_length ) self.assistant_info = classify_agent_info super().set_assistant_info(self.assistant_info) @@ -53,16 +52,33 @@ def set_tools(self, tools: list) -> None: for tool in tools: if isinstance(tool, CallingAgent): tool.set_user_info(self.user_info) - super().set_tools(tools) - def invoke(self, message: str) -> Iterator[str | AgentResponse | SearchSchoolDataAgentResponse]: - if self.is_streaming: - yield from super().invoke(message) - else: - # ツールの出力をそのまま返却 - resp = cast(dict[str, Any], next(super().invoke(message))) - yield resp["output"] + def invoke(self, message: str) -> AgentResponse: + # toolの出力がAgentReaponseで返って来るので整形 + for tool in self.tool.tools: + if isinstance(tool, CallingAgent): + tool.cancel_streaming() + resp = super().invoke(message) + resp.document_id = self._doc_id_checker() + return resp + + def _doc_id_checker(self) -> list[str] | None: + """ + ドキュメントIDが存在するか確認する + """ + for tool in self.tool.tools: + if isinstance(tool, CallingSearchSchoolDataAgent): + if tool.document_id is not None: + return tool.document_id + return None + + async def stream(self, message: str, return_length: int = 5) -> AsyncIterator[StreamingAgentResponse]: + for tool in self.tool.tools: + if isinstance(tool, CallingAgent): + tool.setup_streaming(self.queue) + async for output in super().stream(message, return_length): + yield output if __name__ == "__main__": @@ -79,7 +95,7 @@ def invoke(self, message: str) -> Iterator[str | AgentResponse | SearchSchoolDat user_info.conversations.add_conversations_list(history) while True: - classify_agent = ClassifyAgent(user_info=user_info, is_streaming=False) + classify_agent = ClassifyAgent(user_info=user_info) # classify_agent.display_agent_info() # print(main_agent.get_agent_prompt()) # classify_agent.display_agent_prompt() @@ -89,7 +105,7 @@ def invoke(self, message: str) -> Iterator[str | AgentResponse | SearchSchoolDat break # 通常の呼び出し - resp = next(classify_agent.invoke(user)) + resp = classify_agent.invoke(user) print(resp) # ストリーミング呼び出し @@ -97,9 +113,9 @@ def invoke(self, message: str) -> Iterator[str | AgentResponse | SearchSchoolDat # print(output) # resp = classify_agent.get_response() - if type(resp) is dict: + if type(resp) is AgentResponse: new_conversation = [ ("human", user), - ("ai", resp["output"]) + ("ai", cast(str,resp.output)) ] user_info.conversations.add_conversations_list(new_conversation) diff --git a/src/sc_system_ai/agents/dummy_agent.py b/src/sc_system_ai/agents/dummy_agent.py index b864df3..93c7363 100644 --- a/src/sc_system_ai/agents/dummy_agent.py +++ b/src/sc_system_ai/agents/dummy_agent.py @@ -1,4 +1,5 @@ # ダミーのエージェント +from typing import cast from langchain_openai import AzureChatOpenAI @@ -49,14 +50,10 @@ def __init__( self, llm: AzureChatOpenAI = llm, user_info: User | None = None, - is_streaming: bool = True, - return_length: int = 5 ): super().__init__( llm=llm, user_info=user_info if user_info is not None else User(), - is_streaming=is_streaming, - return_length=return_length ) self.assistant_info = dummy_agent_info super().set_assistant_info(self.assistant_info) @@ -96,9 +93,9 @@ def __init__( print(output) resp = dummy_agent.get_response() - if type(resp) is dict: + if resp.error is not None: new_conversation = [ ("human", user), - ("ai", resp["output"]) + ("ai", cast(str, resp.output)) ] user_info.conversations.add_conversations_list(new_conversation) diff --git a/src/sc_system_ai/agents/main_agent.py b/src/sc_system_ai/agents/main_agent.py index 92888e2..ab76f31 100644 --- a/src/sc_system_ai/agents/main_agent.py +++ b/src/sc_system_ai/agents/main_agent.py @@ -16,14 +16,10 @@ def __init__( self, llm: AzureChatOpenAI = llm, user_info: User | None = None, - is_streaming: bool = True, - return_length: int = 5 ): super().__init__( llm=llm, user_info=user_info if user_info is not None else User(), - is_streaming=is_streaming, - return_length=return_length ) self.assistant_info = main_agent_info super().set_assistant_info(self.assistant_info) @@ -43,9 +39,9 @@ def __init__( user_info = User(name=user_name, major=user_major) user_info.conversations.add_conversations_list(history) - main_agent = MainAgent(user_info=user_info, is_streaming=False) + main_agent = MainAgent(user_info=user_info) main_agent.display_agent_info() # print(main_agent.get_agent_prompt()) main_agent.display_agent_prompt() - print(next(main_agent.invoke("magic function に3をいれて"))) + print(main_agent.invoke("magic function に3をいれて")) diff --git a/src/sc_system_ai/agents/search_school_data_agent.py b/src/sc_system_ai/agents/search_school_data_agent.py index d780c31..e544439 100644 --- a/src/sc_system_ai/agents/search_school_data_agent.py +++ b/src/sc_system_ai/agents/search_school_data_agent.py @@ -1,11 +1,10 @@ -from collections.abc import Iterator -from typing import cast +from collections.abc import AsyncIterator from langchain_openai import AzureChatOpenAI # from sc_system_ai.agents.tools import magic_function from sc_system_ai.agents.tools.search_school_data import search_school_database_cosmos -from sc_system_ai.template.agent import Agent, AgentResponse +from sc_system_ai.template.agent import Agent, AgentResponse, StreamingAgentResponse from sc_system_ai.template.ai_settings import llm from sc_system_ai.template.user_prompts import User @@ -19,9 +18,6 @@ ## 学校の情報 """ -class SearchSchoolDataAgentResponse(AgentResponse): - document_id: list[str] - # agentクラスの作成 class SearchSchoolDataAgent(Agent): @@ -29,32 +25,34 @@ def __init__( self, llm: AzureChatOpenAI = llm, user_info: User | None = None, - is_streaming: bool = True, - return_length: int = 5 ): super().__init__( llm=llm, user_info=user_info if user_info is not None else User(), - is_streaming=is_streaming, - return_length=return_length ) self.assistant_info = search_school_data_agent_info - def invoke(self, message: str) -> Iterator[SearchSchoolDataAgentResponse]: - # Agentクラスのストリーミングを改修後にストリーミング実装 - self.cancel_streaming() + def _add_search_result(self, message: str) -> list[str]: search = search_school_database_cosmos(message) ids = [] for doc in search: self.assistant_info += f"### {doc.metadata['title']}\n" + doc.page_content + "\n" ids.append(doc.metadata["id"]) super().set_assistant_info(self.assistant_info) + return ids + + def invoke(self, message: str) -> AgentResponse: + # Agentクラスのストリーミングを改修後にストリーミング実装 + ids = self._add_search_result(message) + resp = super().invoke(message) + resp.document_id = ids + return resp - resp = cast(AgentResponse, next(super().invoke(message))) - yield { - **resp, - "document_id": ids - } + async def stream(self, message: str, return_length: int = 5) -> AsyncIterator[StreamingAgentResponse]: + ids = self._add_search_result(message) + async for resp in super().stream(message, return_length): + yield resp + self.result.document_id = ids if __name__ == "__main__": from sc_system_ai.logging_config import setup_logging @@ -68,5 +66,5 @@ def invoke(self, message: str) -> Iterator[SearchSchoolDataAgentResponse]: ] user_info = User(name=user_name, major=user_major) user_info.conversations.add_conversations_list(history) - agent = SearchSchoolDataAgent(user_info=user_info, is_streaming=False) - print(next(agent.invoke("京都テックについて教えて"))) + agent = SearchSchoolDataAgent(user_info=user_info) + print(agent.invoke("京都テックについて教えて")) diff --git a/src/sc_system_ai/agents/small_talk_agent.py b/src/sc_system_ai/agents/small_talk_agent.py new file mode 100644 index 0000000..38b5897 --- /dev/null +++ b/src/sc_system_ai/agents/small_talk_agent.py @@ -0,0 +1,49 @@ +from langchain_openai import AzureChatOpenAI + +from sc_system_ai.agents.tools import magic_function +from sc_system_ai.template.agent import Agent +from sc_system_ai.template.ai_settings import llm +from sc_system_ai.template.user_prompts import User + +main_agent_tools = [magic_function] +main_agent_info = """あなたの役割はユーザーと雑談を行うことです。 +ユーザーが楽しめるような会話になるようにしてください。 +""" + +# agentクラスの作成 + + +class SmallTalkAgent(Agent): + def __init__( + self, + llm: AzureChatOpenAI = llm, + user_info: User | None = None, + ): + super().__init__( + llm=llm, + user_info=user_info if user_info is not None else User(), + ) + self.assistant_info = main_agent_info + super().set_assistant_info(self.assistant_info) + super().set_tools(main_agent_tools) + + +if __name__ == "__main__": + from sc_system_ai.logging_config import setup_logging + setup_logging() + # ユーザー情報 + user_name = "hogehoge" + user_major = "fugafuga専攻" + history = [ + ("human", "こんにちは!"), + ("ai", "本日はどのようなご用件でしょうか?") + ] + user_info = User(name=user_name, major=user_major) + user_info.conversations.add_conversations_list(history) + + agent = SmallTalkAgent(user_info=user_info) + agent.display_agent_info() + # print(main_agent.get_agent_prompt()) + agent.display_agent_prompt() + print(agent.invoke("magic function に3をいれて")) + diff --git a/src/sc_system_ai/agents/tools/calling_search_school_data_agent.py b/src/sc_system_ai/agents/tools/calling_search_school_data_agent.py index 2ce47f1..c4aaaf1 100644 --- a/src/sc_system_ai/agents/tools/calling_search_school_data_agent.py +++ b/src/sc_system_ai/agents/tools/calling_search_school_data_agent.py @@ -10,6 +10,9 @@ class CallingSearchSchoolDataAgent(CallingAgent): + # tool側でidを保持する + document_id: list[str] | None = None + def __init__(self) -> None: super().__init__() self.set_tool_info( @@ -18,6 +21,11 @@ def __init__(self) -> None: agent=SearchSchoolDataAgent ) + def _run(self, user_input: str) -> str: + resp = super()._run(user_input) + self.document_id = self.response.document_id if self.response is not None else None + return resp + calling_search_school_data_agent = CallingSearchSchoolDataAgent() if __name__ == "__main__": diff --git a/src/sc_system_ai/agents/tools/calling_small_talk_agent.py b/src/sc_system_ai/agents/tools/calling_small_talk_agent.py new file mode 100644 index 0000000..888eef4 --- /dev/null +++ b/src/sc_system_ai/agents/tools/calling_small_talk_agent.py @@ -0,0 +1,26 @@ +import logging + +from sc_system_ai.agents.small_talk_agent import SmallTalkAgent +from sc_system_ai.template.calling_agent import CallingAgent +from sc_system_ai.template.user_prompts import User + +logger = logging.getLogger(__name__) + + +class CallingSmallTalkAgent(CallingAgent): + def __init__(self) -> None: + super().__init__() + self.set_tool_info( + name="calling_small_talk_agent", + description="雑談エージェントを呼び出すツール", + agent=SmallTalkAgent, + ) + +calling_small_talk_agent = CallingSmallTalkAgent() + +if __name__ == "__main__": + from sc_system_ai.logging_config import setup_logging + setup_logging() + + calling_small_talk_agent.set_user_info(User(name="hogehoge", major="fugafuga専攻")) + print(calling_small_talk_agent.invoke({"user_input": "こんにちは"})) diff --git a/src/sc_system_ai/main.py b/src/sc_system_ai/main.py index 6af0a7c..aa0deb8 100644 --- a/src/sc_system_ai/main.py +++ b/src/sc_system_ai/main.py @@ -58,9 +58,9 @@ """ import logging -from collections.abc import Iterator +from collections.abc import AsyncIterator from importlib import import_module -from typing import Literal, TypedDict, cast +from typing import Literal, TypedDict from sc_system_ai.template.agent import Agent from sc_system_ai.template.ai_settings import llm @@ -68,13 +68,18 @@ logger = logging.getLogger(__name__) -AGENT = Literal["classify", "dummy", "search_school_data"] +AGENT = Literal["classify", "dummy", "search_school_data", "small_talk"] class Response(TypedDict): output: str | None error: str | None document_id: list[str] | None +class StreamResponse(TypedDict): + output: str | None + error: str | None + status: str | None + class Chat: """Chatクラス ユーザー情報と会話履歴を保持し、エージェントとのチャットを行うクラス @@ -114,16 +119,12 @@ def __init__( user_name: str, user_major: str, conversation: list[tuple[str, str]] | None = None, - is_streaming: bool = True, - return_length: int = 5 ) -> None: self.user = User(name=user_name, major=user_major) if conversation is None: conversation = [] self.user.conversations.add_conversations_list(conversation) - self.is_streaming = is_streaming - self.return_length = return_length self._agent: Agent | None = None @property @@ -144,7 +145,7 @@ def invoke( self, message: str, command: AGENT = "classify" - ) -> Iterator[Response]: + ) -> Response: """エージェントを呼び出し、チャットを行う関数 Args: @@ -164,13 +165,46 @@ def invoke( - dummy: ダミーエージェント """ self._call_agent(command) - if self.is_streaming: - for resp in self.agent.invoke(message): - if type(resp) is str: - yield self._create_response({"output": resp}) - else: - resp = next(self.agent.invoke(message)) - yield self._create_response(cast(dict, resp)) + resp = self.agent.invoke(message) + return { + "output": resp.output, + "error": resp.error, + "document_id": resp.document_id + } + + async def stream( + self, + message: str, + return_length: int = 5, + command: AGENT = "classify" + ) -> AsyncIterator[StreamResponse]: + """エージェントを呼び出し、ストリーミングチャットを行う関数 + + Args: + message (str): メッセージ + return_length (int, optional): ストリーミングモード時の返答数. デフォルトは5 + command (AGENT, optional): 呼び出すエージェント。デフォルトでは分類エージェントを呼び出します。 + + Returns: + Iterator[str]: エージェントからの返答 + + コマンドでエージェントを指定して、エージェントを呼び出す場合。 + ```python + for resp in chat.stream(message="私の名前と専攻は何ですか?", command="dummy"): + print(resp) + ``` + + 呼び出し可能なエージェント: + - classify: 分類エージェント + - dummy: ダミーエージェント + """ + self._call_agent(command) + async for resp in self.agent.stream(message, return_length): + yield { + "output": resp.output, + "error": resp.error, + "status": resp.status + } def _call_agent(self, command: AGENT) -> None: try: @@ -182,21 +216,11 @@ def _call_agent(self, command: AGENT) -> None: self.agent = agent_class( llm=llm, user_info=self.user, - is_streaming=self.is_streaming, - return_length=self.return_length ) except (ModuleNotFoundError, AttributeError, ValueError): logger.error(f"エージェントが見つかりません: {command}") raise ValueError(f"エージェントが見つかりません: {command}") from None - def _create_response(self, resp: dict) -> Response: - return { - "output": resp.get("output"), - "error": resp.get("error"), - "document_id": resp.get("document_id") - } - - def static_chat() -> None: # ユーザー情報 @@ -212,16 +236,16 @@ def static_chat() -> None: user.conversations.add_conversations_list(conversation) # エージェントの設定 - agent = Agent(llm=llm, user_info=user, is_streaming=False) + agent = Agent(llm=llm, user_info=user) # メッセージを送信 message = "私の名前と専攻は何ですか?" - resp = next(agent.invoke(message)) + resp = agent.invoke(message) print(resp) -def streaming_chat() -> None: +async def streaming_chat() -> None: # ユーザー情報 user_name = "hogehoge" user_major = "fugafuga専攻" @@ -235,15 +259,17 @@ def streaming_chat() -> None: user.conversations.add_conversations_list(conversation) # エージェントの設定 - agent = Agent(llm=llm, user_info=user, is_streaming=True) + agent = Agent(llm=llm, user_info=user) # メッセージを送信 message = "私の名前と専攻は何ですか?" - for resp in agent.invoke(message): + async for resp in agent.stream(message): print(resp) if __name__ == "__main__": + import asyncio + from sc_system_ai.logging_config import setup_logging setup_logging() @@ -257,7 +283,6 @@ def streaming_chat() -> None: ("human", "こんにちは!"), ("ai", "本日はどのようなご用件でしょうか?") ], - is_streaming=False, ) message = "私の名前と専攻は何ですか?" @@ -267,11 +292,12 @@ def streaming_chat() -> None: # pass # # 通常呼び出し - resp = next(chat.invoke(message=message, command="dummy")) - print(resp) + # resp = chat.invoke(message=message, command="dummy") + # print(resp) # ストリーミング呼び出し - # chat.is_streaming = True - # for r in chat.invoke(message=message, command="dummy"): - # print(r) - # chat.agent.get_response() + async def stream() -> None: + async for r in chat.stream(message="京都テックについて教えて"): + print(r) + asyncio.run(stream()) + diff --git a/src/sc_system_ai/template/agent.py b/src/sc_system_ai/template/agent.py index 30841f3..7a79ff3 100644 --- a/src/sc_system_ai/template/agent.py +++ b/src/sc_system_ai/template/agent.py @@ -7,16 +7,17 @@ """ import logging -from collections.abc import Iterator +from collections.abc import AsyncIterator from queue import Queue from threading import Thread -from typing import Any, TypedDict, TypeGuard +from typing import Any, Literal from langchain.agents import AgentExecutor, create_tool_calling_agent from langchain.tools import BaseTool from langchain_community.tools import DuckDuckGoSearchRun from langchain_core.messages import AIMessage, HumanMessage from langchain_openai import AzureChatOpenAI +from pydantic import BaseModel from sc_system_ai.agents.tools import magic_function, search_duckduckgo from sc_system_ai.template.ai_settings import llm @@ -39,39 +40,20 @@ class ToolManager: """ def __init__( self, + queue: Queue, tools: list | None = None, - is_streaming: bool = True, - queue: Queue | None = None, ): self.tools: list[BaseTool] = [] - self._is_streaming = is_streaming - self.queue = queue if queue is not None else Queue() - + self.queue = queue + self.handler = StreamingToolHandler(self.queue) if tools is not None: self.set_tools(tools) - @property - def is_streaming(self) -> bool: - return self._is_streaming - - @is_streaming.setter - def is_streaming(self, is_streaming: bool) -> None: - self._is_streaming = is_streaming - - if self._is_streaming: - self.tools = self.setup_streaming(self.tools) - else: - self.cancel_streaming() - - def setup_streaming(self, tools: list[BaseTool]) -> list[BaseTool]: + def setup_streaming(self) -> None: """ストリーミングのセットアップを行う関数""" - self.handler = StreamingToolHandler(self.queue) - - for tool in tools: + for tool in self.tools: tool.callbacks = [self.handler] - return tools - def cancel_streaming(self) -> None: """ストリーミングのセットアップを解除する関数""" for tool in self.tools: @@ -110,14 +92,22 @@ def _tool_checker(self, tool: Any) -> bool: tools: {tools} ------------------- """ - # Agentのレスポンスの型 -class AgentResponse(TypedDict, total=False): +class BaseAgentResponse(BaseModel): + """Agentのレスポンスの型""" + output: str | None = None + error: str | None = None + +class AgentResponse(BaseAgentResponse): """Agentのレスポンスの型""" - chat_history: list[HumanMessage | AIMessage] - messages: str - output: str - error: str + chat_history: list[HumanMessage | AIMessage] | None = None + messages: str | None = None + document_id: list[str] | None = None + + +class StreamingAgentResponse(BaseAgentResponse): + """Agentのストリーミングレスポンスの型""" + status: Literal["processing", "completed", "error"] | None = None # Agentクラスの作成 class Agent: @@ -134,65 +124,34 @@ def __init__( self, llm: AzureChatOpenAI = llm, user_info: User | None = None, - is_streaming: bool = True, - return_length: int = 5 ): self.llm = llm self.user_info = user_info if user_info is not None else User() self.result: AgentResponse - self._is_streaming = is_streaming - self._return_length = return_length self.queue: Queue = Queue() + self.handler = StreamingAgentHandler(self.queue) # assistant_infoとtoolsは各エージェントで設定する self.assistant_info = "" - self.tool = ToolManager(tools=template_tools, is_streaming=self._is_streaming, queue=self.queue) + self.tool = ToolManager(tools=template_tools, queue=self.queue) self.prompt_template = PromptTemplate(assistant_info=self.assistant_info, user_info=self.user_info) - if self._is_streaming: - self.setup_streaming() - self.get_agent_info() - @property - def is_streaming(self) -> bool: - return self._is_streaming - - @is_streaming.setter - def is_streaming(self, is_streaming: bool) -> None: - self._is_streaming = is_streaming - self.tool.is_streaming = is_streaming - - if self._is_streaming: - self.setup_streaming() - else: - self.cancel_streaming() - - @property - def return_length(self) -> int: - return self._return_length - - @return_length.setter - def return_length(self, return_length: int) -> None: - if return_length <= 0: - raise ValueError("return_lengthは1以上の整数である必要があります。") - self._return_length = return_length - def setup_streaming(self) -> None: """ストリーミング時のセットアップを行う関数""" self.clear_queue() - self.handler = StreamingAgentHandler(self.queue) - - # llmの設定 self.llm.streaming = True self.llm.callbacks = [self.handler] + self.tool.setup_streaming() def cancel_streaming(self) -> None: """ストリーミング時のセットアップを解除する関数""" self.llm.streaming = False self.llm.callbacks = None + self.tool.cancel_streaming() def clear_queue(self) -> None: """キューをクリアする関数""" @@ -208,7 +167,7 @@ def set_tools(self, tools: list) -> None: """ツールを設定する関数""" self.tool.set_tools(tools) - def invoke(self, message: str) -> Iterator[str | AgentResponse]: + def invoke(self, message: str) -> AgentResponse: """ エージェントを実行する関数 @@ -226,82 +185,119 @@ def invoke(self, message: str) -> Iterator[str | AgentResponse]: resp = next(agent.invoke("user message")) ``` """ + self.cancel_streaming() + self._invoke(message, False) + return self.get_response() + + async def stream( + self, + message: str, + return_length: int = 5 + ) -> AsyncIterator[StreamingAgentResponse]: + """ + エージェントをストリーミングで実行する関数 + + Args: + message (str): ユーザーからのメッセージ + + ```python + for output in agent.stream("user message"): + print(output) + ``` + """ + self.setup_streaming() + phrase = "" + thread = Thread(target=self._invoke, args=(message, True,)) + thread.start() + try: + while True: + if self.queue.empty(): + continue + + token = self.handler.queue.get() + if token is None: + logger.debug("エージェントの実行が終了しました。") + break + phrase += token + if len(phrase) >= return_length: + yield StreamingAgentResponse( + output=phrase, error=None, status="processing" + ) + phrase = "" + except Exception as e: + logger.error(f"エラーが発生しました:{e}") + yield StreamingAgentResponse( + output=None, error=f"エラーが発生しました:{e}", status="error" + ) + + if thread and thread.is_alive(): + thread.join() + yield StreamingAgentResponse(output=phrase, error=None, status="completed") + + def _invoke(self, message: str, streaming: bool) -> None: agent = create_tool_calling_agent( llm=self.llm, tools=self.tool.tools, prompt=self.prompt_template.full_prompt ) - self.agent_executor = AgentExecutor( + agent_executor = AgentExecutor( agent=agent, tools=self.tool.tools, - callbacks= [self.handler] if self._is_streaming else None + callbacks= [self.handler] if streaming else None ) - - if self._is_streaming: - yield from self._streaming_invoke(message) - else: - self._invoke(message) - if "error" in self.result: - yield self.result["error"] - else: - yield self.result - - def _invoke(self, message: str) -> None: try: # エージェントの実行 logger.info("エージェントの実行を開始します。\n-------------------\n") logger.debug(f"最終的なプロンプト: {self.prompt_template.full_prompt.messages}") - resp = self.agent_executor.invoke({ + resp = agent_executor.invoke({ "chat_history": self.user_info.conversations.format_conversation(), "messages": message, }) - if self._response_checker(resp): - self.result = resp + if "output" in resp: + self.result = AgentResponse( + chat_history=resp.get("chat_history"), + messages=resp.get("messages"), + output=resp.get("output"), + ) else: logger.error("エージェントの実行結果取得に失敗しました。") logger.debug(f"エージェントの実行結果: {resp}") raise RuntimeError("エージェントの実行結果取得に失敗しました。") except Exception as e: logger.error(f"エージェントの実行に失敗しました。エラー内容: {e}") - self.result = {"error": f"エージェントの実行に失敗しました。エラー内容: {e}"} - - def _response_checker(self, response: Any) -> TypeGuard[AgentResponse]: - """レスポンスの型チェック""" - if type(response) is dict: - if all(key in response for key in ["chat_history", "messages", "output"]): - return True - return False + self.result = AgentResponse(error=f"エージェントの実行に失敗しました。エラー内容: {e}") - def _streaming_invoke(self, message: str) -> Iterator[str]: - phrase = "" - thread = Thread(target=self._invoke, args=(message,)) - - thread.start() - try: - while True: - token = self.handler.queue.get() - if token is None: - break - - phrase += token - if len(phrase) >= self._return_length: - yield phrase - phrase = "" - except Exception as e: - logger.error(f"エラーが発生しました:{e}") - finally: - yield phrase + async def stream_on_tool(self, message: str) -> None: + """ツール上でストリーミングでエージェントを実行する関数""" + self.handler.queue = self.queue + self.setup_streaming() + agent = create_tool_calling_agent( + llm=self.llm, + tools=self.tool.tools, + prompt=self.prompt_template.full_prompt + ) + agent_executor = AgentExecutor( + agent=agent, + tools=self.tool.tools, + callbacks= [self.handler], + ) + resp = await agent_executor.ainvoke({ + "chat_history": self.user_info.conversations.format_conversation(), + "messages": message, + }) + self.result = AgentResponse( + chat_history=resp.get("chat_history"), + messages=resp.get("messages"), + output=resp.get("output"), + ) - #クリーンアップ - if thread and thread.is_alive(): - thread.join() def get_response(self) -> AgentResponse: """エージェントのレスポンスを取得する関数""" try: resp = self.result except AttributeError: - return {"error": "エージェントの実行結果がありません。"} + return AgentResponse(error="エージェントの実行結果がありません。") else: return resp @@ -330,6 +326,8 @@ def display_agent_prompt(self) -> None: if __name__ == "__main__": + import asyncio + from sc_system_ai.logging_config import setup_logging setup_logging() # ユーザー情報 @@ -346,15 +344,15 @@ def display_agent_prompt(self) -> None: agent = Agent( user_info=user_info, llm=llm, - is_streaming=False ) agent.assistant_info = "あなたは優秀な校正者です。" agent.tool.set_tools(tools) - result = next(agent.invoke("magic function に3")) - print(result) + # result = agent.invoke("magic function に3") + # print(result) + + async def main() -> None: + async for output in agent.stream("magic function に3", 5): + print(output.output) - agent.is_streaming = True - for output in agent.invoke("magic function に3"): - print(output) - print(agent.get_response()) + asyncio.run(main()) diff --git a/src/sc_system_ai/template/calling_agent.py b/src/sc_system_ai/template/calling_agent.py index fadce2d..cc5603d 100644 --- a/src/sc_system_ai/template/calling_agent.py +++ b/src/sc_system_ai/template/calling_agent.py @@ -1,10 +1,12 @@ +import asyncio import logging -from typing import Any, cast +from queue import Queue +from typing import cast from langchain_core.tools import BaseTool from pydantic import BaseModel, ConfigDict, Field -from sc_system_ai.template.agent import Agent +from sc_system_ai.template.agent import Agent, AgentResponse from sc_system_ai.template.user_prompts import User logger = logging.getLogger(__name__) @@ -50,6 +52,12 @@ def __init__(self): user_info: User = Field(description="ユーザー情報", default=User()) agent: type[Agent] = Agent + # AgentResponseを保持する変数 + response: AgentResponse | None = None + + # ストリーミングのセットアップ + queue: Queue = Queue() + is_streaming: bool = False def __init__(self) -> None: super().__init__() @@ -57,21 +65,29 @@ def __init__(self) -> None: def _run( self, user_input: str, - ) -> dict[str, Any]: + ) -> str: logger.info(f"Calling Agent Toolが次の値で呼び出されました: {user_input}") # エージェントの呼び出し try: - agent = self.agent(user_info=self.user_info, is_streaming=False) + agent = self.agent(user_info=self.user_info) + agent.queue = self.queue except Exception as e: logger.error(f"エージェントの呼び出しに失敗しました: {e}") raise e else: logger.debug(f"エージェントの呼び出しに成功しました: {self.agent}") - resp = next(agent.invoke(user_input)) + if self.is_streaming: + asyncio.run(agent.stream_on_tool(user_input)) + resp = agent.get_response() + else: + resp = agent.invoke(user_input) + self.response = resp + if resp.error is not None: + return resp.error + return cast(str, resp.output) - return cast(dict[str, Any], resp) def set_user_info(self, user_info: User) -> None: """ユーザー情報の設定 @@ -102,6 +118,16 @@ def set_tool_info( self.description = description self.agent = agent + def setup_streaming(self, queue: Queue) -> None: + """ストリーミングのセットアップ""" + self.is_streaming = True + self.queue = queue + + def cancel_streaming(self) -> None: + """ストリーミングのキャンセル""" + self.is_streaming = False + self.queue = Queue() + calling_agent = CallingAgent() diff --git a/studies/streaming_resp.py b/studies/streaming_resp.py index 14556db..1fe94e3 100644 --- a/studies/streaming_resp.py +++ b/studies/streaming_resp.py @@ -1,12 +1,14 @@ -from dotenv import load_dotenv +import asyncio from queue import Queue from threading import Thread +from typing import Any, AsyncIterator +from dotenv import load_dotenv +from langchain.agents import AgentExecutor, create_tool_calling_agent from langchain.callbacks.base import BaseCallbackHandler -from langchain.agents import create_tool_calling_agent, AgentExecutor -from sc_system_ai.template.ai_settings import llm from sc_system_ai.agents.tools.magic_function import magic_function +from sc_system_ai.template.ai_settings import llm from sc_system_ai.template.system_prompt import PromptTemplate from sc_system_ai.template.user_prompts import User @@ -16,21 +18,22 @@ class StreamingHandler(BaseCallbackHandler): def __init__(self, queue: Queue): super().__init__() self.queue = queue - - def on_llm_new_token(self, token, **kwargs): + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + print(token) if token: self.queue.put(token) - def on_agent_finish(self, response, **kwargs): + def on_agent_finish(self, finish, **kwargs: Any) -> None: print("end") self.queue.put(None) - def on_llm_error(self, error, **kwargs): + def on_llm_error(self, **kwargs: Any) -> None: self.queue.put(None) -def main(): - queue = Queue() +async def main() -> AsyncIterator[str]: + queue: Queue = Queue() handler = StreamingHandler(queue) llm.streaming=True @@ -46,9 +49,9 @@ def main(): prompt=prompt.full_prompt ) agent_executor = AgentExecutor(agent=agent, tools=tool, callbacks=[handler]) - def task(): + def task() -> None: result = agent_executor.invoke({ - "messages": "magic functionに3を入れて" + "messages": "悲しい歌を歌ってください" }) print(result) @@ -56,21 +59,23 @@ def task(): thread.start() resp = "" + while True: + if queue.empty(): + continue + token = queue.get() + if token is None: + break + resp += token + if len(resp) > 1: + yield resp + resp = "" + thread.join() + yield resp + +async def job() -> None: + async for i in main(): + print(i) + +if __name__ == "__main__": + asyncio.run(job()) - try: - while True: - token = queue.get() - if token is None: - break - - resp += token - if len(resp) >= 5: - yield resp - resp = "" - finally: - yield resp - if thread and thread.is_alive(): - thread.join() - -for output in main(): - print(output)