diff --git a/src/sc_system_ai/agents/tools/search_school_data.py b/src/sc_system_ai/agents/tools/search_school_data.py index 1e98385..fdea2e1 100644 --- a/src/sc_system_ai/agents/tools/search_school_data.py +++ b/src/sc_system_ai/agents/tools/search_school_data.py @@ -31,10 +31,6 @@ def search_school_database_cosmos(search_word: str, top_k: int = 2) -> list[Docu """学校に関する情報を検索する関数(現在のデータベースを参照)""" cosmos_manager = CosmosDBManager() docs = cosmos_manager.similarity_search(search_word, k=top_k) - - for doc in docs: - source = cosmos_manager.get_source_by_id(doc.metadata["id"]) - doc.metadata["source"] = source return docs @@ -54,13 +50,12 @@ def _run( """use the tool.""" logger.info(f"Search School Data Toolが次の値で呼び出されました: {search_word}") result = search_school_database_cosmos(search_word) - i = 1 search_result = [] - for doc in result: + for i, doc in enumerate(result): if hasattr(doc, 'page_content'): search_result.append( - f'・検索結果{i}は以下の通りです。\n{doc.page_content}\n参考URL: "{doc.metadata["source"]}"\n\n') - i += 1 + f'・検索結果{i + 1}は以下の通りです。\n{doc.page_content}\n参考URL: "{doc.metadata["id"]}"\n\n' + ) return search_result diff --git a/src/sc_system_ai/logging_config.py b/src/sc_system_ai/logging_config.py index 082b2f2..52a8595 100644 --- a/src/sc_system_ai/logging_config.py +++ b/src/sc_system_ai/logging_config.py @@ -39,5 +39,9 @@ def setup_logging() -> None: package_logger = logging.getLogger("sc_system_ai") package_logger.setLevel(logging.DEBUG) + # azure.coreのログメッセージをWARNING以上で出力する + azure_logger = logging.getLogger("azure.core") + azure_logger.setLevel(logging.WARNING) + # langchainのログメッセージを出力する set_verbose(True) diff --git a/src/sc_system_ai/template/azure_cosmos.py b/src/sc_system_ai/template/azure_cosmos.py index 4911287..e8be466 100644 --- a/src/sc_system_ai/template/azure_cosmos.py +++ b/src/sc_system_ai/template/azure_cosmos.py @@ -1,6 +1,7 @@ import logging import os -from typing import Any +from datetime import datetime +from typing import Any, Literal, cast from azure.cosmos import CosmosClient, PartitionKey from dotenv import load_dotenv @@ -11,6 +12,7 @@ from langchain_core.embeddings import Embeddings from sc_system_ai.template.ai_settings import embeddings +from sc_system_ai.template.document_formatter import md_formatter, text_formatter load_dotenv() @@ -79,36 +81,95 @@ def __init__( create_container=create_container, ) + def create_document( + self, + text: str, + text_type: Literal["markdown", "plain"] = "markdown" + ) -> list[str]: + """データベースに新しいdocumentを作成する関数""" + logger.info("新しいdocumentを作成します") + texts, metadatas = self._division_document( + md_formatter(text) if text_type == "markdown" else text_formatter(text) + ) + ids = self._insert_texts(texts, metadatas) + return ids + + def _division_document( + self, + documents: list[Document] + ) -> tuple[list[str], list[dict[str, Any]]]: + """documentを分割する関数""" + docs, metadata = [], [] + for doc in documents: + docs.append(doc.page_content) + metadata.append(doc.metadata) + return docs, metadata + + def update_document( + self, + id: str, + text: str, + ) -> str: + """データベースのdocumentを更新する関数""" + logger.info("documentを更新します") + + # metadataのupdated_atを更新 + query = "SELECT c.metadata FROM c WHERE c.id = @id" + parameters = [{"name": "@id", "value": id}] + + try: + item = self._container.query_items( + query=query, + parameters=cast(list[dict[str, Any]], parameters), # mypyがエラー吐くのでキャスト + enable_cross_partition_query=True + ).next() + except StopIteration: + logger.error(f"{id=}のdocumentが見つかりませんでした") + return "documentが見つかりませんでした" + + metadata = item["metadata"] + metadata["updated_at"] = datetime.now().strftime("%Y-%m-%d") + + to_upsert = { + "id": id, + "text": text, + self._embedding_key: self._embedding.embed_documents([text])[0], + "metadata": metadata, + } + self._container.upsert_item(body=to_upsert) + return id + def read_all_documents(self) -> list[Document]: - """全てのdocumentsを読み込む関数""" + """全てのdocumentsとIDを読み込む関数""" logger.info("全てのdocumentsを読み込みます") query = "SELECT c.id, c.text FROM c" items = list(self._container.query_items( - query=query, enable_cross_partition_query=True)) - docs = [] - i = 1 + query=query, enable_cross_partition_query=True) + ) + docs: list[Document] = [] for item in items: text = item["text"] - item["number"] = i - i += 1 + _id = item["id"] docs.append( - Document(page_content=text, metadata=item)) - logger.debug(f"{docs[0].page_content=}, \n\nlength: {len(docs)}") + Document(page_content=text, metadata={"id": _id}) + ) return docs def get_source_by_id(self, id: str) -> str: """idを指定してsourceを取得する関数""" logger.info(f"{id=}のsourceを取得します") - item = self._container.read_item(item=id, partition_key=id) + query = "SELECT c.text FROM c WHERE c.id = " + f"'{id}'" + item = self._container.query_items( + query=query, enable_cross_partition_query=True + ).next() - result = item.get("source") + result = item["text"] if type(result) is str: return result else: return "sourceが見つかりませんでした" - if __name__ == "__main__": from sc_system_ai.logging_config import setup_logging setup_logging() @@ -118,9 +179,16 @@ def get_source_by_id(self, id: str) -> str: # results = cosmos_manager.read_all_documents() results = cosmos_manager.similarity_search(query, k=1) print(results[0]) - - # idで指定したドキュメントのsourceを取得 - ids = results[0].metadata["id"] - print(f"{ids=}") - doc = cosmos_manager.get_source_by_id(ids) - print(doc) + print(results[0].metadata["id"]) + + # # idで指定したドキュメントのsourceを取得 + # ids = results[0].metadata["id"] + # print(f"{ids=}") + # doc = cosmos_manager.get_source_by_id(ids) + # print(doc) + +# # documentを更新 +# text = """ストリーミングレスポンスに対応するためにジェネレータとして定義されています。 +# エージェントが回答の生成を終えてからレスポンスを受け取ることも可能です。""" +# _id = "c55bb571-498a-4db9-9da0-e9e35d46906b" +# print(cosmos_manager.update_document(_id, text)) diff --git a/src/sc_system_ai/template/document_formatter.py b/src/sc_system_ai/template/document_formatter.py new file mode 100644 index 0000000..2b27712 --- /dev/null +++ b/src/sc_system_ai/template/document_formatter.py @@ -0,0 +1,226 @@ +import re +from datetime import datetime +from typing import Any + +from langchain_core.documents import Document +from langchain_text_splitters import ( + CharacterTextSplitter, + MarkdownHeaderTextSplitter, + RecursiveCharacterTextSplitter, +) + +CHUNK_SIZE = 1000 +CHUNK_OVERLAP = 200 + +def _max_level(text: str) -> int: + """Markdownのヘッダーの最大レベルを返す関数""" + headers = re.findall(r"^#+", text, re.MULTILINE) + return max([len(h) for h in headers]) if headers else 0 + +def markdown_splitter( + text: str, +) -> list[Document]: + """Markdownをヘッダーで分割する関数""" + headers_to_split_on = [ + ("#" * (i + 1), f"Header {i + 1}") + for i in range(_max_level(text)) + ] + splitter = MarkdownHeaderTextSplitter( + headers_to_split_on, + return_each_line=True, + ) + return splitter.split_text(text) + +def _find_header(document: Document) -> str | None: + """ドキュメントのヘッダー名を返す関数""" + i = 0 + while True: + if document.metadata.get(f"Header {i + 1}") is None: + break + i += 1 + return document.metadata[f"Header {i}"] if i != 0 else None + +def recursive_document_splitter( + documents: list[Document], + chunk_size: int, + chunk_overlap: int, +) -> list[Document]: + """再帰的に分割する関数""" + splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + return splitter.split_documents(documents) + +def document_splitter( + documents: Document | list[Document], + separator: str = "\n\n", + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = CHUNK_OVERLAP, + ) -> list[Document]: + """Documentを分割する関数""" + _documents = documents if isinstance(documents, list) else [documents] + splitter = CharacterTextSplitter( + separator=separator, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + return splitter.split_documents(_documents) + +def character_splitter( + text: str, + separator: str = "\n\n", + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = CHUNK_OVERLAP, + ) -> list[Document]: + """文字列を分割する関数""" + character_splitter = CharacterTextSplitter( + separator=separator, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + splitted_text = character_splitter.split_text(text) + return character_splitter.create_documents(splitted_text) + +def add_metadata( + documents: list[Document], + title: str, + source: str | None = None, + with_timestamp: bool = True, + with_section_number: bool = False, + **kwargs: Any +) -> list[Document]: + """メタデータを追加する関数 + Args: + documents (list[Document]): ドキュメントのリスト + title (str): タイトル + source (str, optional): ソース. + with_timestamp (bool, optional): タイムスタンプの有無. Defaults to True. + with_section_number (bool, optional): セクション番号の有無. Defaults to False. + """ + i = 1 + date = datetime.now().strftime("%Y-%m-%d") + for doc in documents: + doc.metadata["title"] = title + + if source is not None and \ + doc.metadata.get("source") is None: + doc.metadata["source"] = source + + if with_timestamp and \ + doc.metadata.get("created_at") is None: + doc.metadata["created_at"] = date + doc.metadata["updated_at"] = date + + if with_section_number and \ + doc.metadata.get("section_number") is None: + doc.metadata["section_number"] = i + i += 1 + + for key, value in kwargs.items(): + doc.metadata[key] = value + + return documents + +def md_formatter( + text: str, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = CHUNK_OVERLAP, + **kwargs: Any +) -> list[Document]: + """Markdown形式のテキストをフォーマットする関数 + Args: + text (str): Markdown形式のテキスト + chunk_size (int, optional): 分割するサイズ. + chunk_overlap (int, optional): オーバーラップのサイズ. + + chunk_sizeを超えるテキストは再分割し、メタデータにセクション番号を付与します. + """ + formatted_docs: list[Document] = [] + for doc in markdown_splitter(text): + t = _find_header(doc) + if len(doc.page_content) > chunk_size: + rdocs = recursive_document_splitter( + [doc], + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + formatted_docs += add_metadata( + rdocs, + title=t if t is not None else rdocs[0].page_content, + with_section_number=True, + **kwargs + ) + else: + formatted_docs += add_metadata( + [doc], + title=t if t is not None else doc.page_content, + **kwargs + ) + + return formatted_docs + +def text_formatter( + text: str, + separator: str = "\n\n", + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = CHUNK_OVERLAP, + **kwargs: Any +) -> list[Document]: + """テキストをフォーマットする関数 + Args: + text (str): テキスト + separator (str, optional): 区切り文字. + chunk_size (int, optional): 分割するサイズ. + chunk_overlap (int, optional): オーバーラップのサイズ. + + セパレータとチャンクサイズでテキストを分割し、メタデータにセクション番号を付与します. + """ + docs = character_splitter( + text, + separator=separator, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + return add_metadata( + docs, + title=docs[0].page_content, + with_section_number=True, + **kwargs + ) + +if __name__ == "__main__": + md_text = """ +# Sample Markdown +This is a sample markdown text. + + +## piyo +There is section 2. +### fuga +but, there is section 3. + + +## Are you ...? +Are you hogehoge? + + +### negative answer +No, I'm fugafuga. + + +### positive answer +Yes, I'm hogehoge. +""" + def print_docs(docs: list[Document]) -> None: + for doc in docs: + print(doc.page_content) + print(doc.metadata) + print() + + + docs = md_formatter(md_text) + print_docs(docs) + + docs = text_formatter(md_text) + print_docs(docs)