From 4973cab60c6f772a47b2d00c44fac1380880e255 Mon Sep 17 00:00:00 2001 From: haruki26 Date: Tue, 21 Jan 2025 16:09:00 +0000 Subject: [PATCH] =?UTF-8?q?=E5=88=86=E9=A1=9E=E3=82=A8=E3=83=BC=E3=82=B8?= =?UTF-8?q?=E3=82=A7=E3=83=B3=E3=83=88=E3=81=AE=E5=87=BA=E5=8A=9B=E3=81=AE?= =?UTF-8?q?=E3=83=95=E3=82=A9=E3=83=BC=E3=83=9E=E3=83=83=E3=83=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sc_system_ai/agents/classify_agent.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/sc_system_ai/agents/classify_agent.py b/src/sc_system_ai/agents/classify_agent.py index 20dd308..a10d377 100644 --- a/src/sc_system_ai/agents/classify_agent.py +++ b/src/sc_system_ai/agents/classify_agent.py @@ -1,11 +1,14 @@ +from collections.abc import Iterator +from typing import Any, cast + from langchain_openai import AzureChatOpenAI +# from sc_system_ai.agents.tools import magic_function +from sc_system_ai.agents.search_school_data_agent import SearchSchoolDataAgentResponse from sc_system_ai.agents.tools.calling_dummy_agent import calling_dummy_agent from sc_system_ai.agents.tools.calling_search_school_data_agent import calling_search_school_data_agent - -# from sc_system_ai.agents.tools import magic_function from sc_system_ai.agents.tools.classify_role import classify_role -from sc_system_ai.template.agent import Agent +from sc_system_ai.template.agent import Agent, AgentResponse from sc_system_ai.template.ai_settings import llm from sc_system_ai.template.calling_agent import CallingAgent from sc_system_ai.template.user_prompts import User @@ -53,6 +56,15 @@ def set_tools(self, tools: list) -> None: super().set_tools(tools) + def invoke(self, message: str) -> Iterator[str | AgentResponse | SearchSchoolDataAgentResponse]: + if self.is_streaming: + yield from super().invoke(message) + else: + # ツールの出力をそのまま返却 + resp = cast(dict[str, Any], next(super().invoke(message))) + yield resp["output"] + + if __name__ == "__main__": from sc_system_ai.logging_config import setup_logging setup_logging() @@ -67,7 +79,7 @@ def set_tools(self, tools: list) -> None: user_info.conversations.add_conversations_list(history) while True: - classify_agent = ClassifyAgent(user_info=user_info) + classify_agent = ClassifyAgent(user_info=user_info, is_streaming=False) # classify_agent.display_agent_info() # print(main_agent.get_agent_prompt()) # classify_agent.display_agent_prompt()