Skip to content

Commit

Permalink
Merge pull request #174 from KTC-Security-Circle/fix/response
Browse files Browse the repository at this point in the history
Fix/response
  • Loading branch information
snow7y authored Jan 28, 2025
2 parents 869c8ef + 02043b1 commit 7c0e34d
Show file tree
Hide file tree
Showing 12 changed files with 398 additions and 252 deletions.
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ click==8.1.7 ; python_version >= "3.10" and python_version < "4.0"
colorama==0.4.6 ; python_version >= "3.10" and python_version < "4.0" and platform_system == "Windows"
dataclasses-json==0.6.7 ; python_version >= "3.10" and python_version < "4.0"
distro==1.9.0 ; python_version >= "3.10" and python_version < "4.0"
duckduckgo-search==6.3.7 ; python_version >= "3.10" and python_version < "4.0"
duckduckgo-search==7.2.1 ; python_version >= "3.10" and python_version < "4.0"
exceptiongroup==1.2.2 ; python_version >= "3.10" and python_version < "3.11"
frozenlist==1.4.1 ; python_version >= "3.10" and python_version < "4.0"
greenlet==3.1.1 ; python_version < "3.13" and (platform_machine == "aarch64" or platform_machine == "ppc64le" or platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "win32" or platform_machine == "WIN32") and python_version >= "3.10"
Expand All @@ -28,17 +28,18 @@ jsonpointer==3.0.0 ; python_version >= "3.10" and python_version < "4.0"
langchain-community==0.3.13 ; python_version >= "3.10" and python_version < "4.0"
langchain-core==0.3.28 ; python_version >= "3.10" and python_version < "4.0"
langchain-openai==0.2.14 ; python_version >= "3.10" and python_version < "4.0"
langchain-text-splitters==0.3.5 ; python_version >= "3.10" and python_version < "4.0"
langchain-text-splitters==0.3.4 ; python_version >= "3.10" and python_version < "4.0"
langchain==0.3.13 ; python_version >= "3.10" and python_version < "4.0"
langsmith==0.1.147 ; python_version >= "3.10" and python_version < "4.0"
lxml==5.3.0 ; python_version >= "3.10" and python_version < "4.0"
marshmallow==3.22.0 ; python_version >= "3.10" and python_version < "4.0"
multidict==6.1.0 ; python_version >= "3.10" and python_version < "4.0"
mypy-extensions==1.0.0 ; python_version >= "3.10" and python_version < "4.0"
numpy==1.26.4 ; python_version >= "3.10" and python_version < "4.0"
openai==1.58.1 ; python_version >= "3.10" and python_version < "4.0"
orjson==3.10.7 ; python_version >= "3.10" and python_version < "4.0" and platform_python_implementation != "PyPy"
packaging==24.1 ; python_version >= "3.10" and python_version < "4.0"
primp==0.8.1 ; python_version >= "3.10" and python_version < "4.0"
primp==0.10.0 ; python_version >= "3.10" and python_version < "4.0"
propcache==0.2.0 ; python_version >= "3.10" and python_version < "4.0"
pydantic-core==2.23.4 ; python_version >= "3.10" and python_version < "4.0"
pydantic-settings==2.5.2 ; python_version >= "3.10" and python_version < "4.0"
Expand Down
64 changes: 40 additions & 24 deletions src/sc_system_ai/agents/classify_agent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from collections.abc import Iterator
from typing import Any, cast
from collections.abc import AsyncIterator
from typing import cast

from langchain_openai import AzureChatOpenAI

