Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/cosmosdb crud #146

Merged
merged 14 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sc_system_ai/logging_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,9 @@ def setup_logging() -> None:
package_logger = logging.getLogger("sc_system_ai")
package_logger.setLevel(logging.DEBUG)

# azure.coreのログメッセージをWARNING以上で出力する
azure_logger = logging.getLogger("azure.core")
azure_logger.setLevel(logging.WARNING)

# langchainのログメッセージを出力する
set_verbose(True)
169 changes: 147 additions & 22 deletions src/sc_system_ai/template/azure_cosmos.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from typing import Any
from datetime import datetime
from typing import Any, Literal, cast

from azure.cosmos import CosmosClient, PartitionKey
from dotenv import load_dotenv
Expand All @@ -11,6 +12,7 @@
from langchain_core.embeddings import Embeddings

from sc_system_ai.template.ai_settings import embeddings
from sc_system_ai.template.document_formatter import md_formatter, text_formatter

load_dotenv()

Expand Down Expand Up @@ -79,48 +81,171 @@ def __init__(
create_container=create_container,
)

def _similarity_search_with_score(
self,
embeddings: list[float],
k: int = 1,
pre_filter: dict | None = None,
with_embedding: bool = False
) -> list[tuple[Document, float]]:
query = "SELECT "

# If limit_offset_clause is not specified, add TOP clause
if pre_filter is None or pre_filter.get("limit_offset_clause") is None:
query += "TOP @limit "

query += (
"c.id, c[@embeddingKey], c.text, c.metadata, "
"VectorDistance(c[@embeddingKey], @embeddings) AS SimilarityScore FROM c"
)

# Add where_clause if specified
if pre_filter is not None and pre_filter.get("where_clause") is not None:
query += " {}".format(pre_filter["where_clause"])

query += " ORDER BY VectorDistance(c[@embeddingKey], @embeddings)"

# Add limit_offset_clause if specified
if pre_filter is not None and pre_filter.get("limit_offset_clause") is not None:
query += " {}".format(pre_filter["limit_offset_clause"])
parameters = [
{"name": "@limit", "value": k},
{"name": "@embeddingKey", "value": self._embedding_key},
{"name": "@embeddings", "value": embeddings},
]

docs_and_scores = []

items = list(
self._container.query_items(
query=query, parameters=parameters, enable_cross_partition_query=True
)
)
for item in items:
text = item["text"]
metadata = item["metadata"]

# idをmetadataに追加
metadata["id"] = item["id"]

score = item["SimilarityScore"]
if with_embedding:
metadata[self._embedding_key] = item[self._embedding_key]
docs_and_scores.append(
(Document(page_content=text, metadata=metadata), score)
)
return docs_and_scores

def create_document(
self,
text: str,
text_type: Literal["markdown", "plain"] = "markdown"
) -> list[str]:
"""データベースに新しいdocumentを作成する関数"""
logger.info("新しいdocumentを作成します")
texts, metadatas = self._division_document(
md_formatter(text) if text_type == "markdown" else text_formatter(text)
)
ids = self._insert_texts(texts, metadatas)
return ids

def _division_document(
self,
documents: list[Document]
) -> tuple[list[str], list[dict[str, Any]]]:
"""documentを分割する関数"""
logger.info("documentを分割します")
docs = []
metadata = []
for doc in documents:
docs.append(doc.page_content)
metadata.append(doc.metadata)
return docs, metadata

def update_document(
self,
id: str,
text: str,
) -> str:
"""データベースのdocumentを更新する関数"""
logger.info("documentを更新します")

# metadataのupdated_atを更新
query = "SELECT c.metadata FROM c WHERE c.id = @id"
parameters = [{"name": "@id", "value": id}]

try:
item = self._container.query_items(
query=query,
parameters=cast(list[dict[str, Any]], parameters), # mypyがエラー吐くのでキャスト
enable_cross_partition_query=True
).next()
except StopIteration:
logger.error(f"{id=}のdocumentが見つかりませんでした")
return "documentが見つかりませんでした"

metadata = item["metadata"]
metadata["updated_at"] = datetime.now().strftime("%Y-%m-%d")

to_upsert = {
"id": id,
"text": text,
self._embedding_key: self._embedding.embed_documents([text])[0],
"metadata": metadata,
}
self._container.upsert_item(body=to_upsert)
return id

def read_all_documents(self) -> list[Document]:
"""全てのdocumentsを読み込む関数"""
"""全てのdocumentsとIDを読み込む関数"""
logger.info("全てのdocumentsを読み込みます")
query = "SELECT c.id, c.text FROM c"
items = list(self._container.query_items(
query=query, enable_cross_partition_query=True))
docs = []
i = 1
query=query, enable_cross_partition_query=True)
)
docs: list[Document] = []
for item in items:
text = item["text"]
item["number"] = i
i += 1
_id = item["id"]
docs.append(
Document(page_content=text, metadata=item))
logger.debug(f"{docs[0].page_content=}, \n\nlength: {len(docs)}")
Document(page_content=text, metadata={"id": _id})
)
return docs

def get_source_by_id(self, id: str) -> str:
"""idを指定してsourceを取得する関数"""
logger.info(f"{id=}のsourceを取得します")
item = self._container.read_item(item=id, partition_key=id)
query = "SELECT c.text FROM c WHERE c.id = " + f"'{id}'"
item = self._container.query_items(
query=query, enable_cross_partition_query=True
).next()

result = item.get("source")
result = item["text"]
if type(result) is str:
return result
else:
return "sourceが見つかりませんでした"



if __name__ == "__main__":
from sc_system_ai.logging_config import setup_logging
setup_logging()

cosmos_manager = CosmosDBManager()
query = "京都テック"
# results = cosmos_manager.read_all_documents()
results = cosmos_manager.similarity_search(query, k=1)
print(results[0])

# idで指定したドキュメントのsourceを取得
ids = results[0].metadata["id"]
print(f"{ids=}")
doc = cosmos_manager.get_source_by_id(ids)
print(doc)
# query = "京都テック"
# # results = cosmos_manager.read_all_documents()
# results = cosmos_manager.similarity_search(query, k=1)
# print(results[0])
# print(results[0].metadata["id"])

# # idで指定したドキュメントのsourceを取得
# ids = results[0].metadata["id"]
# print(f"{ids=}")
# doc = cosmos_manager.get_source_by_id(ids)
# print(doc)

# documentを更新
text = """ストリーミングレスポンスに対応するためにジェネレータとして定義されています。
エージェントが回答の生成を終えてからレスポンスを受け取ることも可能です。"""
_id = "c55bb571-498a-4db9-9da0-e9e35d46906b"
print(cosmos_manager.update_document(_id, text))
Loading
Loading