Skip to content

Commit

Permalink
Merge pull request #172 from KTC-Security-Circle/feature/search-agent
Browse files Browse the repository at this point in the history
Feature/search agent
  • Loading branch information
snow7y authored Jan 21, 2025
2 parents 3aa3ce4 + 4973cab commit 165b7e5
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 50 deletions.
34 changes: 24 additions & 10 deletions src/sc_system_ai/agents/classify_agent.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from langchain_openai import AzureChatOpenAI
from collections.abc import Iterator
from typing import Any, cast

from sc_system_ai.agents.tools.calling_dummy_agent import calling_dummy_agent
from langchain_openai import AzureChatOpenAI

# from sc_system_ai.agents.tools import magic_function
from sc_system_ai.agents.search_school_data_agent import SearchSchoolDataAgentResponse
from sc_system_ai.agents.tools.calling_dummy_agent import calling_dummy_agent
from sc_system_ai.agents.tools.calling_search_school_data_agent import calling_search_school_data_agent
from sc_system_ai.agents.tools.classify_role import classify_role
from sc_system_ai.template.agent import Agent
from sc_system_ai.template.agent import Agent, AgentResponse
from sc_system_ai.template.ai_settings import llm
from sc_system_ai.template.calling_agent import CallingAgent
from sc_system_ai.template.user_prompts import User

classify_agent_tools = [
# magic_function,
classify_role,
calling_dummy_agent
calling_dummy_agent,
calling_search_school_data_agent
]