# from sc_system_ai.agents.tools import magic_function
from sc_system_ai.agents.search_school_data_agent import SearchSchoolDataAgentResponse
from sc_system_ai.agents.tools.calling_dummy_agent import calling_dummy_agent
from sc_system_ai.agents.tools.calling_search_school_data_agent import calling_search_school_data_agent
from sc_system_ai.agents.tools.calling_search_school_data_agent import (
CallingSearchSchoolDataAgent,
calling_search_school_data_agent,
)
from sc_system_ai.agents.tools.calling_small_talk_agent import calling_small_talk_agent
from sc_system_ai.agents.tools.classify_role import classify_role
from sc_system_ai.template.agent import Agent, AgentResponse
from sc_system_ai.template.agent import Agent, AgentResponse, StreamingAgentResponse
from sc_system_ai.template.ai_settings import llm
from sc_system_ai.template.calling_agent import CallingAgent
from sc_system_ai.template.user_prompts import User
Expand All @@ -17,14 +20,14 @@
# magic_function,
classify_role,
calling_dummy_agent,
calling_search_school_data_agent
calling_search_school_data_agent,
calling_small_talk_agent,
]

classify_agent_info = """あなたの役割は適切なエージェントを選択し処理を引き継ぐことです。
あなたがユーザーと会話を行ってはいけません。
ユーザーの入力、会話の流れから適切なエージェントを選択してください。
引き継いだエージェントが処理を完了するまで、そのエージェントがユーザーと会話を続けるようにしてください。
適切なエージェントの選択、呼び出しができなかった場合は、そのままユーザーとの会話を続けてください。
"""

# agentクラスの作成
Expand All @@ -35,14 +38,10 @@ def __init__(
self,
llm: AzureChatOpenAI = llm,
user_info: User | None = None,
is_streaming: bool = True,
return_length: int = 5
):
super().__init__(
llm=llm,
user_info=user_info if user_info is not None else User(),
is_streaming=is_streaming,
return_length=return_length
)
self.assistant_info = classify_agent_info
super().set_assistant_info(self.assistant_info)
Expand All @@ -53,16 +52,33 @@ def set_tools(self, tools: list) -> None:
for tool in tools:
if isinstance(tool, CallingAgent):
tool.set_user_info(self.user_info)

super().set_tools(tools)

def invoke(self, message: str) -> Iterator[str | AgentResponse | SearchSchoolDataAgentResponse]:
if self.is_streaming:
yield from super().invoke(message)
else:
# ツールの出力をそのまま返却
resp = cast(dict[str, Any], next(super().invoke(message)))
yield resp["output"]
def invoke(self, message: str) -> AgentResponse:
# toolの出力がAgentReaponseで返って来るので整形
for tool in self.tool.tools:
if isinstance(tool, CallingAgent):
tool.cancel_streaming()
resp = super().invoke(message)
resp.document_id = self._doc_id_checker()
return resp

def _doc_id_checker(self) -> list[str] | None:
"""
ドキュメントIDが存在するか確認する
"""
for tool in self.tool.tools:
if isinstance(tool, CallingSearchSchoolDataAgent):
if tool.document_id is not None:
return tool.document_id
return None

async def stream(self, message: str, return_length: int = 5) -> AsyncIterator[StreamingAgentResponse]:
for tool in self.tool.tools:
if isinstance(tool, CallingAgent):
tool.setup_streaming(self.queue)
async for output in super().stream(message, return_length):
yield output


