Skip to content

Commit

Permalink
Chatクラスからの呼び出し処理を作成
Browse files Browse the repository at this point in the history
  • Loading branch information
haruki26 committed Jan 21, 2025
1 parent 1c84f2e commit f44f1d9
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions src/sc_system_ai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
import logging
from collections.abc import Iterator
from importlib import import_module
from typing import Literal, TypedDict
from typing import Literal, TypedDict, cast

from sc_system_ai.template.agent import Agent
from sc_system_ai.template.ai_settings import llm
Expand Down Expand Up @@ -144,7 +144,7 @@ def invoke(
self,
message: str,
command: AGENT = "classify"
) -> Iterator[str]:
) -> Iterator[Response]:
"""エージェントを呼び出し、チャットを行う関数
Args:
Expand All @@ -167,20 +167,15 @@ def invoke(
if self.is_streaming:
for resp in self.agent.invoke(message):
if type(resp) is str:
yield resp
yield self._create_response({"output": resp})
else:
resp = next(self.agent.invoke(message))

if type(resp) is dict:
if "error" in resp:
yield resp["error"]
else:
yield resp["output"]
yield self._create_response(cast(dict, resp))

def _call_agent(self, command: AGENT) -> None:
try:
module_name = f"sc_system_ai.agents.{command}_agent"
class_name = f"{command.capitalize()}Agent"
class_name = "".join([cn.capitalize() for cn in command.split("_")]) + "Agent"
module = import_module(module_name)
agent_class = getattr(module, class_name)

Expand All @@ -194,6 +189,13 @@ def _call_agent(self, command: AGENT) -> None:
logger.error(f"エージェントが見つかりません: {command}")
raise ValueError(f"エージェントが見つかりません: {command}") from None

def _create_response(self, resp: dict) -> Response:
return {
"output": resp.get("output"),
"error": resp.get("error"),
"document_id": resp.get("document_id")
}



def static_chat() -> None:
Expand Down Expand Up @@ -257,15 +259,15 @@ def streaming_chat() -> None:
],
is_streaming=False,
)
message = "京都テックについて教えて"
message = "私の名前と専攻は何ですか?"

# try:
# resp = chat.agent.get_response()
# except Exception:
# pass

# # 通常呼び出し
resp = next(chat.invoke(message=message, command="search_school_data"))
resp = next(chat.invoke(message=message, command="dummy"))
print(resp)

# ストリーミング呼び出し
Expand Down

0 comments on commit f44f1d9

Please sign in to comment.