classify_agent_info = """あなたの役割は適切なエージェントを選択し処理を引き継ぐことです。
Expand Down Expand Up @@ -51,6 +56,15 @@ def set_tools(self, tools: list) -> None:

super().set_tools(tools)

def invoke(self, message: str) -> Iterator[str | AgentResponse | SearchSchoolDataAgentResponse]:
if self.is_streaming:
yield from super().invoke(message)
else:
# ツールの出力をそのまま返却
resp = cast(dict[str, Any], next(super().invoke(message)))
yield resp["output"]


if __name__ == "__main__":
from sc_system_ai.logging_config import setup_logging
setup_logging()
Expand All @@ -65,7 +79,7 @@ def set_tools(self, tools: list) -> None:
user_info.conversations.add_conversations_list(history)

while True:
classify_agent = ClassifyAgent(user_info=user_info)
classify_agent = ClassifyAgent(user_info=user_info, is_streaming=False)
# classify_agent.display_agent_info()
# print(main_agent.get_agent_prompt())
# classify_agent.display_agent_prompt()
Expand All @@ -75,13 +89,13 @@ def set_tools(self, tools: list) -> None:
break

# 通常の呼び出し
# resp = classify_agent.invoke(user)
# print(resp)
resp = next(classify_agent.invoke(user))
print(resp)

# ストリーミング呼び出し
for output in classify_agent.invoke(user):
print(output)
resp = classify_agent.get_response()
# for output in classify_agent.invoke(user):
# print(output)
# resp = classify_agent.get_response()

if type(resp) is dict:
new_conversation = [
Expand Down
72 changes: 72 additions & 0 deletions src/sc_system_ai/agents/search_school_data_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from collections.abc import Iterator
from typing import cast

from langchain_openai import AzureChatOpenAI

# from sc_system_ai.agents.tools import magic_function
from sc_system_ai.agents.tools.search_school_data import search_school_database_cosmos
from sc_system_ai.template.agent import Agent, AgentResponse
from sc_system_ai.template.ai_settings import llm
from sc_system_ai.template.user_prompts import User

# search_school_data_agent_tools = [
# # magic_function,
# ]

search_school_data_agent_info = """あなたの役割は学校の情報をもとにユーザーの質問に回答することです。
以下に学校の情報について示します。
## 学校の情報
"""

class SearchSchoolDataAgentResponse(AgentResponse):
document_id: list[str]

# agentクラスの作成

class SearchSchoolDataAgent(Agent):
def __init__(
self,
llm: AzureChatOpenAI = llm,
user_info: User | None = None,
is_streaming: bool = True,
return_length: int = 5
):
super().__init__(
llm=llm,
user_info=user_info if user_info is not None else User(),
is_streaming=is_streaming,
return_length=return_length
)
self.assistant_info = search_school_data_agent_info

def invoke(self, message: str) -> Iterator[SearchSchoolDataAgentResponse]:
# Agentクラスのストリーミングを改修後にストリーミング実装
self.cancel_streaming()
search = search_school_database_cosmos(message)
ids = []
for doc in search:
self.assistant_info += f"### {doc.metadata['title']}\n" + doc.page_content + "\n"
ids.append(doc.metadata["id"])
super().set_assistant_info(self.assistant_info)

resp = cast(AgentResponse, next(super().invoke(message)))
yield {
**resp,
"document_id": ids
}

if __name__ == "__main__":
from sc_system_ai.logging_config import setup_logging
setup_logging()
# ユーザー情報
user_name = "hogehoge"
user_major = "fugafuga専攻"
history = [
("human", "こんにちは!"),
("ai", "本日はどのようなご用件でしょうか?")
]
user_info = User(name=user_name, major=user_major)
user_info.conversations.add_conversations_list(history)
agent = SearchSchoolDataAgent(user_info=user_info, is_streaming=False)
print(next(agent.invoke("京都テックについて教えて")))
28 changes: 28 additions & 0 deletions src/sc_system_ai/agents/tools/calling_search_school_data_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# dummyAgentの呼び出しを行うツール

import logging

from sc_system_ai.agents.search_school_data_agent import SearchSchoolDataAgent
from sc_system_ai.template.calling_agent import CallingAgent
from sc_system_ai.template.user_prompts import User

logger = logging.getLogger(__name__)


class CallingSearchSchoolDataAgent(CallingAgent):
def __init__(self) -> None:
super().__init__()
self.set_tool_info(
name="calling_search_school_data_agent",
description="学校情報を検索するエージェントを呼び出すツール",
agent=SearchSchoolDataAgent
)

calling_search_school_data_agent = CallingSearchSchoolDataAgent()

if __name__ == "__main__":
from sc_system_ai.logging_config import setup_logging
setup_logging()

calling_search_school_data_agent.set_user_info(User(name="hogehoge", major="fugafuga専攻"))
print(calling_search_school_data_agent.invoke({"user_input": "京都テックについて教えて"}))
4 changes: 4 additions & 0 deletions src/sc_system_ai/agents/tools/classify_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Output(BaseModel):
"遅延届",
"早退届",
"公欠届",
"学校情報検索"
]
similarity_score: float = Field(ge=0.0, le=1.0)

Expand Down Expand Up @@ -88,6 +89,9 @@ def check_same_word(
"遅延届",
"早退届",
"公欠届",
],
"学校情報検索": [
"学校情報の検索"
]
}

Expand Down
49 changes: 28 additions & 21 deletions src/sc_system_ai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,20 @@
import logging
from collections.abc import Iterator
from importlib import import_module
from typing import Literal
from typing import Literal, TypedDict, cast

from sc_system_ai.template.agent import Agent
from sc_system_ai.template.ai_settings import llm
from sc_system_ai.template.user_prompts import User

logger = logging.getLogger(__name__)

AGENT = Literal["classify", "dummy"]
AGENT = Literal["classify", "dummy", "search_school_data"]

class Response(TypedDict):
output: str | None
error: str | None
document_id: list[str] | None

class Chat:
"""Chatクラス
Expand Down Expand Up @@ -139,7 +144,7 @@ def invoke(
self,
message: str,
command: AGENT = "classify"
) -> Iterator[str]:
) -> Iterator[Response]:
"""エージェントを呼び出し、チャットを行う関数
Args:
Expand All @@ -162,20 +167,15 @@ def invoke(
if self.is_streaming:
for resp in self.agent.invoke(message):
if type(resp) is str:
yield resp
yield self._create_response({"output": resp})
else:
resp = next(self.agent.invoke(message))

if type(resp) is dict:
if "error" in resp:
yield resp["error"]
else:
yield resp["output"]
yield self._create_response(cast(dict, resp))

def _call_agent(self, command: AGENT) -> None:
try:
module_name = f"sc_system_ai.agents.{command}_agent"
class_name = f"{command.capitalize()}Agent"
class_name = "".join([cn.capitalize() for cn in command.split("_")]) + "Agent"
module = import_module(module_name)
agent_class = getattr(module, class_name)

Expand All @@ -189,6 +189,13 @@ def _call_agent(self, command: AGENT) -> None:
logger.error(f"エージェントが見つかりません: {command}")
raise ValueError(f"エージェントが見つかりません: {command}") from None

def _create_response(self, resp: dict) -> Response:
return {
"output": resp.get("output"),
"error": resp.get("error"),
"document_id": resp.get("document_id")
}



def static_chat() -> None:
Expand Down Expand Up @@ -254,17 +261,17 @@ def streaming_chat() -> None:
)
message = "私の名前と専攻は何ですか?"

try:
resp = chat.agent.get_response()
except Exception:
pass
# try:
# resp = chat.agent.get_response()
# except Exception:
# pass

# # 通常呼び出し
# resp = next(chat.invoke(message=message, command="dummy"))
# print(resp)
resp = next(chat.invoke(message=message, command="dummy"))
print(resp)

# ストリーミング呼び出し
chat.is_streaming = True
for r in chat.invoke(message=message, command="dummy"):
print(r)
chat.agent.get_response()
# chat.is_streaming = True
# for r in chat.invoke(message=message, command="dummy"):
# print(r)
# chat.agent.get_response()
9 changes: 5 additions & 4 deletions src/sc_system_ai/template/azure_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def update_document(
if text is not None:
if text_type is None:
raise TypeError("textを更新する際はtext_typeを指定してください。")
result = self._update_text(
id, text, text_type, item["metadata"].get("group_id", None)
result = self._text_updater(
id, text, text_type, metadata, item["metadata"].get("group_id", None)
)

if any([title, metadata, del_metadata]):
Expand Down Expand Up @@ -261,11 +261,12 @@ def _create_patch(
})
return patch

def _update_text(
def _text_updater(
self,
id: str,
text: str,
text_type: Literal["markdown", "plain"],
metadata: dict[str, Any] | None = None,
group_id: str | None = None,
) -> list[str]:
"""textを更新する関数"""
Expand All @@ -277,7 +278,7 @@ def _update_text(
for d in data:
self.delete_document_by_id(d["id"])

ids = self.create_document(text, text_type)
ids = self.create_document(text, text_type, metadata=metadata)
patch = [{
"op": "replace",
"path": "/metadata/created_at",
Expand Down
18 changes: 3 additions & 15 deletions src/sc_system_ai/template/calling_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any
from typing import Any, cast

from langchain_core.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(self) -> None:
def _run(
self,
user_input: str,
) -> str:
) -> dict[str, Any]:
logger.info(f"Calling Agent Toolが次の値で呼び出されました: {user_input}")

# エージェントの呼び出し
Expand All @@ -71,19 +71,7 @@ def _run(

resp = next(agent.invoke(user_input))

return self._type_checker(resp)

def _type_checker(self, response: Any) -> str:
"""レスポンスの型チェック"""
resp = ""
if type(response) is dict:
if "output" in response:
resp = response["output"]
elif type(response) is str:
resp = response

return resp

return cast(dict[str, Any], resp)

def set_user_info(self, user_info: User) -> None:
"""ユーザー情報の設定
Expand Down

0 comments on commit 165b7e5

Please sign in to comment.