From 3ee3266659d66e9e4cbc683b8d541b8320803253 Mon Sep 17 00:00:00 2001 From: ZiTao-Li Date: Mon, 29 Apr 2024 22:26:42 +0800 Subject: [PATCH] add knowledge bank --- .../configs/agent_config.json | 132 +---------- .../configs/detailed_rag_config_example.json | 114 ++++++++++ .../rag_example.py | 43 +++- src/agentscope/agents/rag_agents.py | 88 +++----- src/agentscope/rag/__init__.py | 10 +- src/agentscope/rag/knowledge_bank.py | 113 ++++++++++ src/agentscope/rag/langchain_rag.py | 210 ------------------ src/agentscope/rag/llama_index_rag.py | 25 +-- src/agentscope/rag/rag.py | 5 +- 9 files changed, 323 insertions(+), 417 deletions(-) create mode 100644 examples/conversation_with_RAG_agents/configs/detailed_rag_config_example.json create mode 100644 src/agentscope/rag/knowledge_bank.py delete mode 100644 src/agentscope/rag/langchain_rag.py diff --git a/examples/conversation_with_RAG_agents/configs/agent_config.json b/examples/conversation_with_RAG_agents/configs/agent_config.json index 5f38640bb..124e7a578 100644 --- a/examples/conversation_with_RAG_agents/configs/agent_config.json +++ b/examples/conversation_with_RAG_agents/configs/agent_config.json @@ -8,29 +8,10 @@ "model_config_name": "qwen_config", "emb_model_config_name": "qwen_emb_config", "rag_config": { - "index_config": [ - { - "load_data": { - "loader": { - "create_object": true, - "module": "llama_index.core", - "class": "SimpleDirectoryReader", - "init_args": { - "input_dir": "../../docs/sphinx_doc/en/source/tutorial", - "required_exts": [ - ".md" - ] - } - } - } - } - ], - "chunk_size": 2048, - "chunk_overlap": 40, + "knowledge_id": "agentscope_tutorial_rag", "similarity_top_k": 5, "log_retrieval": false, - "recent_n_mem": 1, - "persist_dir": "./rag_storage/tutorial_assist" + "recent_n_mem": 1 } } }, @@ -43,43 +24,10 @@ "model_config_name": "qwen_config", "emb_model_config_name": "qwen_emb_config", "rag_config": { - "index_config": [ - { - "load_data": { - "loader": { - "create_object": true, - "module": "llama_index.core", - "class": "SimpleDirectoryReader", - "init_args": { - "input_dir": "../../src/agentscope", - "recursive": true, - "required_exts": [ - ".py" - ] - } - } - }, - "store_and_index": { - "transformations": [ - { - "create_object": true, - "module": "llama_index.core.node_parser", - "class": "CodeSplitter", - "init_args": { - "language": "python", - "chunk_lines": 100 - } - } - ] - } - } - ], - "chunk_size": 2048, - "chunk_overlap": 40, + "knowledge_id": "agentscope_code_rag", "similarity_top_k": 5, "log_retrieval": false, - "recent_n_mem": 1, - "persist_dir": "./rag_storage/code_assist" + "recent_n_mem": 1 } } }, @@ -92,32 +40,13 @@ "model_config_name": "qwen_config", "emb_model_config_name": "qwen_emb_config", "rag_config": { - "index_config": [ - { - "load_data": { - "loader": { - "create_object": true, - "module": "llama_index.core", - "class": "SimpleDirectoryReader", - "init_args": { - "input_dir": "../../docs/docstring_html/", - "required_exts": [ - ".html" - ] - } - } - } - } - ], - "chunk_size": 2048, - "chunk_overlap": 40, + "knowledge_id": "agentscope_api_rag", "similarity_top_k": 3, "log_retrieval": true, "recent_n_mem": 1, - "persist_dir": "./rag_storage/api_assist", "repo_base": "../../", "file_dir": "../../docs/docstring_html/" - } + } } }, { @@ -139,54 +68,7 @@ "model_config_name": "qwen_config", "emb_model_config_name": "qwen_emb_config", "rag_config": { - "index_config": [ - { - "load_data": { - "loader": { - "create_object": true, - "module": "llama_index.core", - "class": "SimpleDirectoryReader", - "init_args": { - "input_dir": "../../docs/sphinx_doc/en/source/tutorial", - "required_exts": [ - ".md" - ] - } - } - } - }, - { - "load_data": { - "loader": { - "create_object": true, - "module": "llama_index.core", - "class": "SimpleDirectoryReader", - "init_args": { - "input_dir": "../../src/agentscope", - "recursive": true, - "required_exts": [ - ".py" - ] - } - } - }, - "store_and_index": { - "transformations": [ - { - "create_object": true, - "module": "llama_index.core.node_parser", - "class": "CodeSplitter", - "init_args": { - "language": "python", - "chunk_lines": 100 - } - } - ] - } - } - ], - "chunk_size": 2048, - "chunk_overlap": 40, + "knowledge_id": "agentscope_global_rag", "similarity_top_k": 5, "log_retrieval": false, "recent_n_mem": 1, diff --git a/examples/conversation_with_RAG_agents/configs/detailed_rag_config_example.json b/examples/conversation_with_RAG_agents/configs/detailed_rag_config_example.json new file mode 100644 index 000000000..8e3a95042 --- /dev/null +++ b/examples/conversation_with_RAG_agents/configs/detailed_rag_config_example.json @@ -0,0 +1,114 @@ +[ + { + "knowledge_id": "agentscope_code_rag", + "persist_dir": "./rag_storage/searching_assist", + "chunk_size": 2048, + "chunk_overlap": 40, + "data_processing": [ + { + "load_data": { + "loader": { + "create_object": true, + "module": "llama_index.core", + "class": "SimpleDirectoryReader", + "init_args": { + "input_dir": "../../src/agentscope", + "recursive": true, + "required_exts": [ + ".py" + ] + } + } + }, + "store_and_index": { + "transformations": [ + { + "create_object": true, + "module": "llama_index.core.node_parser", + "class": "CodeSplitter", + "init_args": { + "language": "python", + "chunk_lines": 100 + } + } + ] + } + } + ] + }, + { + "knowledge_id": "agentscope_api_rag", + "persist_dir": "./rag_storage/searching_assist", + "chunk_size": 2048, + "chunk_overlap": 40, + "data_processing": [ + { + "load_data": { + "loader": { + "create_object": true, + "module": "llama_index.core", + "class": "SimpleDirectoryReader", + "init_args": { + "input_dir": "../../docs/docstring_html/", + "required_exts": [ + ".html" + ] + } + } + } + } + ] + }, + { + "knowledge_id": "agentscope_global_rag", + "persist_dir": "./rag_storage/searching_assist", + "chunk_size": 2048, + "chunk_overlap": 40, + "data_processing": [ + { + "load_data": { + "loader": { + "create_object": true, + "module": "llama_index.core", + "class": "SimpleDirectoryReader", + "init_args": { + "input_dir": "../../docs/sphinx_doc/en/source/tutorial", + "required_exts": [ + ".md" + ] + } + } + } + }, + { + "load_data": { + "loader": { + "create_object": true, + "module": "llama_index.core", + "class": "SimpleDirectoryReader", + "init_args": { + "input_dir": "../../src/agentscope", + "recursive": true, + "required_exts": [ + ".py" + ] + } + } + }, + "store_and_index": { + "transformations": [ + { + "create_object": true, + "module": "llama_index.core.node_parser", + "class": "CodeSplitter", + "init_args": { + "language": "python", + "chunk_lines": 100 + } + } + ] + } + } + ] + } +] \ No newline at end of file diff --git a/examples/conversation_with_RAG_agents/rag_example.py b/examples/conversation_with_RAG_agents/rag_example.py index 0369d8925..9295e8c56 100644 --- a/examples/conversation_with_RAG_agents/rag_example.py +++ b/examples/conversation_with_RAG_agents/rag_example.py @@ -10,6 +10,7 @@ import agentscope from agentscope.agents import UserAgent, DialogAgent, LlamaIndexAgent +from agentscope.rag import KnowledgeBank AGENT_CHOICE_PROMPT = """ @@ -59,6 +60,31 @@ def main() -> None: config["api_key"] = f"{os.environ.get('DASHSCOPE_API_KEY')}" agentscope.init(model_configs=model_configs) + # initialize knowledge bank (for RAG) + knowledge_bank = KnowledgeBank() + # a simple example of importing data to RAG + knowledge_bank.add_data_for_rag( + knowledge_id="agentscope_tutorial_rag", + emb_model_name="qwen_emb_config", + data_dirs_and_types={ + "../../docs/sphinx_doc/en/source/tutorial": [".md"], + }, + persist_dir="./rag_storage/tutorial_assist", + ) + # more detailed configuration can be achieved by loading config file + with open( + "configs/detailed_rag_config_example.json", + "r", + encoding="utf-8", + ) as f: + knowledge_configs = json.load(f) + for config in knowledge_configs: + knowledge_bank.add_data_for_rag( + knowledge_id=config["knowledge_id"], + emb_model_name="qwen_emb_config", + index_config=config, + ) + with open("configs/agent_config.json", "r", encoding="utf-8") as f: agent_configs = json.load(f) @@ -76,13 +102,20 @@ def main() -> None: searching_agent = LlamaIndexAgent(**agent_configs[4]["args"]) - rag_agents = [ + rag_agent_list = [ tutorial_agent, code_explain_agent, api_agent, searching_agent, ] - rag_agent_names = [agent.name for agent in rag_agents] + rag_agent_names = [agent.name for agent in rag_agent_list] + + for rag_agent in rag_agent_list: + rag_agent.init_rag( + rag_module=knowledge_bank.get_rag( + rag_agent.rag_config["knowledge_id"], + ), + ) # define a guide agent rag_agent_descriptions = [ @@ -91,7 +124,7 @@ def main() -> None: + "\n agent description:" + agent.description + "\n" - for agent in rag_agents + for agent in rag_agent_list ] agent_configs[3]["args"].pop("description") agent_configs[3]["args"]["sys_prompt"] = agent_configs[3]["args"][ @@ -114,14 +147,14 @@ def main() -> None: x.role = "user" # to enforce dashscope requirement on roles if len(x["content"]) == 0 or str(x["content"]).startswith("exit"): break - speak_list = filter_agents(x.get("content", ""), rag_agents) + speak_list = filter_agents(x.get("content", ""), rag_agent_list) if len(speak_list) == 0: guide_response = guide_agent(x) # Only one agent can be called in the current version, # we may support multi-agent conversation later speak_list = filter_agents( guide_response.get("content", ""), - rag_agents, + rag_agent_list, ) agent_name_list = [agent.name for agent in speak_list] for agent_name, agent in zip(agent_name_list, speak_list): diff --git a/src/agentscope/agents/rag_agents.py b/src/agentscope/agents/rag_agents.py index 8e83853e7..af606802e 100644 --- a/src/agentscope/agents/rag_agents.py +++ b/src/agentscope/agents/rag_agents.py @@ -62,11 +62,11 @@ def __init__( # setup RAG configurations self.rag_config = rag_config or {} - # use LlamaIndexAgent OR LangChainAgent - self.rag = self.init_rag() + # rag module to be initialized later + self.rag = None @abstractmethod - def init_rag(self) -> RAGBase: + def init_rag(self, **kwargs: Any) -> RAGBase: """initialize RAG with configuration""" def reply( @@ -182,49 +182,17 @@ def __init__( memory_config (dict): memory configuration rag_config (dict): - config for RAG. It contains the parameters for - RAG modules functions: - rag.load_data(...) and rag.store_and_index(docs, ...) - If not provided, the default setting will be used. - An example of the config for retrieving code files - is as following: - - "rag_config":{ - "index_configs": [ - { - "load_data": { - "loader": { - "create_object": true, - "module": "llama_index.core", - "class": "SimpleDirectoryReader", - "init_args": { - "input_dir": "path/to/data", - "recursive": true - ... - } - } - }, - "store_and_index": { - "transformations": [ - { - "create_object": true, - "module": "llama_index.core.node_parser", - "class": "CodeSplitter", - "init_args": { - "language": "python", - "chunk_lines": 100 - } - } - ] - } - } - ], - "chunk_size": 2048, - "chunk_overlap": 40, - "similarity_top_k": 10, - "log_retrieval": true, - "recent_n_mem": 1 - } + config for RAG module. It contains at least + the following parameters: + "knowledge_id" (str): + identifier of the knowledge in KnowledgeBank, + "similarity_top_k" (int): + how many nodes/document to retrieved, + "log_retrieval" (bool): + whether log the retrieved content, + "recent_n_mem" (int): + how many memory used to query (default is 1, + using only the current input to reply) """ super().__init__( name=name, @@ -236,17 +204,27 @@ def __init__( ) self.description = kwargs.get("description", "") - def init_rag(self) -> LlamaIndexRAG: + def init_rag( + self, + rag_module: Optional[RAGBase] = None, + index_config: Optional[dict] = None, + **kwargs: Any, + ) -> None: # dynamic loading loader # initiate RAG related attributes - rag = LlamaIndexRAG( - name=self.name, - model=self.model, - emb_model=self.emb_model, - rag_config=self.rag_config, - index_config=self.rag_config.get("index_config"), - ) - return rag + if rag_module is None and index_config is not None: + self.rag = LlamaIndexRAG( + name=self.name, + model=self.model, + emb_model=self.emb_model, + rag_config=self.rag_config, + index_config=index_config, + ) + elif rag_module is not None: + self.rag = rag_module + self.rag.rag_config = self.rag_config + else: + raise ValueError("Expected either rag_module or index_config") def reply( self, diff --git a/src/agentscope/rag/__init__.py b/src/agentscope/rag/__init__.py index ac407feb1..90a7c1240 100644 --- a/src/agentscope/rag/__init__.py +++ b/src/agentscope/rag/__init__.py @@ -1,17 +1,11 @@ # -*- coding: utf-8 -*- """ Import all pipeline related modules in the package. """ from .rag import RAGBase - from .llama_index_rag import LlamaIndexRAG - -try: - from .langchain_rag import LangChainRAG -except Exception: - LangChainRAG = None # type: ignore # NOQA - +from .knowledge_bank import KnowledgeBank __all__ = [ "RAGBase", "LlamaIndexRAG", - "LangChainRAG", + "KnowledgeBank", ] diff --git a/src/agentscope/rag/knowledge_bank.py b/src/agentscope/rag/knowledge_bank.py new file mode 100644 index 000000000..01465b9e6 --- /dev/null +++ b/src/agentscope/rag/knowledge_bank.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +""" +Knowledge bank for making RAG module easier to use +""" +import copy +from typing import Optional + +from agentscope.models import load_model_by_config_name +from .rag import RAGBase +from .llama_index_rag import LlamaIndexRAG + + +DEFAULT_INDEX_CONFIG = { + "knowledge_id": "", + "persist_dir": "", + "data_processing": [], +} +DEFAULT_LOADER_CONFIG = { + "load_data": { + "loader": { + "create_object": True, + "module": "llama_index.core", + "class": "SimpleDirectoryReader", + "init_args": {}, + }, + }, +} +DEFAULT_INIT_CONFIG = { + "input_dir": "", + "recursive": True, + "required_exts": [], +} + + +class KnowledgeBank: + """ + KnowledgeBank enables + 1) provide an easy and fast way to initialize the RAG model; + 2) make RAG model reusable and sharable for multiple agent. + """ + + def __init__(self) -> None: + """initialize the knowledge bank""" + self.stored_knowledge: dict[str, RAGBase] = {} + + def add_data_for_rag( + self, + knowledge_id: str, + emb_model_name: str, + data_dirs_and_types: dict[str, list[str]] = None, + model_name: Optional[str] = None, + persist_dir: Optional[str] = None, + index_config: Optional[dict] = None, + ) -> None: + """ + Transform data in a directory to be ready to work with RAG. + Args: + knowledge_id (str): + emb_model_name (str): + model_name (Optional[str]): + data_dirs_and_types (dict[str, list[str]]): + persist_dir (Optional[str]): + index_config (ptional[dict]): + """ + if knowledge_id in self.stored_knowledge: + raise ValueError(f"knowledge_id {knowledge_id} already exists.") + + if persist_dir is None: + persist_dir = "./rag_storage/" + + assert data_dirs_and_types is not None or index_config is not None + + if index_config is None: + index_config = copy.deepcopy(DEFAULT_INDEX_CONFIG) + for data_dir, types in data_dirs_and_types.items(): + loader_config = copy.deepcopy(DEFAULT_LOADER_CONFIG) + loader_init = copy.deepcopy(DEFAULT_INIT_CONFIG) + loader_init["input_dir"] = data_dir + loader_init["required_exts"] = types + loader_config["load_data"]["loader"]["init_args"] = loader_init + index_config["data_processing"].append(loader_config) + index_config["persist_dir"] = persist_dir + + self.stored_knowledge[knowledge_id] = LlamaIndexRAG( + name=knowledge_id, + emb_model=load_model_by_config_name(emb_model_name), + model=load_model_by_config_name(model_name) + if model_name + else None, + index_config=index_config, + ) + + def get_rag( + self, + knowledge_id: str, + duplicate: bool = False, + ) -> RAGBase: + """ + Get a RAG from the knowledge bank. + Args: + knowledge_id (str): + unique id for the RAG + duplicate (bool): + whether return a copy of of the RAG. + """ + if knowledge_id not in self.stored_knowledge: + raise ValueError(f"{knowledge_id} has not been added yet.") + + rag = self.stored_knowledge[knowledge_id] + if duplicate: + rag = copy.deepcopy(rag) + + return rag diff --git a/src/agentscope/rag/langchain_rag.py b/src/agentscope/rag/langchain_rag.py deleted file mode 100644 index bda2fe20d..000000000 --- a/src/agentscope/rag/langchain_rag.py +++ /dev/null @@ -1,210 +0,0 @@ -# -*- coding: utf-8 -*- -""" -This module is integrate the LangChain RAG model into our AgentScope package -""" - - -from typing import Any, Optional, Union - -try: - from langchain_core.vectorstores import VectorStore - from langchain_core.documents import Document - from langchain_core.embeddings import Embeddings - from langchain_community.document_loaders.base import BaseLoader - from langchain_community.vectorstores import Chroma - from langchain_text_splitters.base import TextSplitter - from langchain_text_splitters import CharacterTextSplitter -except ImportError: - VectorStore = None - Document = None - Embeddings = None - BaseLoader = None - Chroma = None - TextSplitter = None - CharacterTextSplitter = None - -from agentscope.models import ModelWrapperBase -from .rag import RAGBase -from .rag import ( - DEFAULT_CHUNK_OVERLAP, - DEFAULT_CHUNK_SIZE, -) - - -class _LangChainEmbModel(Embeddings): - """ - Dummy wrapper to convert the ModelWrapperBase embedding model - to a LanguageChain RAG model - """ - - def __init__(self, emb_model: ModelWrapperBase) -> None: - """ - Dummy wrapper - Args: - emb_model (ModelWrapperBase): embedding model of - ModelWrapperBase type - """ - self._emb_model_wrapper = emb_model - - def embed_documents(self, texts: list[str]) -> list[list[float]]: - """ - Wrapper function for embedding list of documents - Args: - texts (list[str]): list of texts to be embedded - """ - results = [ - list(self._emb_model_wrapper(t).embedding[0]) for t in texts - ] - return results - - def embed_query(self, text: str) -> list[float]: - """ - Wrapper function for embedding a single query - Args: - text (str): query to be embedded - """ - return list(self._emb_model_wrapper(text).embedding[0]) - - -class LangChainRAG(RAGBase): - """ - This class is a wrapper around the LangChain RAG. - """ - - def __init__( - self, - model: Optional[ModelWrapperBase], - emb_model: Union[ModelWrapperBase, Embeddings, None], - config: Optional[dict] = None, - **kwargs: Any, - ) -> None: - """ - Initializes the LangChainRAG - Args: - model (ModelWrapperBase): - The language model used for final synthesis - emb_model ( Union[ModelWrapperBase, Embeddings, None]): - The embedding model used for generate embeddings - config (dict): - The additional configuration for llama index rag - """ - super().__init__(model, emb_model, **kwargs) - - self.loader = None - self.splitter = None - self.retriever = None - self.vector_store = None - - if VectorStore is None: - raise ImportError( - "Please install LangChain RAG packages to use LangChain RAG.", - ) - - self.config = config or {} - if isinstance(emb_model, ModelWrapperBase): - self.emb_model = _LangChainEmbModel(emb_model) - elif isinstance(emb_model, Embeddings): - self.emb_model = emb_model - else: - raise TypeError( - f"Embedding model does not support {type(self.emb_model)}.", - ) - - def load_data( - self, - loader: BaseLoader, - query: Optional[Any] = None, - **kwargs: Any, - ) -> list[Document]: - # pylint: disable=unused-argument - """ - Loading data from a directory - Args: - loader (BaseLoader): - accepting a LangChain loader instance - query (str): - accepting a query, LangChain does not rely on this - Returns: - list[Document]: a list of documents loaded - """ - self.loader = loader - docs = self.loader.load() - return docs - - def store_and_index( - self, - docs_list: Any, - vector_store: Optional[VectorStore] = None, - splitter: Optional[TextSplitter] = None, - **kwargs: Any, - ) -> Any: - # pylint: disable=unused-argument - """ - Preprocessing the loaded documents. - Args: - docs_list (Any): - documents to be processed - vector_store (Optional[VectorStore]): - vector store in LangChain RAG - splitter (Optional[TextSplitter]): - optional, specifies the splitter to preprocess - the documents - - Returns: - None - - In LlamaIndex terms, an Index is a data structure composed - of Document objects, designed to enable querying by an LLM. - For example: - 1) preprocessing documents with - 2) generate embedding, - 3) store the embedding-content to vdb - """ - self.splitter = splitter or CharacterTextSplitter( - chunk_size=self.config.get("chunk_size", DEFAULT_CHUNK_SIZE), - chunk_overlap=self.config.get( - "chunk_overlap", - DEFAULT_CHUNK_OVERLAP, - ), - ) - all_splits = [] - for docs in docs_list: - all_splits = all_splits + self.splitter.split_documents(docs) - - # indexing the chunks and store them into the vector store - if vector_store is None: - vector_store = Chroma() - self.vector_store = vector_store.from_documents( - documents=all_splits, - embedding=self.emb_model, - ) - - # build retriever - search_type = self.config.get("search_type", "similarity") - self.retriever = self.vector_store.as_retriever( - search_type=search_type, - search_kwargs={ - "k": self.config.get("similarity_top_k", 6), - }, - ) - - def retrieve(self, query: Any, to_list_strs: bool = False) -> list[Any]: - """ - This is a basic retrieve function with LangChain APIs - Args: - query: query is expected to be a question in string - - Returns: - list of answers - - More advanced retriever can refer to - https://python.langchain.com/docs/modules/data_connection/retrievers/ - """ - - retrieved_docs = self.retriever.invoke(query) - if to_list_strs: - results = [] - for doc in retrieved_docs: - results.append(doc.page_content) - return results - return retrieved_docs diff --git a/src/agentscope/rag/llama_index_rag.py b/src/agentscope/rag/llama_index_rag.py index d74454614..2b3e74fa2 100644 --- a/src/agentscope/rag/llama_index_rag.py +++ b/src/agentscope/rag/llama_index_rag.py @@ -126,10 +126,10 @@ class LlamaIndexRAG(RAGBase): def __init__( self, name: str, - model: ModelWrapperBase, + model: Optional[ModelWrapperBase] = None, emb_model: Union[ModelWrapperBase, BaseEmbedding, None] = None, - rag_config: dict = None, index_config: dict = None, + rag_config: Optional[dict] = None, overwrite_index: Optional[bool] = False, showprogress: Optional[bool] = True, **kwargs: Any, @@ -155,20 +155,19 @@ def __init__( The language model used for final synthesis emb_model (Optional[ModelWrapperBase]): The embedding model used for generate embeddings - rag_config (dict): - The configuration for llama index rag index_config (dict): The configuration to generate the index + rag_config (dict): + The configuration for llama index rag overwrite_index (Optional[bool]): - Whether to overwrite the index whiel refreshing + Whether to overwrite the index while refreshing showprogress (Optional[bool]): Whether to show the indexing progress """ super().__init__(model, emb_model, rag_config, **kwargs) self.name = name - self.persist_dir = rag_config.get("persist_dir", "/") + self.persist_dir = index_config.get("persist_dir", "/") self.emb_model = emb_model - self.rag_config = rag_config self.index_config = index_config self.overwrite_index = overwrite_index self.showprogress = showprogress @@ -207,7 +206,7 @@ def _init_rag(self) -> None: self.refresh_index() else: self._data_to_index() - self._set_retriever() + self.set_retriever() logger.info(f"RAG agent {self.name} initialization completed!\n") def _load_index(self) -> None: @@ -239,7 +238,7 @@ def _data_to_index(self) -> None: nodes = [] # load data to documents and set transformations # using information in index_config - for config in self.index_config: + for config in self.index_config.get("data_processing"): documents = self._data_to_docs(config=config) transformations = self._set_transformations(config=config).get( "transformations", @@ -371,11 +370,11 @@ def _set_transformations(self, config: dict) -> Any: else: transformations = [ SentenceSplitter( - chunk_size=self.rag_config.get( + chunk_size=self.index_config.get( "chunk_size", DEFAULT_CHUNK_SIZE, ), - chunk_overlap=self.rag_config.get( + chunk_overlap=self.index_config.get( "chunk_overlap", DEFAULT_CHUNK_OVERLAP, ), @@ -389,7 +388,7 @@ def _set_transformations(self, config: dict) -> Any: transformations = {"transformations": transformations} return transformations - def _set_retriever( + def set_retriever( self, retriever: Optional[BaseRetriever] = None, **kwargs: Any, @@ -447,7 +446,7 @@ def refresh_index(self) -> None: """ Refresh the index when needed. """ - for config in self.index_config: + for config in self.index_config.get("data_processing"): documents = self._data_to_docs(config=config) # store and indexing for each file type transformations = self._set_transformations(config=config).get( diff --git a/src/agentscope/rag/rag.py b/src/agentscope/rag/rag.py index b9885dd21..85543ade2 100644 --- a/src/agentscope/rag/rag.py +++ b/src/agentscope/rag/rag.py @@ -30,7 +30,7 @@ class RAGBase(ABC): def __init__( self, - model: Optional[ModelWrapperBase], + model: Optional[ModelWrapperBase] = None, emb_model: Any = None, rag_config: Optional[dict] = None, **kwargs: Any, @@ -65,6 +65,9 @@ def retrieve( return a list with retrieved documents (in strings) """ + def set_retriever(self, **kwargs: Any) -> None: + """update retriever of RAG module""" + def post_processing( self, retrieved_docs: list[str],