Skip to content

Commit

Permalink
分類エージェントの出力のフォーマット
Browse files Browse the repository at this point in the history
  • Loading branch information
haruki26 committed Jan 21, 2025
1 parent f44f1d9 commit 4973cab
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/sc_system_ai/agents/classify_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from collections.abc import Iterator
from typing import Any, cast

from langchain_openai import AzureChatOpenAI

# from sc_system_ai.agents.tools import magic_function
from sc_system_ai.agents.search_school_data_agent import SearchSchoolDataAgentResponse
from sc_system_ai.agents.tools.calling_dummy_agent import calling_dummy_agent
from sc_system_ai.agents.tools.calling_search_school_data_agent import calling_search_school_data_agent

# from sc_system_ai.agents.tools import magic_function
from sc_system_ai.agents.tools.classify_role import classify_role
from sc_system_ai.template.agent import Agent
from sc_system_ai.template.agent import Agent, AgentResponse
from sc_system_ai.template.ai_settings import llm
from sc_system_ai.template.calling_agent import CallingAgent
from sc_system_ai.template.user_prompts import User
Expand Down Expand Up @@ -53,6 +56,15 @@ def set_tools(self, tools: list) -> None:

super().set_tools(tools)

def invoke(self, message: str) -> Iterator[str | AgentResponse | SearchSchoolDataAgentResponse]:
if self.is_streaming:
yield from super().invoke(message)
else:
# ツールの出力をそのまま返却
resp = cast(dict[str, Any], next(super().invoke(message)))
yield resp["output"]


if __name__ == "__main__":
from sc_system_ai.logging_config import setup_logging
setup_logging()
Expand All @@ -67,7 +79,7 @@ def set_tools(self, tools: list) -> None:
user_info.conversations.add_conversations_list(history)

while True:
classify_agent = ClassifyAgent(user_info=user_info)
classify_agent = ClassifyAgent(user_info=user_info, is_streaming=False)
# classify_agent.display_agent_info()
# print(main_agent.get_agent_prompt())
# classify_agent.display_agent_prompt()
Expand Down

0 comments on commit 4973cab

Please sign in to comment.