if __name__ == "__main__":
Expand All @@ -79,7 +95,7 @@ def invoke(self, message: str) -> Iterator[str | AgentResponse | SearchSchoolDat
user_info.conversations.add_conversations_list(history)

while True:
classify_agent = ClassifyAgent(user_info=user_info, is_streaming=False)
classify_agent = ClassifyAgent(user_info=user_info)
# classify_agent.display_agent_info()
# print(main_agent.get_agent_prompt())
# classify_agent.display_agent_prompt()
Expand All @@ -89,17 +105,17 @@ def invoke(self, message: str) -> Iterator[str | AgentResponse | SearchSchoolDat
break

# 通常の呼び出し
resp = next(classify_agent.invoke(user))
resp = classify_agent.invoke(user)
print(resp)

# ストリーミング呼び出し
# for output in classify_agent.invoke(user):
# print(output)
# resp = classify_agent.get_response()

if type(resp) is dict:
if type(resp) is AgentResponse:
new_conversation = [
("human", user),
("ai", resp["output"])
("ai", cast(str,resp.output))
]
user_info.conversations.add_conversations_list(new_conversation)
9 changes: 3 additions & 6 deletions src/sc_system_ai/agents/dummy_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# ダミーのエージェント
from typing import cast

from langchain_openai import AzureChatOpenAI

Expand Down Expand Up @@ -49,14 +50,10 @@ def __init__(
self,
llm: AzureChatOpenAI = llm,
user_info: User | None = None,
is_streaming: bool = True,
return_length: int = 5
):
super().__init__(
llm=llm,
user_info=user_info if user_info is not None else User(),
is_streaming=is_streaming,
return_length=return_length
)
self.assistant_info = dummy_agent_info
super().set_assistant_info(self.assistant_info)
Expand Down Expand Up @@ -96,9 +93,9 @@ def __init__(
print(output)
resp = dummy_agent.get_response()

if type(resp) is dict:
if resp.error is not None:
new_conversation = [
("human", user),
("ai", resp["output"])
("ai", cast(str, resp.output))
]
user_info.conversations.add_conversations_list(new_conversation)
8 changes: 2 additions & 6 deletions src/sc_system_ai/agents/main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,10 @@ def __init__(
self,
llm: AzureChatOpenAI = llm,
user_info: User | None = None,
is_streaming: bool = True,
return_length: int = 5
):
super().__init__(
llm=llm,
user_info=user_info if user_info is not None else User(),
is_streaming=is_streaming,
return_length=return_length
)
self.assistant_info = main_agent_info
super().set_assistant_info(self.assistant_info)
Expand All @@ -43,9 +39,9 @@ def __init__(
user_info = User(name=user_name, major=user_major)
user_info.conversations.add_conversations_list(history)

main_agent = MainAgent(user_info=user_info, is_streaming=False)
main_agent = MainAgent(user_info=user_info)
main_agent.display_agent_info()
# print(main_agent.get_agent_prompt())
main_agent.display_agent_prompt()
print(next(main_agent.invoke("magic function に3をいれて")))
print(main_agent.invoke("magic function に3をいれて"))

38 changes: 18 additions & 20 deletions src/sc_system_ai/agents/search_school_data_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from collections.abc import Iterator
from typing import cast
from collections.abc import AsyncIterator

from langchain_openai import AzureChatOpenAI

# from sc_system_ai.agents.tools import magic_function
from sc_system_ai.agents.tools.search_school_data import search_school_database_cosmos
from sc_system_ai.template.agent import Agent, AgentResponse
from sc_system_ai.template.agent import Agent, AgentResponse, StreamingAgentResponse
from sc_system_ai.template.ai_settings import llm
from sc_system_ai.template.user_prompts import User

Expand All @@ -19,42 +18,41 @@
## 学校の情報
"""

class SearchSchoolDataAgentResponse(AgentResponse):
document_id: list[str]

# agentクラスの作成

class SearchSchoolDataAgent(Agent):
def __init__(
self,
llm: AzureChatOpenAI = llm,
user_info: User | None = None,
is_streaming: bool = True,
return_length: int = 5
):
super().__init__(
llm=llm,
user_info=user_info if user_info is not None else User(),
is_streaming=is_streaming,
return_length=return_length
)
self.assistant_info = search_school_data_agent_info

def invoke(self, message: str) -> Iterator[SearchSchoolDataAgentResponse]:
# Agentクラスのストリーミングを改修後にストリーミング実装
self.cancel_streaming()
def _add_search_result(self, message: str) -> list[str]:
search = search_school_database_cosmos(message)
ids = []
for doc in search:
self.assistant_info += f"### {doc.metadata['title']}\n" + doc.page_content + "\n"
ids.append(doc.metadata["id"])
super().set_assistant_info(self.assistant_info)
return ids

def invoke(self, message: str) -> AgentResponse:
# Agentクラスのストリーミングを改修後にストリーミング実装
ids = self._add_search_result(message)
resp = super().invoke(message)
resp.document_id = ids
return resp

resp = cast(AgentResponse, next(super().invoke(message)))
yield {
**resp,
"document_id": ids
}
async def stream(self, message: str, return_length: int = 5) -> AsyncIterator[StreamingAgentResponse]:
ids = self._add_search_result(message)
async for resp in super().stream(message, return_length):
yield resp
self.result.document_id = ids

if __name__ == "__main__":
from sc_system_ai.logging_config import setup_logging
Expand All @@ -68,5 +66,5 @@ def invoke(self, message: str) -> Iterator[SearchSchoolDataAgentResponse]:
]
user_info = User(name=user_name, major=user_major)
user_info.conversations.add_conversations_list(history)
agent = SearchSchoolDataAgent(user_info=user_info, is_streaming=False)
print(next(agent.invoke("京都テックについて教えて")))
agent = SearchSchoolDataAgent(user_info=user_info)
print(agent.invoke("京都テックについて教えて"))
49 changes: 49 additions & 0 deletions src/sc_system_ai/agents/small_talk_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from langchain_openai import AzureChatOpenAI

from sc_system_ai.agents.tools import magic_function
from sc_system_ai.template.agent import Agent
from sc_system_ai.template.ai_settings import llm
from sc_system_ai.template.user_prompts import User

main_agent_tools = [magic_function]
main_agent_info = """あなたの役割はユーザーと雑談を行うことです。
ユーザーが楽しめるような会話になるようにしてください。
"""

# agentクラスの作成


class SmallTalkAgent(Agent):
def __init__(
self,
llm: AzureChatOpenAI = llm,
user_info: User | None = None,
):
super().__init__(
llm=llm,
user_info=user_info if user_info is not None else User(),
)
self.assistant_info = main_agent_info
super().set_assistant_info(self.assistant_info)
super().set_tools(main_agent_tools)


if __name__ == "__main__":
from sc_system_ai.logging_config import setup_logging
setup_logging()
# ユーザー情報
user_name = "hogehoge"
user_major = "fugafuga専攻"
history = [
("human", "こんにちは!"),
("ai", "本日はどのようなご用件でしょうか?")
]
user_info = User(name=user_name, major=user_major)
user_info.conversations.add_conversations_list(history)

agent = SmallTalkAgent(user_info=user_info)
agent.display_agent_info()
# print(main_agent.get_agent_prompt())
agent.display_agent_prompt()
print(agent.invoke("magic function に3をいれて"))

Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@


class CallingSearchSchoolDataAgent(CallingAgent):
# tool側でidを保持する
document_id: list[str] | None = None

def __init__(self) -> None:
super().__init__()
self.set_tool_info(
Expand All @@ -18,6 +21,11 @@ def __init__(self) -> None:
agent=SearchSchoolDataAgent
)

def _run(self, user_input: str) -> str:
resp = super()._run(user_input)
self.document_id = self.response.document_id if self.response is not None else None
return resp

calling_search_school_data_agent = CallingSearchSchoolDataAgent()

if __name__ == "__main__":
Expand Down
26 changes: 26 additions & 0 deletions src/sc_system_ai/agents/tools/calling_small_talk_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import logging

from sc_system_ai.agents.small_talk_agent import SmallTalkAgent
from sc_system_ai.template.calling_agent import CallingAgent
from sc_system_ai.template.user_prompts import User

logger = logging.getLogger(__name__)


class CallingSmallTalkAgent(CallingAgent):
def __init__(self) -> None:
super().__init__()
self.set_tool_info(
name="calling_small_talk_agent",
description="雑談エージェントを呼び出すツール",
agent=SmallTalkAgent,
)

calling_small_talk_agent = CallingSmallTalkAgent()

if __name__ == "__main__":
from sc_system_ai.logging_config import setup_logging
setup_logging()

calling_small_talk_agent.set_user_info(User(name="hogehoge", major="fugafuga専攻"))
print(calling_small_talk_agent.invoke({"user_input": "こんにちは"}))
Loading

0 comments on commit 7c0e34d

Please sign in to comment.