diff --git a/src/sc_system_ai/agents/classify_agent.py b/src/sc_system_ai/agents/classify_agent.py index 20dd308..a10d377 100644 --- a/src/sc_system_ai/agents/classify_agent.py +++ b/src/sc_system_ai/agents/classify_agent.py @@ -1,11 +1,14 @@ +from collections.abc import Iterator +from typing import Any, cast + from langchain_openai import AzureChatOpenAI +# from sc_system_ai.agents.tools import magic_function +from sc_system_ai.agents.search_school_data_agent import SearchSchoolDataAgentResponse from sc_system_ai.agents.tools.calling_dummy_agent import calling_dummy_agent from sc_system_ai.agents.tools.calling_search_school_data_agent import calling_search_school_data_agent - -# from sc_system_ai.agents.tools import magic_function from sc_system_ai.agents.tools.classify_role import classify_role -from sc_system_ai.template.agent import Agent +from sc_system_ai.template.agent import Agent, AgentResponse from sc_system_ai.template.ai_settings import llm from sc_system_ai.template.calling_agent import CallingAgent from sc_system_ai.template.user_prompts import User @@ -53,6 +56,15 @@ def set_tools(self, tools: list) -> None: super().set_tools(tools) + def invoke(self, message: str) -> Iterator[str | AgentResponse | SearchSchoolDataAgentResponse]: + if self.is_streaming: + yield from super().invoke(message) + else: + # ツールの出力をそのまま返却 + resp = cast(dict[str, Any], next(super().invoke(message))) + yield resp["output"] + + if __name__ == "__main__": from sc_system_ai.logging_config import setup_logging setup_logging() @@ -67,7 +79,7 @@ def set_tools(self, tools: list) -> None: user_info.conversations.add_conversations_list(history) while True: - classify_agent = ClassifyAgent(user_info=user_info) + classify_agent = ClassifyAgent(user_info=user_info, is_streaming=False) # classify_agent.display_agent_info() # print(main_agent.get_agent_prompt()) # classify_agent.display_agent_prompt()