From bdb14f947fa29971fff42c7109c94f0456398228 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Sun, 26 Jan 2025 10:21:38 +0000 Subject: [PATCH 01/31] =?UTF-8?q?requirements=E3=81=AE=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index f149a60..e387fee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ click==8.1.7 ; python_version >= "3.10" and python_version < "4.0" colorama==0.4.6 ; python_version >= "3.10" and python_version < "4.0" and platform_system == "Windows" dataclasses-json==0.6.7 ; python_version >= "3.10" and python_version < "4.0" distro==1.9.0 ; python_version >= "3.10" and python_version < "4.0" -duckduckgo-search==6.3.7 ; python_version >= "3.10" and python_version < "4.0" +duckduckgo-search==7.2.1 ; python_version >= "3.10" and python_version < "4.0" exceptiongroup==1.2.2 ; python_version >= "3.10" and python_version < "3.11" frozenlist==1.4.1 ; python_version >= "3.10" and python_version < "4.0" greenlet==3.1.1 ; python_version < "3.13" and (platform_machine == "aarch64" or platform_machine == "ppc64le" or platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "win32" or platform_machine == "WIN32") and python_version >= "3.10" @@ -28,9 +28,10 @@ jsonpointer==3.0.0 ; python_version >= "3.10" and python_version < "4.0" langchain-community==0.3.13 ; python_version >= "3.10" and python_version < "4.0" langchain-core==0.3.28 ; python_version >= "3.10" and python_version < "4.0" langchain-openai==0.2.14 ; python_version >= "3.10" and python_version < "4.0" -langchain-text-splitters==0.3.5 ; python_version >= "3.10" and python_version < "4.0" +langchain-text-splitters==0.3.4 ; python_version >= "3.10" and python_version < "4.0" langchain==0.3.13 ; python_version >= "3.10" and python_version < "4.0" langsmith==0.1.147 ; python_version >= "3.10" and python_version < "4.0" +lxml==5.3.0 ; python_version >= "3.10" and python_version < "4.0" marshmallow==3.22.0 ; python_version >= "3.10" and python_version < "4.0" multidict==6.1.0 ; python_version >= "3.10" and python_version < "4.0" mypy-extensions==1.0.0 ; python_version >= "3.10" and python_version < "4.0" @@ -38,7 +39,7 @@ numpy==1.26.4 ; python_version >= "3.10" and python_version < "4.0" openai==1.58.1 ; python_version >= "3.10" and python_version < "4.0" orjson==3.10.7 ; python_version >= "3.10" and python_version < "4.0" and platform_python_implementation != "PyPy" packaging==24.1 ; python_version >= "3.10" and python_version < "4.0" -primp==0.8.1 ; python_version >= "3.10" and python_version < "4.0" +primp==0.10.0 ; python_version >= "3.10" and python_version < "4.0" propcache==0.2.0 ; python_version >= "3.10" and python_version < "4.0" pydantic-core==2.23.4 ; python_version >= "3.10" and python_version < "4.0" pydantic-settings==2.5.2 ; python_version >= "3.10" and python_version < "4.0" From c61cde193b72af1e18beaae22b57b549c11f0acd Mon Sep 17 00:00:00 2001 From: haruki26 Date: Sun, 26 Jan 2025 12:47:57 +0000 Subject: [PATCH 02/31] =?UTF-8?q?invoke=E3=83=A1=E3=82=BD=E3=83=83?= =?UTF-8?q?=E3=83=89=E3=81=AE=E5=BD=B9=E5=89=B2=E3=82=92=E5=90=8C=E6=9C=9F?= =?UTF-8?q?=E3=83=AC=E3=82=B9=E3=83=9D=E3=83=B3=E3=82=B9=E3=81=AE=E3=81=BF?= =?UTF-8?q?=E3=81=AB=E9=99=90=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/template/agent.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/src/sc_system_ai/template/agent.py b/src/sc_system_ai/template/agent.py index 30841f3..fe175e5 100644 --- a/src/sc_system_ai/template/agent.py +++ b/src/sc_system_ai/template/agent.py @@ -208,7 +208,7 @@ def set_tools(self, tools: list) -> None: """ツールを設定する関数""" self.tool.set_tools(tools) - def invoke(self, message: str) -> Iterator[str | AgentResponse]: + def invoke(self, message: str) -> AgentResponse: """ エージェントを実行する関数 @@ -236,15 +236,8 @@ def invoke(self, message: str) -> Iterator[str | AgentResponse]: tools=self.tool.tools, callbacks= [self.handler] if self._is_streaming else None ) - - if self._is_streaming: - yield from self._streaming_invoke(message) - else: - self._invoke(message) - if "error" in self.result: - yield self.result["error"] - else: - yield self.result + self._invoke(message) + return self.get_response() def _invoke(self, message: str) -> None: try: # エージェントの実行 @@ -351,10 +344,10 @@ def display_agent_prompt(self) -> None: agent.assistant_info = "あなたは優秀な校正者です。" agent.tool.set_tools(tools) - result = next(agent.invoke("magic function に3")) + result = agent.invoke("magic function に3") print(result) - agent.is_streaming = True - for output in agent.invoke("magic function に3"): - print(output) - print(agent.get_response()) + # agent.is_streaming = True + # for output in agent.invoke("magic function に3"): + # print(output) + # print(agent.get_response()) From 33a21c4bfd070b2fc46aba1a3246528cfa4d1f4f Mon Sep 17 00:00:00 2001 From: haruki26 Date: Sun, 26 Jan 2025 14:01:37 +0000 Subject: [PATCH 03/31] =?UTF-8?q?=E9=9D=9E=E5=90=8C=E6=9C=9F=E3=82=B8?= =?UTF-8?q?=E3=82=A7=E3=83=8D=E3=83=AC=E3=83=BC=E3=82=BF=E3=81=A7=E3=82=B9?= =?UTF-8?q?=E3=83=88=E3=83=AA=E3=83=BC=E3=83=9F=E3=83=B3=E3=82=B0=E5=AE=9F?= =?UTF-8?q?=E8=A3=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 挙動が怪しいので別の環境で要検証 --- studies/streaming_resp.py | 61 +++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/studies/streaming_resp.py b/studies/streaming_resp.py index 14556db..1fe94e3 100644 --- a/studies/streaming_resp.py +++ b/studies/streaming_resp.py @@ -1,12 +1,14 @@ -from dotenv import load_dotenv +import asyncio from queue import Queue from threading import Thread +from typing import Any, AsyncIterator +from dotenv import load_dotenv +from langchain.agents import AgentExecutor, create_tool_calling_agent from langchain.callbacks.base import BaseCallbackHandler -from langchain.agents import create_tool_calling_agent, AgentExecutor -from sc_system_ai.template.ai_settings import llm from sc_system_ai.agents.tools.magic_function import magic_function +from sc_system_ai.template.ai_settings import llm from sc_system_ai.template.system_prompt import PromptTemplate from sc_system_ai.template.user_prompts import User @@ -16,21 +18,22 @@ class StreamingHandler(BaseCallbackHandler): def __init__(self, queue: Queue): super().__init__() self.queue = queue - - def on_llm_new_token(self, token, **kwargs): + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + print(token) if token: self.queue.put(token) - def on_agent_finish(self, response, **kwargs): + def on_agent_finish(self, finish, **kwargs: Any) -> None: print("end") self.queue.put(None) - def on_llm_error(self, error, **kwargs): + def on_llm_error(self, **kwargs: Any) -> None: self.queue.put(None) -def main(): - queue = Queue() +async def main() -> AsyncIterator[str]: + queue: Queue = Queue() handler = StreamingHandler(queue) llm.streaming=True @@ -46,9 +49,9 @@ def main(): prompt=prompt.full_prompt ) agent_executor = AgentExecutor(agent=agent, tools=tool, callbacks=[handler]) - def task(): + def task() -> None: result = agent_executor.invoke({ - "messages": "magic functionに3を入れて" + "messages": "悲しい歌を歌ってください" }) print(result) @@ -56,21 +59,23 @@ def task(): thread.start() resp = "" + while True: + if queue.empty(): + continue + token = queue.get() + if token is None: + break + resp += token + if len(resp) > 1: + yield resp + resp = "" + thread.join() + yield resp + +async def job() -> None: + async for i in main(): + print(i) + +if __name__ == "__main__": + asyncio.run(job()) - try: - while True: - token = queue.get() - if token is None: - break - - resp += token - if len(resp) >= 5: - yield resp - resp = "" - finally: - yield resp - if thread and thread.is_alive(): - thread.join() - -for output in main(): - print(output) From 650a3ba9e5cd0048086ac891c7b71be90acd183b Mon Sep 17 00:00:00 2001 From: haruki26 Date: Sun, 26 Jan 2025 14:34:18 +0000 Subject: [PATCH 04/31] =?UTF-8?q?Agent=E3=82=AF=E3=83=A9=E3=82=B9=E3=81=AB?= =?UTF-8?q?stream=E3=83=A1=E3=82=BD=E3=83=83=E3=83=89=E3=82=92=E4=BD=9C?= =?UTF-8?q?=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/template/agent.py | 92 ++++++++++++++++++------------ 1 file changed, 54 insertions(+), 38 deletions(-) diff --git a/src/sc_system_ai/template/agent.py b/src/sc_system_ai/template/agent.py index fe175e5..86c58d8 100644 --- a/src/sc_system_ai/template/agent.py +++ b/src/sc_system_ai/template/agent.py @@ -7,7 +7,7 @@ """ import logging -from collections.abc import Iterator +from collections.abc import AsyncIterator from queue import Queue from threading import Thread from typing import Any, TypedDict, TypeGuard @@ -226,24 +226,61 @@ def invoke(self, message: str) -> AgentResponse: resp = next(agent.invoke("user message")) ``` """ + self.cancel_streaming() + self._invoke(message) + return self.get_response() + + async def stream(self, message: str, return_length: int) -> AsyncIterator[AgentResponse]: + """ + エージェントをストリーミングで実行する関数 + + Args: + message (str): ユーザーからのメッセージ + + ```python + for output in agent.stream("user message"): + print(output) + ``` + """ + self.setup_streaming() + phrase = "" + thread = Thread(target=self._invoke, args=(message,)) + thread.start() + try: + while True: + if self.queue.empty(): + continue + + token = self.handler.queue.get() + if token is None: + logger.debug("エージェントの実行が終了しました。") + break + phrase += token + if len(phrase) >= return_length: + yield {"output": phrase} + phrase = "" + except Exception as e: + logger.error(f"エラーが発生しました:{e}") + + if thread and thread.is_alive(): + thread.join() + yield {"output": phrase} + + def _invoke(self, message: str) -> None: agent = create_tool_calling_agent( llm=self.llm, tools=self.tool.tools, prompt=self.prompt_template.full_prompt ) - self.agent_executor = AgentExecutor( + agent_executor = AgentExecutor( agent=agent, tools=self.tool.tools, callbacks= [self.handler] if self._is_streaming else None ) - self._invoke(message) - return self.get_response() - - def _invoke(self, message: str) -> None: try: # エージェントの実行 logger.info("エージェントの実行を開始します。\n-------------------\n") logger.debug(f"最終的なプロンプト: {self.prompt_template.full_prompt.messages}") - resp = self.agent_executor.invoke({ + resp = agent_executor.invoke({ "chat_history": self.user_info.conversations.format_conversation(), "messages": message, }) @@ -265,30 +302,6 @@ def _response_checker(self, response: Any) -> TypeGuard[AgentResponse]: return True return False - def _streaming_invoke(self, message: str) -> Iterator[str]: - phrase = "" - thread = Thread(target=self._invoke, args=(message,)) - - thread.start() - try: - while True: - token = self.handler.queue.get() - if token is None: - break - - phrase += token - if len(phrase) >= self._return_length: - yield phrase - phrase = "" - except Exception as e: - logger.error(f"エラーが発生しました:{e}") - finally: - yield phrase - - #クリーンアップ - if thread and thread.is_alive(): - thread.join() - def get_response(self) -> AgentResponse: """エージェントのレスポンスを取得する関数""" try: @@ -323,6 +336,8 @@ def display_agent_prompt(self) -> None: if __name__ == "__main__": + import asyncio + from sc_system_ai.logging_config import setup_logging setup_logging() # ユーザー情報 @@ -339,15 +354,16 @@ def display_agent_prompt(self) -> None: agent = Agent( user_info=user_info, llm=llm, - is_streaming=False + is_streaming=True, ) agent.assistant_info = "あなたは優秀な校正者です。" agent.tool.set_tools(tools) - result = agent.invoke("magic function に3") - print(result) + # result = agent.invoke("magic function に3") + # print(result) + + async def main() -> None: + async for output in agent.stream("magic function に3", 5): + print(output) - # agent.is_streaming = True - # for output in agent.invoke("magic function に3"): - # print(output) - # print(agent.get_response()) + asyncio.run(main()) From 05b4e26413f9f6965e1f633d274468890547685f Mon Sep 17 00:00:00 2001 From: haruki26 Date: Sun, 26 Jan 2025 14:43:44 +0000 Subject: [PATCH 05/31] =?UTF-8?q?Agent=E3=82=AF=E3=83=A9=E3=82=B9=E3=82=92?= =?UTF-8?q?=E3=82=A4=E3=83=B3=E3=82=B9=E3=82=BF=E3=83=B3=E3=82=B9=E5=8C=96?= =?UTF-8?q?=E6=99=82=E3=81=AB=E3=82=B9=E3=83=88=E3=83=AA=E3=83=BC=E3=83=9F?= =?UTF-8?q?=E3=83=B3=E3=82=B0=E3=82=92=E5=88=A4=E6=96=AD=E3=81=97=E3=81=AA?= =?UTF-8?q?=E3=81=84=E3=82=88=E3=81=86=E3=81=AB=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/template/agent.py | 50 ++++++------------------------ 1 file changed, 10 insertions(+), 40 deletions(-) diff --git a/src/sc_system_ai/template/agent.py b/src/sc_system_ai/template/agent.py index 86c58d8..07b325b 100644 --- a/src/sc_system_ai/template/agent.py +++ b/src/sc_system_ai/template/agent.py @@ -66,10 +66,8 @@ def is_streaming(self, is_streaming: bool) -> None: def setup_streaming(self, tools: list[BaseTool]) -> list[BaseTool]: """ストリーミングのセットアップを行う関数""" self.handler = StreamingToolHandler(self.queue) - for tool in tools: tool.callbacks = [self.handler] - return tools def cancel_streaming(self) -> None: @@ -134,52 +132,21 @@ def __init__( self, llm: AzureChatOpenAI = llm, user_info: User | None = None, - is_streaming: bool = True, - return_length: int = 5 ): self.llm = llm self.user_info = user_info if user_info is not None else User() self.result: AgentResponse - self._is_streaming = is_streaming - self._return_length = return_length self.queue: Queue = Queue() # assistant_infoとtoolsは各エージェントで設定する self.assistant_info = "" - self.tool = ToolManager(tools=template_tools, is_streaming=self._is_streaming, queue=self.queue) + self.tool = ToolManager(tools=template_tools, is_streaming=True, queue=self.queue) self.prompt_template = PromptTemplate(assistant_info=self.assistant_info, user_info=self.user_info) - if self._is_streaming: - self.setup_streaming() - self.get_agent_info() - @property - def is_streaming(self) -> bool: - return self._is_streaming - - @is_streaming.setter - def is_streaming(self, is_streaming: bool) -> None: - self._is_streaming = is_streaming - self.tool.is_streaming = is_streaming - - if self._is_streaming: - self.setup_streaming() - else: - self.cancel_streaming() - - @property - def return_length(self) -> int: - return self._return_length - - @return_length.setter - def return_length(self, return_length: int) -> None: - if return_length <= 0: - raise ValueError("return_lengthは1以上の整数である必要があります。") - self._return_length = return_length - def setup_streaming(self) -> None: """ストリーミング時のセットアップを行う関数""" self.clear_queue() @@ -227,10 +194,14 @@ def invoke(self, message: str) -> AgentResponse: ``` """ self.cancel_streaming() - self._invoke(message) + self._invoke(message, False) return self.get_response() - async def stream(self, message: str, return_length: int) -> AsyncIterator[AgentResponse]: + async def stream( + self, + message: str, + return_length: int = 5 + ) -> AsyncIterator[AgentResponse]: """ エージェントをストリーミングで実行する関数 @@ -244,7 +215,7 @@ async def stream(self, message: str, return_length: int) -> AsyncIterator[AgentR """ self.setup_streaming() phrase = "" - thread = Thread(target=self._invoke, args=(message,)) + thread = Thread(target=self._invoke, args=(message, True,)) thread.start() try: while True: @@ -266,7 +237,7 @@ async def stream(self, message: str, return_length: int) -> AsyncIterator[AgentR thread.join() yield {"output": phrase} - def _invoke(self, message: str) -> None: + def _invoke(self, message: str, streaming: bool) -> None: agent = create_tool_calling_agent( llm=self.llm, tools=self.tool.tools, @@ -275,7 +246,7 @@ def _invoke(self, message: str) -> None: agent_executor = AgentExecutor( agent=agent, tools=self.tool.tools, - callbacks= [self.handler] if self._is_streaming else None + callbacks= [self.handler] if streaming else None ) try: # エージェントの実行 logger.info("エージェントの実行を開始します。\n-------------------\n") @@ -354,7 +325,6 @@ def display_agent_prompt(self) -> None: agent = Agent( user_info=user_info, llm=llm, - is_streaming=True, ) agent.assistant_info = "あなたは優秀な校正者です。" agent.tool.set_tools(tools) From a74b05a00da1f0d6f029fd972c1b8955acc3ba16 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Sun, 26 Jan 2025 14:46:18 +0000 Subject: [PATCH 06/31] =?UTF-8?q?ToolManager=E3=82=82=E5=90=8C=E6=A7=98?= =?UTF-8?q?=E3=81=AB=E5=A4=89=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/template/agent.py | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/src/sc_system_ai/template/agent.py b/src/sc_system_ai/template/agent.py index 07b325b..2c6bd0a 100644 --- a/src/sc_system_ai/template/agent.py +++ b/src/sc_system_ai/template/agent.py @@ -39,30 +39,14 @@ class ToolManager: """ def __init__( self, + queue: Queue, tools: list | None = None, - is_streaming: bool = True, - queue: Queue | None = None, ): self.tools: list[BaseTool] = [] - self._is_streaming = is_streaming - self.queue = queue if queue is not None else Queue() - + self.queue = queue if tools is not None: self.set_tools(tools) - @property - def is_streaming(self) -> bool: - return self._is_streaming - - @is_streaming.setter - def is_streaming(self, is_streaming: bool) -> None: - self._is_streaming = is_streaming - - if self._is_streaming: - self.tools = self.setup_streaming(self.tools) - else: - self.cancel_streaming() - def setup_streaming(self, tools: list[BaseTool]) -> list[BaseTool]: """ストリーミングのセットアップを行う関数""" self.handler = StreamingToolHandler(self.queue) @@ -141,7 +125,7 @@ def __init__( # assistant_infoとtoolsは各エージェントで設定する self.assistant_info = "" - self.tool = ToolManager(tools=template_tools, is_streaming=True, queue=self.queue) + self.tool = ToolManager(tools=template_tools, queue=self.queue) self.prompt_template = PromptTemplate(assistant_info=self.assistant_info, user_info=self.user_info) From 85a081235040e003ebf74cb8ede960c3d6f11e2c Mon Sep 17 00:00:00 2001 From: haruki26 Date: Sun, 26 Jan 2025 14:50:28 +0000 Subject: [PATCH 07/31] =?UTF-8?q?tool=E3=81=AEstreaming=E3=82=BB=E3=83=83?= =?UTF-8?q?=E3=83=88=E3=82=A2=E3=83=83=E3=83=97=E3=81=AE=E6=8C=99=E5=8B=95?= =?UTF-8?q?=E3=82=92=E5=A4=89=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/template/agent.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/sc_system_ai/template/agent.py b/src/sc_system_ai/template/agent.py index 2c6bd0a..6fadec4 100644 --- a/src/sc_system_ai/template/agent.py +++ b/src/sc_system_ai/template/agent.py @@ -44,15 +44,14 @@ def __init__( ): self.tools: list[BaseTool] = [] self.queue = queue + self.handler = StreamingToolHandler(self.queue) if tools is not None: self.set_tools(tools) - def setup_streaming(self, tools: list[BaseTool]) -> list[BaseTool]: + def setup_streaming(self) -> None: """ストリーミングのセットアップを行う関数""" - self.handler = StreamingToolHandler(self.queue) - for tool in tools: + for tool in self.tools: tool.callbacks = [self.handler] - return tools def cancel_streaming(self) -> None: """ストリーミングのセットアップを解除する関数""" From 302f9f3aca32d313ac7ad963948caeab4434c1e6 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Sun, 26 Jan 2025 14:54:38 +0000 Subject: [PATCH 08/31] =?UTF-8?q?agent=E3=82=BB=E3=83=83=E3=83=88=E3=82=A2?= =?UTF-8?q?=E3=83=83=E3=83=97=E3=81=A8=E5=90=8C=E6=99=82=E3=81=AB=E3=83=84?= =?UTF-8?q?=E3=83=BC=E3=83=AB=E3=82=82=E8=A1=8C=E3=81=86=E3=82=88=E3=81=86?= =?UTF-8?q?=E3=81=AB=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/template/agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sc_system_ai/template/agent.py b/src/sc_system_ai/template/agent.py index 6fadec4..44c3869 100644 --- a/src/sc_system_ai/template/agent.py +++ b/src/sc_system_ai/template/agent.py @@ -121,6 +121,7 @@ def __init__( self.result: AgentResponse self.queue: Queue = Queue() + self.handler = StreamingAgentHandler(self.queue) # assistant_infoとtoolsは各エージェントで設定する self.assistant_info = "" @@ -133,16 +134,15 @@ def __init__( def setup_streaming(self) -> None: """ストリーミング時のセットアップを行う関数""" self.clear_queue() - self.handler = StreamingAgentHandler(self.queue) - - # llmの設定 self.llm.streaming = True self.llm.callbacks = [self.handler] + self.tool.setup_streaming() def cancel_streaming(self) -> None: """ストリーミング時のセットアップを解除する関数""" self.llm.streaming = False self.llm.callbacks = None + self.tool.cancel_streaming() def clear_queue(self) -> None: """キューをクリアする関数""" From 0408704a2d3b1346bb37acecb4a2e718f5712025 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Sun, 26 Jan 2025 14:58:05 +0000 Subject: [PATCH 09/31] =?UTF-8?q?calling=5Fagent=E3=81=A7Agent=E3=82=AF?= =?UTF-8?q?=E3=83=A9=E3=82=B9=E3=81=AE=E4=BF=AE=E6=AD=A3=E3=81=AB=E5=AF=BE?= =?UTF-8?q?=E5=BF=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/template/calling_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sc_system_ai/template/calling_agent.py b/src/sc_system_ai/template/calling_agent.py index fadce2d..9d8d66f 100644 --- a/src/sc_system_ai/template/calling_agent.py +++ b/src/sc_system_ai/template/calling_agent.py @@ -62,14 +62,14 @@ def _run( # エージェントの呼び出し try: - agent = self.agent(user_info=self.user_info, is_streaming=False) + agent = self.agent(user_info=self.user_info) except Exception as e: logger.error(f"エージェントの呼び出しに失敗しました: {e}") raise e else: logger.debug(f"エージェントの呼び出しに成功しました: {self.agent}") - resp = next(agent.invoke(user_input)) + resp = agent.invoke(user_input) return cast(dict[str, Any], resp) From 90cbb2ea206fb81b7720de655acdaa1a3ad5d544 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Sun, 26 Jan 2025 15:17:07 +0000 Subject: [PATCH 10/31] =?UTF-8?q?agent=E7=BE=A4=E3=82=92=E6=94=B9=E4=BF=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/agents/classify_agent.py | 26 ++++++++----------- src/sc_system_ai/agents/dummy_agent.py | 4 --- src/sc_system_ai/agents/main_agent.py | 8 ++---- .../agents/search_school_data_agent.py | 23 ++++++---------- 4 files changed, 21 insertions(+), 40 deletions(-) diff --git a/src/sc_system_ai/agents/classify_agent.py b/src/sc_system_ai/agents/classify_agent.py index a10d377..32a882e 100644 --- a/src/sc_system_ai/agents/classify_agent.py +++ b/src/sc_system_ai/agents/classify_agent.py @@ -1,5 +1,4 @@ -from collections.abc import Iterator -from typing import Any, cast +from typing import cast from langchain_openai import AzureChatOpenAI @@ -35,14 +34,10 @@ def __init__( self, llm: AzureChatOpenAI = llm, user_info: User | None = None, - is_streaming: bool = True, - return_length: int = 5 ): super().__init__( llm=llm, user_info=user_info if user_info is not None else User(), - is_streaming=is_streaming, - return_length=return_length ) self.assistant_info = classify_agent_info super().set_assistant_info(self.assistant_info) @@ -53,16 +48,17 @@ def set_tools(self, tools: list) -> None: for tool in tools: if isinstance(tool, CallingAgent): tool.set_user_info(self.user_info) - super().set_tools(tools) - def invoke(self, message: str) -> Iterator[str | AgentResponse | SearchSchoolDataAgentResponse]: - if self.is_streaming: - yield from super().invoke(message) + def invoke(self, message: str) -> AgentResponse | SearchSchoolDataAgentResponse: + # toolの出力を整形 + resp = super().invoke(message) + if type(resp["output"]) is str: + return resp else: - # ツールの出力をそのまま返却 - resp = cast(dict[str, Any], next(super().invoke(message))) - yield resp["output"] + return cast( + AgentResponse | SearchSchoolDataAgentResponse, resp["output"] + ) if __name__ == "__main__": @@ -79,7 +75,7 @@ def invoke(self, message: str) -> Iterator[str | AgentResponse | SearchSchoolDat user_info.conversations.add_conversations_list(history) while True: - classify_agent = ClassifyAgent(user_info=user_info, is_streaming=False) + classify_agent = ClassifyAgent(user_info=user_info) # classify_agent.display_agent_info() # print(main_agent.get_agent_prompt()) # classify_agent.display_agent_prompt() @@ -89,7 +85,7 @@ def invoke(self, message: str) -> Iterator[str | AgentResponse | SearchSchoolDat break # 通常の呼び出し - resp = next(classify_agent.invoke(user)) + resp = classify_agent.invoke(user) print(resp) # ストリーミング呼び出し diff --git a/src/sc_system_ai/agents/dummy_agent.py b/src/sc_system_ai/agents/dummy_agent.py index b864df3..b17b432 100644 --- a/src/sc_system_ai/agents/dummy_agent.py +++ b/src/sc_system_ai/agents/dummy_agent.py @@ -49,14 +49,10 @@ def __init__( self, llm: AzureChatOpenAI = llm, user_info: User | None = None, - is_streaming: bool = True, - return_length: int = 5 ): super().__init__( llm=llm, user_info=user_info if user_info is not None else User(), - is_streaming=is_streaming, - return_length=return_length ) self.assistant_info = dummy_agent_info super().set_assistant_info(self.assistant_info) diff --git a/src/sc_system_ai/agents/main_agent.py b/src/sc_system_ai/agents/main_agent.py index 92888e2..ab76f31 100644 --- a/src/sc_system_ai/agents/main_agent.py +++ b/src/sc_system_ai/agents/main_agent.py @@ -16,14 +16,10 @@ def __init__( self, llm: AzureChatOpenAI = llm, user_info: User | None = None, - is_streaming: bool = True, - return_length: int = 5 ): super().__init__( llm=llm, user_info=user_info if user_info is not None else User(), - is_streaming=is_streaming, - return_length=return_length ) self.assistant_info = main_agent_info super().set_assistant_info(self.assistant_info) @@ -43,9 +39,9 @@ def __init__( user_info = User(name=user_name, major=user_major) user_info.conversations.add_conversations_list(history) - main_agent = MainAgent(user_info=user_info, is_streaming=False) + main_agent = MainAgent(user_info=user_info) main_agent.display_agent_info() # print(main_agent.get_agent_prompt()) main_agent.display_agent_prompt() - print(next(main_agent.invoke("magic function に3をいれて"))) + print(main_agent.invoke("magic function に3をいれて")) diff --git a/src/sc_system_ai/agents/search_school_data_agent.py b/src/sc_system_ai/agents/search_school_data_agent.py index d780c31..a5ff424 100644 --- a/src/sc_system_ai/agents/search_school_data_agent.py +++ b/src/sc_system_ai/agents/search_school_data_agent.py @@ -1,6 +1,3 @@ -from collections.abc import Iterator -from typing import cast - from langchain_openai import AzureChatOpenAI # from sc_system_ai.agents.tools import magic_function @@ -29,18 +26,14 @@ def __init__( self, llm: AzureChatOpenAI = llm, user_info: User | None = None, - is_streaming: bool = True, - return_length: int = 5 ): super().__init__( llm=llm, user_info=user_info if user_info is not None else User(), - is_streaming=is_streaming, - return_length=return_length ) self.assistant_info = search_school_data_agent_info - def invoke(self, message: str) -> Iterator[SearchSchoolDataAgentResponse]: + def invoke(self, message: str) -> SearchSchoolDataAgentResponse: # Agentクラスのストリーミングを改修後にストリーミング実装 self.cancel_streaming() search = search_school_database_cosmos(message) @@ -50,11 +43,11 @@ def invoke(self, message: str) -> Iterator[SearchSchoolDataAgentResponse]: ids.append(doc.metadata["id"]) super().set_assistant_info(self.assistant_info) - resp = cast(AgentResponse, next(super().invoke(message))) - yield { - **resp, - "document_id": ids - } + resp = super().invoke(message) + return SearchSchoolDataAgentResponse( + document_id=ids, + **resp + ) if __name__ == "__main__": from sc_system_ai.logging_config import setup_logging @@ -68,5 +61,5 @@ def invoke(self, message: str) -> Iterator[SearchSchoolDataAgentResponse]: ] user_info = User(name=user_name, major=user_major) user_info.conversations.add_conversations_list(history) - agent = SearchSchoolDataAgent(user_info=user_info, is_streaming=False) - print(next(agent.invoke("京都テックについて教えて"))) + agent = SearchSchoolDataAgent(user_info=user_info) + print(agent.invoke("京都テックについて教えて")) From 9f57b5973e8587acec3e1b911481bf7dc107572f Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 10:50:29 +0000 Subject: [PATCH 11/31] =?UTF-8?q?main=E3=82=92=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/main.py | 59 ++++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/src/sc_system_ai/main.py b/src/sc_system_ai/main.py index 6af0a7c..68d869e 100644 --- a/src/sc_system_ai/main.py +++ b/src/sc_system_ai/main.py @@ -58,7 +58,7 @@ """ import logging -from collections.abc import Iterator +from collections.abc import AsyncIterator from importlib import import_module from typing import Literal, TypedDict, cast @@ -114,16 +114,12 @@ def __init__( user_name: str, user_major: str, conversation: list[tuple[str, str]] | None = None, - is_streaming: bool = True, - return_length: int = 5 ) -> None: self.user = User(name=user_name, major=user_major) if conversation is None: conversation = [] self.user.conversations.add_conversations_list(conversation) - self.is_streaming = is_streaming - self.return_length = return_length self._agent: Agent | None = None @property @@ -144,7 +140,7 @@ def invoke( self, message: str, command: AGENT = "classify" - ) -> Iterator[Response]: + ) -> Response: """エージェントを呼び出し、チャットを行う関数 Args: @@ -164,12 +160,36 @@ def invoke( - dummy: ダミーエージェント """ self._call_agent(command) - if self.is_streaming: - for resp in self.agent.invoke(message): - if type(resp) is str: - yield self._create_response({"output": resp}) - else: - resp = next(self.agent.invoke(message)) + return self._create_response(cast(dict, self.agent.invoke(message))) + + async def stream( + self, + message: str, + return_length: int = 5, + command: AGENT = "classify" + ) -> AsyncIterator[Response]: + """エージェントを呼び出し、ストリーミングチャットを行う関数 + + Args: + message (str): メッセージ + return_length (int, optional): ストリーミングモード時の返答数. デフォルトは5 + command (AGENT, optional): 呼び出すエージェント。デフォルトでは分類エージェントを呼び出します。 + + Returns: + Iterator[str]: エージェントからの返答 + + コマンドでエージェントを指定して、エージェントを呼び出す場合。 + ```python + for resp in chat.stream(message="私の名前と専攻は何ですか?", command="dummy"): + print(resp) + ``` + + 呼び出し可能なエージェント: + - classify: 分類エージェント + - dummy: ダミーエージェント + """ + self._call_agent(command) + async for resp in self.agent.stream(message, return_length): yield self._create_response(cast(dict, resp)) def _call_agent(self, command: AGENT) -> None: @@ -182,8 +202,6 @@ def _call_agent(self, command: AGENT) -> None: self.agent = agent_class( llm=llm, user_info=self.user, - is_streaming=self.is_streaming, - return_length=self.return_length ) except (ModuleNotFoundError, AttributeError, ValueError): logger.error(f"エージェントが見つかりません: {command}") @@ -212,16 +230,16 @@ def static_chat() -> None: user.conversations.add_conversations_list(conversation) # エージェントの設定 - agent = Agent(llm=llm, user_info=user, is_streaming=False) + agent = Agent(llm=llm, user_info=user) # メッセージを送信 message = "私の名前と専攻は何ですか?" - resp = next(agent.invoke(message)) + resp = agent.invoke(message) print(resp) -def streaming_chat() -> None: +async def streaming_chat() -> None: # ユーザー情報 user_name = "hogehoge" user_major = "fugafuga専攻" @@ -235,11 +253,11 @@ def streaming_chat() -> None: user.conversations.add_conversations_list(conversation) # エージェントの設定 - agent = Agent(llm=llm, user_info=user, is_streaming=True) + agent = Agent(llm=llm, user_info=user) # メッセージを送信 message = "私の名前と専攻は何ですか?" - for resp in agent.invoke(message): + async for resp in agent.stream(message): print(resp) @@ -257,7 +275,6 @@ def streaming_chat() -> None: ("human", "こんにちは!"), ("ai", "本日はどのようなご用件でしょうか?") ], - is_streaming=False, ) message = "私の名前と専攻は何ですか?" @@ -267,7 +284,7 @@ def streaming_chat() -> None: # pass # # 通常呼び出し - resp = next(chat.invoke(message=message, command="dummy")) + resp = chat.invoke(message=message, command="dummy") print(resp) # ストリーミング呼び出し From 213673ac31f5d35876a0cb0f259fabc894a0608b Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 11:00:06 +0000 Subject: [PATCH 12/31] =?UTF-8?q?=E6=A4=9C=E7=B4=A2=E3=82=A8=E3=83=BC?= =?UTF-8?q?=E3=82=B8=E3=82=A7=E3=83=B3=E3=83=88=E3=81=A7=E5=A4=89=E6=9B=B4?= =?UTF-8?q?=E3=81=AB=E5=AF=BE=E5=BF=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../agents/search_school_data_agent.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/sc_system_ai/agents/search_school_data_agent.py b/src/sc_system_ai/agents/search_school_data_agent.py index a5ff424..5ec64cb 100644 --- a/src/sc_system_ai/agents/search_school_data_agent.py +++ b/src/sc_system_ai/agents/search_school_data_agent.py @@ -1,3 +1,5 @@ +from collections.abc import AsyncIterator + from langchain_openai import AzureChatOpenAI # from sc_system_ai.agents.tools import magic_function @@ -33,22 +35,34 @@ def __init__( ) self.assistant_info = search_school_data_agent_info - def invoke(self, message: str) -> SearchSchoolDataAgentResponse: - # Agentクラスのストリーミングを改修後にストリーミング実装 - self.cancel_streaming() + def _add_search_result(self, message: str) -> list[str]: search = search_school_database_cosmos(message) ids = [] for doc in search: self.assistant_info += f"### {doc.metadata['title']}\n" + doc.page_content + "\n" ids.append(doc.metadata["id"]) super().set_assistant_info(self.assistant_info) + return ids + def invoke(self, message: str) -> SearchSchoolDataAgentResponse: + # Agentクラスのストリーミングを改修後にストリーミング実装 + ids = self._add_search_result(message) resp = super().invoke(message) return SearchSchoolDataAgentResponse( document_id=ids, **resp ) + async def stream(self, message: str, return_length: int = 5) -> AsyncIterator[AgentResponse]: + ids = self._add_search_result(message) + async for resp in super().stream(message, return_length): + yield resp + result = self.get_response() + self.result = SearchSchoolDataAgentResponse( + document_id=ids, + **result + ) + if __name__ == "__main__": from sc_system_ai.logging_config import setup_logging setup_logging() From f8a496de7b15a94f517a0b7f5db052ae2e9811c9 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 11:05:17 +0000 Subject: [PATCH 13/31] =?UTF-8?q?=E5=90=8C=E6=9C=9F=E3=81=A8=E3=82=B9?= =?UTF-8?q?=E3=83=88=E3=83=AA=E3=83=BC=E3=83=9F=E3=83=B3=E3=82=B0=E6=AF=8E?= =?UTF-8?q?=E3=81=AB=E3=83=AC=E3=82=B9=E3=83=9D=E3=83=B3=E3=82=B9=E3=81=AE?= =?UTF-8?q?=E5=9E=8B=E3=82=92=E7=B5=B1=E4=B8=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/template/agent.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/sc_system_ai/template/agent.py b/src/sc_system_ai/template/agent.py index 44c3869..6aa90a8 100644 --- a/src/sc_system_ai/template/agent.py +++ b/src/sc_system_ai/template/agent.py @@ -93,12 +93,18 @@ def _tool_checker(self, tool: Any) -> bool: """ # Agentのレスポンスの型 -class AgentResponse(TypedDict, total=False): +class AgentResponse(TypedDict): """Agentのレスポンスの型""" - chat_history: list[HumanMessage | AIMessage] - messages: str - output: str - error: str + chat_history: list[HumanMessage | AIMessage] | None + messages: str | None + output: str | None + error: str | None + document_id: list[str] | None + +class StreamingAgentResponse(TypedDict): + """Agentのストリーミングレスポンスの型""" + output: str | None + error: str | None # Agentクラスの作成 class Agent: From e0f0e5483c7302b0220f94605ca447652de13e24 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 13:05:29 +0000 Subject: [PATCH 14/31] =?UTF-8?q?=E3=83=AC=E3=82=B9=E3=83=9D=E3=83=B3?= =?UTF-8?q?=E3=82=B9=E3=81=AE=E5=9E=8B=E3=82=92BaseModel=E3=81=A7=E5=86=8D?= =?UTF-8?q?=E5=AE=9A=E7=BE=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/template/agent.py | 36 +++++++++++++++++------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/src/sc_system_ai/template/agent.py b/src/sc_system_ai/template/agent.py index 6aa90a8..bfe5e82 100644 --- a/src/sc_system_ai/template/agent.py +++ b/src/sc_system_ai/template/agent.py @@ -10,13 +10,14 @@ from collections.abc import AsyncIterator from queue import Queue from threading import Thread -from typing import Any, TypedDict, TypeGuard +from typing import Any, Literal, TypeGuard from langchain.agents import AgentExecutor, create_tool_calling_agent from langchain.tools import BaseTool from langchain_community.tools import DuckDuckGoSearchRun from langchain_core.messages import AIMessage, HumanMessage from langchain_openai import AzureChatOpenAI +from pydantic import BaseModel from sc_system_ai.agents.tools import magic_function, search_duckduckgo from sc_system_ai.template.ai_settings import llm @@ -91,20 +92,22 @@ def _tool_checker(self, tool: Any) -> bool: tools: {tools} ------------------- """ - # Agentのレスポンスの型 -class AgentResponse(TypedDict): +class BaseAgentResponse(BaseModel): + """Agentのレスポンスの型""" + output: str | None = None + error: str | None = None + +class AgentResponse(BaseAgentResponse): """Agentのレスポンスの型""" - chat_history: list[HumanMessage | AIMessage] | None - messages: str | None - output: str | None - error: str | None - document_id: list[str] | None + chat_history: list[HumanMessage | AIMessage] | None = None + messages: str | None = None + document_id: list[str] | None = None + -class StreamingAgentResponse(TypedDict): +class StreamingAgentResponse(BaseAgentResponse): """Agentのストリーミングレスポンスの型""" - output: str | None - error: str | None + status: Literal["processing", "completed", "error"] | None = None # Agentクラスの作成 class Agent: @@ -125,7 +128,7 @@ def __init__( self.llm = llm self.user_info = user_info if user_info is not None else User() - self.result: AgentResponse + self.result = AgentResponse() self.queue: Queue = Queue() self.handler = StreamingAgentHandler(self.queue) @@ -190,7 +193,7 @@ async def stream( self, message: str, return_length: int = 5 - ) -> AsyncIterator[AgentResponse]: + ) -> AsyncIterator[StreamingAgentResponse]: """ エージェントをストリーミングで実行する関数 @@ -217,14 +220,15 @@ async def stream( break phrase += token if len(phrase) >= return_length: - yield {"output": phrase} + yield StreamingAgentResponse(output=phrase, error=None) phrase = "" except Exception as e: logger.error(f"エラーが発生しました:{e}") + yield StreamingAgentResponse(output=None, error=f"エラーが発生しました:{e}") if thread and thread.is_alive(): thread.join() - yield {"output": phrase} + yield StreamingAgentResponse(output=phrase, error=None) def _invoke(self, message: str, streaming: bool) -> None: agent = create_tool_calling_agent( @@ -253,7 +257,7 @@ def _invoke(self, message: str, streaming: bool) -> None: raise RuntimeError("エージェントの実行結果取得に失敗しました。") except Exception as e: logger.error(f"エージェントの実行に失敗しました。エラー内容: {e}") - self.result = {"error": f"エージェントの実行に失敗しました。エラー内容: {e}"} + self.result["error"] = f"エージェントの実行に失敗しました。エラー内容: {e}" def _response_checker(self, response: Any) -> TypeGuard[AgentResponse]: """レスポンスの型チェック""" From ba2ce2b24ec0d9982edf310f580be07d36f4a237 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 13:13:44 +0000 Subject: [PATCH 15/31] =?UTF-8?q?=E3=83=AC=E3=82=B9=E3=83=9D=E3=83=B3?= =?UTF-8?q?=E3=82=B9=E3=81=AE=E5=9E=8B=E3=82=92=E9=81=A9=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/template/agent.py | 35 +++++++++++++++--------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/sc_system_ai/template/agent.py b/src/sc_system_ai/template/agent.py index bfe5e82..e763dfe 100644 --- a/src/sc_system_ai/template/agent.py +++ b/src/sc_system_ai/template/agent.py @@ -10,7 +10,7 @@ from collections.abc import AsyncIterator from queue import Queue from threading import Thread -from typing import Any, Literal, TypeGuard +from typing import Any, Literal from langchain.agents import AgentExecutor, create_tool_calling_agent from langchain.tools import BaseTool @@ -128,7 +128,7 @@ def __init__( self.llm = llm self.user_info = user_info if user_info is not None else User() - self.result = AgentResponse() + self.result: AgentResponse self.queue: Queue = Queue() self.handler = StreamingAgentHandler(self.queue) @@ -220,15 +220,19 @@ async def stream( break phrase += token if len(phrase) >= return_length: - yield StreamingAgentResponse(output=phrase, error=None) + yield StreamingAgentResponse( + output=phrase, error=None, status="processing" + ) phrase = "" except Exception as e: logger.error(f"エラーが発生しました:{e}") - yield StreamingAgentResponse(output=None, error=f"エラーが発生しました:{e}") + yield StreamingAgentResponse( + output=None, error=f"エラーが発生しました:{e}", status="error" + ) if thread and thread.is_alive(): thread.join() - yield StreamingAgentResponse(output=phrase, error=None) + yield StreamingAgentResponse(output=phrase, error=None, status="completed") def _invoke(self, message: str, streaming: bool) -> None: agent = create_tool_calling_agent( @@ -249,29 +253,26 @@ def _invoke(self, message: str, streaming: bool) -> None: "messages": message, }) - if self._response_checker(resp): - self.result = resp + if "output" in resp: + self.result = AgentResponse( + chat_history=resp.get("chat_history"), + messages=resp.get("messages"), + output=resp.get("output"), + ) else: logger.error("エージェントの実行結果取得に失敗しました。") logger.debug(f"エージェントの実行結果: {resp}") raise RuntimeError("エージェントの実行結果取得に失敗しました。") except Exception as e: logger.error(f"エージェントの実行に失敗しました。エラー内容: {e}") - self.result["error"] = f"エージェントの実行に失敗しました。エラー内容: {e}" - - def _response_checker(self, response: Any) -> TypeGuard[AgentResponse]: - """レスポンスの型チェック""" - if type(response) is dict: - if all(key in response for key in ["chat_history", "messages", "output"]): - return True - return False + self.result = AgentResponse(error=f"エージェントの実行に失敗しました。エラー内容: {e}") def get_response(self) -> AgentResponse: """エージェントのレスポンスを取得する関数""" try: resp = self.result except AttributeError: - return {"error": "エージェントの実行結果がありません。"} + return AgentResponse(error="エージェントの実行結果がありません。") else: return resp @@ -327,6 +328,6 @@ def display_agent_prompt(self) -> None: async def main() -> None: async for output in agent.stream("magic function に3", 5): - print(output) + print(output.output) asyncio.run(main()) From 108c1ad33ea609264c7e79c7cb7532c8ef281bd7 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 14:21:18 +0000 Subject: [PATCH 16/31] =?UTF-8?q?calling=5Fagent=E3=81=8Cstr=E3=82=92?= =?UTF-8?q?=E8=BF=94=E5=8D=B4=E3=81=99=E3=82=8B=E3=82=88=E3=81=86=E3=81=AB?= =?UTF-8?q?=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/template/calling_agent.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/sc_system_ai/template/calling_agent.py b/src/sc_system_ai/template/calling_agent.py index 9d8d66f..418735e 100644 --- a/src/sc_system_ai/template/calling_agent.py +++ b/src/sc_system_ai/template/calling_agent.py @@ -1,5 +1,5 @@ import logging -from typing import Any, cast +from typing import cast from langchain_core.tools import BaseTool from pydantic import BaseModel, ConfigDict, Field @@ -57,7 +57,7 @@ def __init__(self) -> None: def _run( self, user_input: str, - ) -> dict[str, Any]: + ) -> str: logger.info(f"Calling Agent Toolが次の値で呼び出されました: {user_input}") # エージェントの呼び出し @@ -70,8 +70,9 @@ def _run( logger.debug(f"エージェントの呼び出しに成功しました: {self.agent}") resp = agent.invoke(user_input) - - return cast(dict[str, Any], resp) + if resp.error is not None: + return resp.error + return cast(str, resp.output) def set_user_info(self, user_info: User) -> None: """ユーザー情報の設定 From 02bf6b1b8793e70601286e0b49a107d1e00899e2 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 14:36:44 +0000 Subject: [PATCH 17/31] =?UTF-8?q?=E5=91=BC=E3=81=B3=E5=87=BA=E3=81=97?= =?UTF-8?q?=E3=83=84=E3=83=BC=E3=83=AB=E3=81=A7=E3=83=AC=E3=82=B9=E3=83=9D?= =?UTF-8?q?=E3=83=B3=E3=82=B9=E3=82=92=E4=BF=9D=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/template/calling_agent.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/sc_system_ai/template/calling_agent.py b/src/sc_system_ai/template/calling_agent.py index 418735e..da91a4c 100644 --- a/src/sc_system_ai/template/calling_agent.py +++ b/src/sc_system_ai/template/calling_agent.py @@ -4,7 +4,7 @@ from langchain_core.tools import BaseTool from pydantic import BaseModel, ConfigDict, Field -from sc_system_ai.template.agent import Agent +from sc_system_ai.template.agent import Agent, AgentResponse from sc_system_ai.template.user_prompts import User logger = logging.getLogger(__name__) @@ -50,6 +50,8 @@ def __init__(self): user_info: User = Field(description="ユーザー情報", default=User()) agent: type[Agent] = Agent + # AgentResponseを保持する変数 + response: AgentResponse | None = None def __init__(self) -> None: super().__init__() @@ -70,6 +72,7 @@ def _run( logger.debug(f"エージェントの呼び出しに成功しました: {self.agent}") resp = agent.invoke(user_input) + self.response = resp if resp.error is not None: return resp.error return cast(str, resp.output) From 4ce05557dc32fa3bf5c67bf2da7bb6b5df73eedd Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 14:44:24 +0000 Subject: [PATCH 18/31] =?UTF-8?q?=E6=A4=9C=E7=B4=A2=E3=82=A8=E3=83=BC?= =?UTF-8?q?=E3=82=B8=E3=82=A7=E3=83=B3=E3=83=88=E3=81=A7=E3=83=AC=E3=82=B9?= =?UTF-8?q?=E3=83=9D=E3=83=B3=E3=82=B9=E3=81=AE=E5=9E=8B=E3=82=92=E9=81=A9?= =?UTF-8?q?=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../agents/search_school_data_agent.py | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/src/sc_system_ai/agents/search_school_data_agent.py b/src/sc_system_ai/agents/search_school_data_agent.py index 5ec64cb..e544439 100644 --- a/src/sc_system_ai/agents/search_school_data_agent.py +++ b/src/sc_system_ai/agents/search_school_data_agent.py @@ -4,7 +4,7 @@ # from sc_system_ai.agents.tools import magic_function from sc_system_ai.agents.tools.search_school_data import search_school_database_cosmos -from sc_system_ai.template.agent import Agent, AgentResponse +from sc_system_ai.template.agent import Agent, AgentResponse, StreamingAgentResponse from sc_system_ai.template.ai_settings import llm from sc_system_ai.template.user_prompts import User @@ -18,9 +18,6 @@ ## 学校の情報 """ -class SearchSchoolDataAgentResponse(AgentResponse): - document_id: list[str] - # agentクラスの作成 class SearchSchoolDataAgent(Agent): @@ -44,24 +41,18 @@ def _add_search_result(self, message: str) -> list[str]: super().set_assistant_info(self.assistant_info) return ids - def invoke(self, message: str) -> SearchSchoolDataAgentResponse: + def invoke(self, message: str) -> AgentResponse: # Agentクラスのストリーミングを改修後にストリーミング実装 ids = self._add_search_result(message) resp = super().invoke(message) - return SearchSchoolDataAgentResponse( - document_id=ids, - **resp - ) + resp.document_id = ids + return resp - async def stream(self, message: str, return_length: int = 5) -> AsyncIterator[AgentResponse]: + async def stream(self, message: str, return_length: int = 5) -> AsyncIterator[StreamingAgentResponse]: ids = self._add_search_result(message) async for resp in super().stream(message, return_length): yield resp - result = self.get_response() - self.result = SearchSchoolDataAgentResponse( - document_id=ids, - **result - ) + self.result.document_id = ids if __name__ == "__main__": from sc_system_ai.logging_config import setup_logging From fca62add95439c033e3a1a282be7f723017c76a0 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 14:44:56 +0000 Subject: [PATCH 19/31] =?UTF-8?q?document=5Fid=E3=82=92=E4=BF=9D=E6=8C=81?= =?UTF-8?q?=E3=81=99=E3=82=8B=E5=A4=89=E6=95=B0=E3=82=92=E8=BF=BD=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../agents/tools/calling_search_school_data_agent.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/sc_system_ai/agents/tools/calling_search_school_data_agent.py b/src/sc_system_ai/agents/tools/calling_search_school_data_agent.py index 2ce47f1..c4aaaf1 100644 --- a/src/sc_system_ai/agents/tools/calling_search_school_data_agent.py +++ b/src/sc_system_ai/agents/tools/calling_search_school_data_agent.py @@ -10,6 +10,9 @@ class CallingSearchSchoolDataAgent(CallingAgent): + # tool側でidを保持する + document_id: list[str] | None = None + def __init__(self) -> None: super().__init__() self.set_tool_info( @@ -18,6 +21,11 @@ def __init__(self) -> None: agent=SearchSchoolDataAgent ) + def _run(self, user_input: str) -> str: + resp = super()._run(user_input) + self.document_id = self.response.document_id if self.response is not None else None + return resp + calling_search_school_data_agent = CallingSearchSchoolDataAgent() if __name__ == "__main__": From 398d58321440b2ed043c56e6df7b23a24d551f87 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 14:45:32 +0000 Subject: [PATCH 20/31] =?UTF-8?q?=E3=83=84=E3=83=BC=E3=83=AB=E3=81=8B?= =?UTF-8?q?=E3=82=89=E7=9B=B4=E6=8E=A5id=E3=82=92=E5=8F=96=E5=BE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/agents/classify_agent.py | 33 ++++++++++++++--------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/src/sc_system_ai/agents/classify_agent.py b/src/sc_system_ai/agents/classify_agent.py index 32a882e..8074c4d 100644 --- a/src/sc_system_ai/agents/classify_agent.py +++ b/src/sc_system_ai/agents/classify_agent.py @@ -3,9 +3,11 @@ from langchain_openai import AzureChatOpenAI # from sc_system_ai.agents.tools import magic_function -from sc_system_ai.agents.search_school_data_agent import SearchSchoolDataAgentResponse from sc_system_ai.agents.tools.calling_dummy_agent import calling_dummy_agent -from sc_system_ai.agents.tools.calling_search_school_data_agent import calling_search_school_data_agent +from sc_system_ai.agents.tools.calling_search_school_data_agent import ( + CallingSearchSchoolDataAgent, + calling_search_school_data_agent, +) from sc_system_ai.agents.tools.classify_role import classify_role from sc_system_ai.template.agent import Agent, AgentResponse from sc_system_ai.template.ai_settings import llm @@ -20,6 +22,7 @@ ] classify_agent_info = """あなたの役割は適切なエージェントを選択し処理を引き継ぐことです。 +あなたがユーザーと会話を行ってはいけません。 ユーザーの入力、会話の流れから適切なエージェントを選択してください。 引き継いだエージェントが処理を完了するまで、そのエージェントがユーザーと会話を続けるようにしてください。 @@ -50,15 +53,21 @@ def set_tools(self, tools: list) -> None: tool.set_user_info(self.user_info) super().set_tools(tools) - def invoke(self, message: str) -> AgentResponse | SearchSchoolDataAgentResponse: - # toolの出力を整形 + def invoke(self, message: str) -> AgentResponse: + # toolの出力がAgentReaponseで返って来るので整形 resp = super().invoke(message) - if type(resp["output"]) is str: - return resp - else: - return cast( - AgentResponse | SearchSchoolDataAgentResponse, resp["output"] - ) + resp.document_id = self._doc_id_checker() + return resp + + def _doc_id_checker(self) -> list[str] | None: + """ + ドキュメントIDが存在するか確認する + """ + for tool in self.tool.tools: + if isinstance(tool, CallingSearchSchoolDataAgent): + if tool.document_id is not None: + return tool.document_id + return None if __name__ == "__main__": @@ -93,9 +102,9 @@ def invoke(self, message: str) -> AgentResponse | SearchSchoolDataAgentResponse: # print(output) # resp = classify_agent.get_response() - if type(resp) is dict: + if type(resp) is AgentResponse: new_conversation = [ ("human", user), - ("ai", resp["output"]) + ("ai", cast(str,resp.output)) ] user_info.conversations.add_conversations_list(new_conversation) From c59c63afd23e7ff6e26ccdabfcc039e5a6e54315 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 14:49:12 +0000 Subject: [PATCH 21/31] =?UTF-8?q?=E3=83=87=E3=83=90=E3=83=83=E3=82=B0?= =?UTF-8?q?=E7=94=A8=E3=81=AE=E3=82=B3=E3=83=BC=E3=83=89=E3=82=92=E4=BF=AE?= =?UTF-8?q?=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/agents/dummy_agent.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/sc_system_ai/agents/dummy_agent.py b/src/sc_system_ai/agents/dummy_agent.py index b17b432..93c7363 100644 --- a/src/sc_system_ai/agents/dummy_agent.py +++ b/src/sc_system_ai/agents/dummy_agent.py @@ -1,4 +1,5 @@ # ダミーのエージェント +from typing import cast from langchain_openai import AzureChatOpenAI @@ -92,9 +93,9 @@ def __init__( print(output) resp = dummy_agent.get_response() - if type(resp) is dict: + if resp.error is not None: new_conversation = [ ("human", user), - ("ai", resp["output"]) + ("ai", cast(str, resp.output)) ] user_info.conversations.add_conversations_list(new_conversation) From 270016147d330692c910c7dfa8a51d10b97a6906 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 15:05:45 +0000 Subject: [PATCH 22/31] =?UTF-8?q?=E9=9B=91=E8=AB=87=E3=82=A8=E3=83=BC?= =?UTF-8?q?=E3=82=B8=E3=82=A7=E3=83=B3=E3=83=88=E3=81=AE=E4=BD=9C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/agents/small_talk_agent.py | 47 +++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 src/sc_system_ai/agents/small_talk_agent.py diff --git a/src/sc_system_ai/agents/small_talk_agent.py b/src/sc_system_ai/agents/small_talk_agent.py new file mode 100644 index 0000000..47541fa --- /dev/null +++ b/src/sc_system_ai/agents/small_talk_agent.py @@ -0,0 +1,47 @@ +from langchain_openai import AzureChatOpenAI + +from sc_system_ai.agents.tools import magic_function +from sc_system_ai.template.agent import Agent +from sc_system_ai.template.ai_settings import llm +from sc_system_ai.template.user_prompts import User + +main_agent_tools = [magic_function] +main_agent_info = "あなたの役割はユーザーと雑談を行うことです。" + +# agentクラスの作成 + + +class SmallTalkAgent(Agent): + def __init__( + self, + llm: AzureChatOpenAI = llm, + user_info: User | None = None, + ): + super().__init__( + llm=llm, + user_info=user_info if user_info is not None else User(), + ) + self.assistant_info = main_agent_info + super().set_assistant_info(self.assistant_info) + super().set_tools(main_agent_tools) + + +if __name__ == "__main__": + from sc_system_ai.logging_config import setup_logging + setup_logging() + # ユーザー情報 + user_name = "hogehoge" + user_major = "fugafuga専攻" + history = [ + ("human", "こんにちは!"), + ("ai", "本日はどのようなご用件でしょうか?") + ] + user_info = User(name=user_name, major=user_major) + user_info.conversations.add_conversations_list(history) + + agent = SmallTalkAgent(user_info=user_info) + agent.display_agent_info() + # print(main_agent.get_agent_prompt()) + agent.display_agent_prompt() + print(agent.invoke("magic function に3をいれて")) + From dec228126731ee3f421381af7466ba3ce27bbbd3 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 15:07:49 +0000 Subject: [PATCH 23/31] =?UTF-8?q?=E3=83=97=E3=83=AD=E3=83=B3=E3=83=97?= =?UTF-8?q?=E3=83=88=E3=81=AE=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/agents/small_talk_agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sc_system_ai/agents/small_talk_agent.py b/src/sc_system_ai/agents/small_talk_agent.py index 47541fa..38b5897 100644 --- a/src/sc_system_ai/agents/small_talk_agent.py +++ b/src/sc_system_ai/agents/small_talk_agent.py @@ -6,7 +6,9 @@ from sc_system_ai.template.user_prompts import User main_agent_tools = [magic_function] -main_agent_info = "あなたの役割はユーザーと雑談を行うことです。" +main_agent_info = """あなたの役割はユーザーと雑談を行うことです。 +ユーザーが楽しめるような会話になるようにしてください。 +""" # agentクラスの作成 From 1d5a30e13f5b075ded9d3c1fe96992911e3a5be6 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 15:10:34 +0000 Subject: [PATCH 24/31] =?UTF-8?q?=E9=9B=91=E8=AB=87=E3=82=A8=E3=83=BC?= =?UTF-8?q?=E3=82=B8=E3=82=A7=E3=83=B3=E3=83=88=E3=81=AE=E5=91=BC=E3=81=B3?= =?UTF-8?q?=E5=87=BA=E3=81=97=E3=83=84=E3=83=BC=E3=83=AB=E3=82=92=E4=BD=9C?= =?UTF-8?q?=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../agents/tools/calling_small_talk_agent.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 src/sc_system_ai/agents/tools/calling_small_talk_agent.py diff --git a/src/sc_system_ai/agents/tools/calling_small_talk_agent.py b/src/sc_system_ai/agents/tools/calling_small_talk_agent.py new file mode 100644 index 0000000..888eef4 --- /dev/null +++ b/src/sc_system_ai/agents/tools/calling_small_talk_agent.py @@ -0,0 +1,26 @@ +import logging + +from sc_system_ai.agents.small_talk_agent import SmallTalkAgent +from sc_system_ai.template.calling_agent import CallingAgent +from sc_system_ai.template.user_prompts import User + +logger = logging.getLogger(__name__) + + +class CallingSmallTalkAgent(CallingAgent): + def __init__(self) -> None: + super().__init__() + self.set_tool_info( + name="calling_small_talk_agent", + description="雑談エージェントを呼び出すツール", + agent=SmallTalkAgent, + ) + +calling_small_talk_agent = CallingSmallTalkAgent() + +if __name__ == "__main__": + from sc_system_ai.logging_config import setup_logging + setup_logging() + + calling_small_talk_agent.set_user_info(User(name="hogehoge", major="fugafuga専攻")) + print(calling_small_talk_agent.invoke({"user_input": "こんにちは"})) From 37c3b72a48d56228006c82f35880eae3e9e1739e Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 15:12:44 +0000 Subject: [PATCH 25/31] =?UTF-8?q?=E5=88=86=E9=A1=9E=E3=82=A8=E3=83=BC?= =?UTF-8?q?=E3=82=B8=E3=82=A7=E3=83=B3=E3=83=88=E3=81=AE=E5=91=BC=E3=81=B3?= =?UTF-8?q?=E5=87=BA=E3=81=97=E3=83=84=E3=83=BC=E3=83=AB=E3=81=AB=E8=BF=BD?= =?UTF-8?q?=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/agents/classify_agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sc_system_ai/agents/classify_agent.py b/src/sc_system_ai/agents/classify_agent.py index 8074c4d..46045c8 100644 --- a/src/sc_system_ai/agents/classify_agent.py +++ b/src/sc_system_ai/agents/classify_agent.py @@ -8,6 +8,7 @@ CallingSearchSchoolDataAgent, calling_search_school_data_agent, ) +from sc_system_ai.agents.tools.calling_small_talk_agent import calling_small_talk_agent from sc_system_ai.agents.tools.classify_role import classify_role from sc_system_ai.template.agent import Agent, AgentResponse from sc_system_ai.template.ai_settings import llm @@ -18,7 +19,8 @@ # magic_function, classify_role, calling_dummy_agent, - calling_search_school_data_agent + calling_search_school_data_agent, + calling_small_talk_agent, ] classify_agent_info = """あなたの役割は適切なエージェントを選択し処理を引き継ぐことです。 From 55648f68359d6f7f17746660d910889d3f98eb47 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 15:14:21 +0000 Subject: [PATCH 26/31] =?UTF-8?q?=E3=82=A8=E3=83=B3=E3=83=88=E3=83=AA?= =?UTF-8?q?=E3=83=BC=E3=83=9D=E3=82=A4=E3=83=B3=E3=83=88=E3=81=AB=E8=BF=BD?= =?UTF-8?q?=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sc_system_ai/main.py b/src/sc_system_ai/main.py index 68d869e..a48c0b0 100644 --- a/src/sc_system_ai/main.py +++ b/src/sc_system_ai/main.py @@ -68,7 +68,7 @@ logger = logging.getLogger(__name__) -AGENT = Literal["classify", "dummy", "search_school_data"] +AGENT = Literal["classify", "dummy", "search_school_data", "small_talk"] class Response(TypedDict): output: str | None From e7761403ee4ff80ed9620f9a2fbdc70dd41bb9dc Mon Sep 17 00:00:00 2001 From: haruki26 Date: Mon, 27 Jan 2025 15:22:50 +0000 Subject: [PATCH 27/31] =?UTF-8?q?chat=E3=82=AF=E3=83=A9=E3=82=B9=E3=81=A7A?= =?UTF-8?q?gent=E3=82=AF=E3=83=A9=E3=82=B9=E3=81=AE=E3=83=AC=E3=82=B9?= =?UTF-8?q?=E3=83=9D=E3=83=B3=E3=82=B9=E5=A4=89=E6=9B=B4=E3=81=AB=E5=AF=BE?= =?UTF-8?q?=E5=BF=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/main.py | 45 ++++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/src/sc_system_ai/main.py b/src/sc_system_ai/main.py index a48c0b0..c83447d 100644 --- a/src/sc_system_ai/main.py +++ b/src/sc_system_ai/main.py @@ -60,7 +60,7 @@ import logging from collections.abc import AsyncIterator from importlib import import_module -from typing import Literal, TypedDict, cast +from typing import Literal, TypedDict from sc_system_ai.template.agent import Agent from sc_system_ai.template.ai_settings import llm @@ -75,6 +75,11 @@ class Response(TypedDict): error: str | None document_id: list[str] | None +class StreamResponse(TypedDict): + output: str | None + error: str | None + status: str | None + class Chat: """Chatクラス ユーザー情報と会話履歴を保持し、エージェントとのチャットを行うクラス @@ -160,14 +165,19 @@ def invoke( - dummy: ダミーエージェント """ self._call_agent(command) - return self._create_response(cast(dict, self.agent.invoke(message))) + resp = self.agent.invoke(message) + return { + "output": resp.output, + "error": resp.error, + "document_id": resp.document_id + } async def stream( self, message: str, return_length: int = 5, command: AGENT = "classify" - ) -> AsyncIterator[Response]: + ) -> AsyncIterator[StreamResponse]: """エージェントを呼び出し、ストリーミングチャットを行う関数 Args: @@ -190,7 +200,11 @@ async def stream( """ self._call_agent(command) async for resp in self.agent.stream(message, return_length): - yield self._create_response(cast(dict, resp)) + yield { + "output": resp.output, + "error": resp.error, + "status": resp.status + } def _call_agent(self, command: AGENT) -> None: try: @@ -207,14 +221,6 @@ def _call_agent(self, command: AGENT) -> None: logger.error(f"エージェントが見つかりません: {command}") raise ValueError(f"エージェントが見つかりません: {command}") from None - def _create_response(self, resp: dict) -> Response: - return { - "output": resp.get("output"), - "error": resp.get("error"), - "document_id": resp.get("document_id") - } - - def static_chat() -> None: # ユーザー情報 @@ -262,6 +268,8 @@ async def streaming_chat() -> None: if __name__ == "__main__": + import asyncio + from sc_system_ai.logging_config import setup_logging setup_logging() @@ -284,11 +292,12 @@ async def streaming_chat() -> None: # pass # # 通常呼び出し - resp = chat.invoke(message=message, command="dummy") - print(resp) + # resp = chat.invoke(message=message, command="dummy") + # print(resp) # ストリーミング呼び出し - # chat.is_streaming = True - # for r in chat.invoke(message=message, command="dummy"): - # print(r) - # chat.agent.get_response() + async def stream() -> None: + async for r in chat.stream(message=message, command="dummy"): + print(r) + asyncio.run(stream()) + From 2d2d9ce1ee07baba1231963966dea624599429ed Mon Sep 17 00:00:00 2001 From: haruki26 Date: Tue, 28 Jan 2025 13:11:10 +0000 Subject: [PATCH 28/31] =?UTF-8?q?=E3=83=84=E3=83=BC=E3=83=AB=E4=B8=8A?= =?UTF-8?q?=E3=81=A7=E3=82=B9=E3=83=88=E3=83=AA=E3=83=BC=E3=83=9F=E3=83=B3?= =?UTF-8?q?=E3=82=B0=E3=82=92=E8=A1=8C=E3=81=86=E3=83=A1=E3=82=BD=E3=83=83?= =?UTF-8?q?=E3=83=89=E3=82=92=E4=BD=9C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/template/agent.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/sc_system_ai/template/agent.py b/src/sc_system_ai/template/agent.py index e763dfe..7a79ff3 100644 --- a/src/sc_system_ai/template/agent.py +++ b/src/sc_system_ai/template/agent.py @@ -267,6 +267,31 @@ def _invoke(self, message: str, streaming: bool) -> None: logger.error(f"エージェントの実行に失敗しました。エラー内容: {e}") self.result = AgentResponse(error=f"エージェントの実行に失敗しました。エラー内容: {e}") + async def stream_on_tool(self, message: str) -> None: + """ツール上でストリーミングでエージェントを実行する関数""" + self.handler.queue = self.queue + self.setup_streaming() + agent = create_tool_calling_agent( + llm=self.llm, + tools=self.tool.tools, + prompt=self.prompt_template.full_prompt + ) + agent_executor = AgentExecutor( + agent=agent, + tools=self.tool.tools, + callbacks= [self.handler], + ) + resp = await agent_executor.ainvoke({ + "chat_history": self.user_info.conversations.format_conversation(), + "messages": message, + }) + self.result = AgentResponse( + chat_history=resp.get("chat_history"), + messages=resp.get("messages"), + output=resp.get("output"), + ) + + def get_response(self) -> AgentResponse: """エージェントのレスポンスを取得する関数""" try: From fba162aed3fcccdc8bd5c715adba5237163baa2e Mon Sep 17 00:00:00 2001 From: haruki26 Date: Tue, 28 Jan 2025 13:21:55 +0000 Subject: [PATCH 29/31] =?UTF-8?q?=E5=91=BC=E3=81=B3=E5=87=BA=E3=81=97?= =?UTF-8?q?=E3=83=84=E3=83=BC=E3=83=AB=E3=81=A7=E3=82=B9=E3=83=88=E3=83=AA?= =?UTF-8?q?=E3=83=BC=E3=83=9F=E3=83=B3=E3=82=B0=E3=82=92=E5=8F=AF=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/template/calling_agent.py | 30 +++++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/sc_system_ai/template/calling_agent.py b/src/sc_system_ai/template/calling_agent.py index da91a4c..cc5603d 100644 --- a/src/sc_system_ai/template/calling_agent.py +++ b/src/sc_system_ai/template/calling_agent.py @@ -1,4 +1,6 @@ +import asyncio import logging +from queue import Queue from typing import cast from langchain_core.tools import BaseTool @@ -53,6 +55,10 @@ def __init__(self): # AgentResponseを保持する変数 response: AgentResponse | None = None + # ストリーミングのセットアップ + queue: Queue = Queue() + is_streaming: bool = False + def __init__(self) -> None: super().__init__() @@ -65,18 +71,24 @@ def _run( # エージェントの呼び出し try: agent = self.agent(user_info=self.user_info) + agent.queue = self.queue except Exception as e: logger.error(f"エージェントの呼び出しに失敗しました: {e}") raise e else: logger.debug(f"エージェントの呼び出しに成功しました: {self.agent}") - resp = agent.invoke(user_input) - self.response = resp - if resp.error is not None: - return resp.error + if self.is_streaming: + asyncio.run(agent.stream_on_tool(user_input)) + resp = agent.get_response() + else: + resp = agent.invoke(user_input) + self.response = resp + if resp.error is not None: + return resp.error return cast(str, resp.output) + def set_user_info(self, user_info: User) -> None: """ユーザー情報の設定 @@ -106,6 +118,16 @@ def set_tool_info( self.description = description self.agent = agent + def setup_streaming(self, queue: Queue) -> None: + """ストリーミングのセットアップ""" + self.is_streaming = True + self.queue = queue + + def cancel_streaming(self) -> None: + """ストリーミングのキャンセル""" + self.is_streaming = False + self.queue = Queue() + calling_agent = CallingAgent() From 13e52db1f86a2c11dcb88632b30f08d237e2053a Mon Sep 17 00:00:00 2001 From: haruki26 Date: Tue, 28 Jan 2025 13:23:03 +0000 Subject: [PATCH 30/31] =?UTF-8?q?=E5=88=86=E9=A1=9E=E3=82=A8=E3=83=BC?= =?UTF-8?q?=E3=82=B8=E3=82=A7=E3=83=B3=E3=83=88=E3=81=A7=E3=83=84=E3=83=BC?= =?UTF-8?q?=E3=83=AB=E3=82=B9=E3=83=88=E3=83=AA=E3=83=BC=E3=83=9F=E3=83=B3?= =?UTF-8?q?=E3=82=B0=E3=81=AE=E8=A8=AD=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/agents/classify_agent.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/sc_system_ai/agents/classify_agent.py b/src/sc_system_ai/agents/classify_agent.py index 46045c8..f03aac5 100644 --- a/src/sc_system_ai/agents/classify_agent.py +++ b/src/sc_system_ai/agents/classify_agent.py @@ -1,3 +1,4 @@ +from collections.abc import AsyncIterator from typing import cast from langchain_openai import AzureChatOpenAI @@ -10,7 +11,7 @@ ) from sc_system_ai.agents.tools.calling_small_talk_agent import calling_small_talk_agent from sc_system_ai.agents.tools.classify_role import classify_role -from sc_system_ai.template.agent import Agent, AgentResponse +from sc_system_ai.template.agent import Agent, AgentResponse, StreamingAgentResponse from sc_system_ai.template.ai_settings import llm from sc_system_ai.template.calling_agent import CallingAgent from sc_system_ai.template.user_prompts import User @@ -27,8 +28,6 @@ あなたがユーザーと会話を行ってはいけません。 ユーザーの入力、会話の流れから適切なエージェントを選択してください。 引き継いだエージェントが処理を完了するまで、そのエージェントがユーザーと会話を続けるようにしてください。 - -適切なエージェントの選択、呼び出しができなかった場合は、そのままユーザーとの会話を続けてください。 """ # agentクラスの作成 @@ -57,6 +56,9 @@ def set_tools(self, tools: list) -> None: def invoke(self, message: str) -> AgentResponse: # toolの出力がAgentReaponseで返って来るので整形 + for tool in self.tool.tools: + if isinstance(tool, CallingAgent): + tool.cancel_streaming() resp = super().invoke(message) resp.document_id = self._doc_id_checker() return resp @@ -71,6 +73,13 @@ def _doc_id_checker(self) -> list[str] | None: return tool.document_id return None + async def stream(self, message: str, return_length: int = 5) -> AsyncIterator[StreamingAgentResponse]: + for tool in self.tool.tools: + if isinstance(tool, CallingAgent): + tool.setup_streaming(self.queue) + async for output in super().stream(message, return_length): + yield output + if __name__ == "__main__": from sc_system_ai.logging_config import setup_logging From 02043b10ba751cdc96f28b4fa337b82cbf94be1e Mon Sep 17 00:00:00 2001 From: haruki26 Date: Tue, 28 Jan 2025 13:23:45 +0000 Subject: [PATCH 31/31] =?UTF-8?q?=E3=82=B9=E3=83=88=E3=83=AA=E3=83=BC?= =?UTF-8?q?=E3=83=9F=E3=83=B3=E3=82=B0=E3=81=AE=E3=83=86=E3=82=B9=E3=83=88?= =?UTF-8?q?=E3=82=B3=E3=83=BC=E3=83=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sc_system_ai/main.py b/src/sc_system_ai/main.py index c83447d..aa0deb8 100644 --- a/src/sc_system_ai/main.py +++ b/src/sc_system_ai/main.py @@ -297,7 +297,7 @@ async def streaming_chat() -> None: # ストリーミング呼び出し async def stream() -> None: - async for r in chat.stream(message=message, command="dummy"): + async for r in chat.stream(message="京都テックについて教えて"): print(r) asyncio.run(stream())