Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/cosmosdb crud #146

Merged
merged 14 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions src/sc_system_ai/agents/tools/search_school_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ def search_school_database_cosmos(search_word: str, top_k: int = 2) -> list[Docu
"""学校に関する情報を検索する関数(現在のデータベースを参照)"""
cosmos_manager = CosmosDBManager()
docs = cosmos_manager.similarity_search(search_word, k=top_k)

for doc in docs:
source = cosmos_manager.get_source_by_id(doc.metadata["id"])
doc.metadata["source"] = source
return docs


Expand All @@ -54,13 +50,12 @@ def _run(
"""use the tool."""
logger.info(f"Search School Data Toolが次の値で呼び出されました: {search_word}")
result = search_school_database_cosmos(search_word)
i = 1
search_result = []
for doc in result:
for i, doc in enumerate(result):
if hasattr(doc, 'page_content'):
search_result.append(
f'・検索結果{i}は以下の通りです。\n{doc.page_content}\n参考URL: "{doc.metadata["source"]}"\n\n')
i += 1
f'・検索結果{i + 1}は以下の通りです。\n{doc.page_content}\n参考URL: "{doc.metadata["id"]}"\n\n'
)
return search_result


Expand Down
4 changes: 4 additions & 0 deletions src/sc_system_ai/logging_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,9 @@ def setup_logging() -> None:
package_logger = logging.getLogger("sc_system_ai")
package_logger.setLevel(logging.DEBUG)

# azure.coreのログメッセージをWARNING以上で出力する
azure_logger = logging.getLogger("azure.core")
azure_logger.setLevel(logging.WARNING)

# langchainのログメッセージを出力する
set_verbose(True)
104 changes: 86 additions & 18 deletions src/sc_system_ai/template/azure_cosmos.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from typing import Any
from datetime import datetime
from typing import Any, Literal, cast

from azure.cosmos import CosmosClient, PartitionKey
from dotenv import load_dotenv
Expand All @@ -11,6 +12,7 @@
from langchain_core.embeddings import Embeddings

from sc_system_ai.template.ai_settings import embeddings
from sc_system_ai.template.document_formatter import md_formatter, text_formatter

load_dotenv()

Expand Down Expand Up @@ -79,36 +81,95 @@ def __init__(
create_container=create_container,
)

def create_document(
self,
text: str,
text_type: Literal["markdown", "plain"] = "markdown"
) -> list[str]:
"""データベースに新しいdocumentを作成する関数"""
logger.info("新しいdocumentを作成します")
texts, metadatas = self._division_document(
md_formatter(text) if text_type == "markdown" else text_formatter(text)
)
ids = self._insert_texts(texts, metadatas)
return ids

def _division_document(
self,
documents: list[Document]
) -> tuple[list[str], list[dict[str, Any]]]:
"""documentを分割する関数"""
docs, metadata = [], []
for doc in documents:
docs.append(doc.page_content)
metadata.append(doc.metadata)
return docs, metadata

def update_document(
self,
id: str,
text: str,
) -> str:
"""データベースのdocumentを更新する関数"""
logger.info("documentを更新します")

# metadataのupdated_atを更新
query = "SELECT c.metadata FROM c WHERE c.id = @id"
parameters = [{"name": "@id", "value": id}]

try:
item = self._container.query_items(
query=query,
parameters=cast(list[dict[str, Any]], parameters), # mypyがエラー吐くのでキャスト
enable_cross_partition_query=True
).next()
except StopIteration:
logger.error(f"{id=}のdocumentが見つかりませんでした")
return "documentが見つかりませんでした"

metadata = item["metadata"]
metadata["updated_at"] = datetime.now().strftime("%Y-%m-%d")

to_upsert = {
"id": id,
"text": text,
self._embedding_key: self._embedding.embed_documents([text])[0],
"metadata": metadata,
}
self._container.upsert_item(body=to_upsert)
return id

def read_all_documents(self) -> list[Document]:
"""全てのdocumentsを読み込む関数"""
"""全てのdocumentsとIDを読み込む関数"""
logger.info("全てのdocumentsを読み込みます")
query = "SELECT c.id, c.text FROM c"
items = list(self._container.query_items(
query=query, enable_cross_partition_query=True))
docs = []
i = 1
query=query, enable_cross_partition_query=True)
)
docs: list[Document] = []
for item in items:
text = item["text"]
item["number"] = i
i += 1
_id = item["id"]
docs.append(
Document(page_content=text, metadata=item))
logger.debug(f"{docs[0].page_content=}, \n\nlength: {len(docs)}")
Document(page_content=text, metadata={"id": _id})
)
return docs

def get_source_by_id(self, id: str) -> str:
"""idを指定してsourceを取得する関数"""
logger.info(f"{id=}のsourceを取得します")
item = self._container.read_item(item=id, partition_key=id)
query = "SELECT c.text FROM c WHERE c.id = " + f"'{id}'"
item = self._container.query_items(
query=query, enable_cross_partition_query=True
).next()

result = item.get("source")
result = item["text"]
if type(result) is str:
return result
else:
return "sourceが見つかりませんでした"



if __name__ == "__main__":
from sc_system_ai.logging_config import setup_logging
setup_logging()
Expand All @@ -118,9 +179,16 @@ def get_source_by_id(self, id: str) -> str:
# results = cosmos_manager.read_all_documents()
results = cosmos_manager.similarity_search(query, k=1)
print(results[0])

# idで指定したドキュメントのsourceを取得
ids = results[0].metadata["id"]
print(f"{ids=}")
doc = cosmos_manager.get_source_by_id(ids)
print(doc)
print(results[0].metadata["id"])

# # idで指定したドキュメントのsourceを取得
# ids = results[0].metadata["id"]
# print(f"{ids=}")
# doc = cosmos_manager.get_source_by_id(ids)
# print(doc)

# # documentを更新
# text = """ストリーミングレスポンスに対応するためにジェネレータとして定義されています。
# エージェントが回答の生成を終えてからレスポンスを受け取ることも可能です。"""
# _id = "c55bb571-498a-4db9-9da0-e9e35d46906b"
# print(cosmos_manager.update_document(_id, text))
Loading
Loading