Skip to content

Commit

Permalink
Changes made: To reorganize the work flow, we made the following chan…
Browse files Browse the repository at this point in the history
…ges -

(1) llama_index_rag.py: added functions load_doc, load_docs, docs_to_nodes, persist_to_dir,
(2) rag_agents.py: relocate internal function _prepare_args_from_config to llama_index_rag.py, as a method of class LlmaIndexRAG(RAGBase)
  • Loading branch information
艾渔 committed Apr 25, 2024
1 parent 3007809 commit d7ae87c
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 106 deletions.
173 changes: 142 additions & 31 deletions examples/conversation_with_RAG_agents/rag/llama_index_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import Any, Optional, List, Union
from loguru import logger
import importlib
import os.path

try:
Expand Down Expand Up @@ -205,7 +206,6 @@ def load_data(
def store_and_index(
self,
docs_list: Any,
vector_store: Union[BasePydanticVectorStore, VectorStore, None] = None,
retriever: Optional[BaseRetriever] = None,
transformations: Optional[list[NodeParser]] = None,
store_and_index_args_list: Optional[list] = None,
Expand All @@ -217,8 +217,6 @@ def store_and_index(
docs_list (Any):
documents to be processed, usually expected to be in
llama index Documents.
vector_store (Union[BasePydanticVectorStore, VectorStore, None]):
vector store in llama index
retriever (Optional[BaseRetriever]):
optional, specifies the retriever in llama index to be used
transformations (Optional[list[NodeParser]]):
Expand Down Expand Up @@ -246,42 +244,19 @@ def store_and_index(
nodes = []
# we build nodes by using the IngestionPipeline for each document
for i in range(len(docs_list)):
# load the transformation
transformations = store_and_index_args_list[i].get(
"transformations", None)
# if it is not specified, use the default configuration
if transformations is None:
transformations = [
SentenceSplitter(
chunk_size=self.config.get(
"chunk_size",
DEFAULT_CHUNK_SIZE,
),
chunk_overlap=self.config.get(
"chunk_overlap",
DEFAULT_CHUNK_OVERLAP,
),
),
]

# adding embedding model as the last step of transformation
# https://docs.llamaindex.ai/en/stable/module_guides/loading/ingestion_pipeline/root.html
transformations.append(self.emb_model)

# use in memory to construct an index
pipeline = IngestionPipeline(
transformations=transformations,
nodes = nodes + self.docs_to_nodes(
docs=docs_list[i],
transformations=store_and_index_args_list[i].get(
"transformations", None)
)
# stack up the nodes from the pipline
nodes = nodes + pipeline.run(documents=docs_list[i])

# feed all the nodes to embedding model to calculate index
self.index = VectorStoreIndex(
nodes=nodes,
embed_model=self.emb_model,
)
# persist the calculated index
self.index.storage_context.persist(persist_dir=self.persist_dir)
self.persist_to_dir()
else:
# load the storage_context
storage_context = StorageContext.from_defaults(
Expand Down Expand Up @@ -310,6 +285,86 @@ def store_and_index(
self.retriever = retriever
return self.index

def persist_to_dir(self):
"""
Persist the index to the directory.
"""
self.index.storage_context.persist(persist_dir=self.persist_dir)

def load_docs(self, index_config: dict) -> Any:
"""
Load the documents by configurations.
Args:
index_config (dict):
the index configuration
Return:
Any: the loaded documents
"""

if "load_data" in index_config:
load_data_args = self._prepare_args_from_config(
index_config["load_data"],
)
else:
try:
from llama_index.core import SimpleDirectoryReader
except ImportError as exc_inner:
raise ImportError(
" LlamaIndexAgent requires llama-index to be install."
"Please run `pip install llama-index`",
) from exc_inner
load_data_args = {
"loader": SimpleDirectoryReader(
index_config["set_default_data_path"]),
}
logger.info(f"rag.load_data args: {load_data_args}")
docs = self.load_data(**load_data_args)
return docs

def docs_to_nodes(
self,
docs: Any,
transformations: Optional[list[NodeParser]] = None
) -> Any:
"""
Convert the documents to nodes.
Args:
docs (Any):
documents to be processed, usually expected to be in
llama index Documents.
transformations (list[NodeParser]):
specifies the transformations (operators) to
process documents (e.g., split the documents into smaller
chunks)
Return:
Any: return the index of the processed document
"""
# if it is not specified, use the default configuration
if transformations is None:
transformations = [
SentenceSplitter(
chunk_size=self.config.get(
"chunk_size",
DEFAULT_CHUNK_SIZE,
),
chunk_overlap=self.config.get(
"chunk_overlap",
DEFAULT_CHUNK_OVERLAP,
),
),
]
# adding embedding model as the last step of transformation
# https://docs.llamaindex.ai/en/stable/module_guides/loading/ingestion_pipeline/root.html
transformations.append(self.emb_model)

# use in memory to construct an index
pipeline = IngestionPipeline(
transformations=transformations,
)
# stack up the nodes from the pipline
nodes = pipeline.run(documents=docs)
return nodes

def set_retriever(self, retriever: BaseRetriever) -> None:
"""
Reset the retriever if necessary.
Expand Down Expand Up @@ -342,3 +397,59 @@ def retrieve(self, query: str, to_list_strs: bool = False) -> list[Any]:
results.append(node.get_text())
return results
return retrieved

def _prepare_args_from_config(
self,
config: dict,
) -> Any:
"""
Helper function to build args for the two functions:
load_data(...) and store_and_index(docs, ...)
in RAG classes.
Args:
config (dict): a dictionary containing configurations
Returns:
Any: an object that is parsed/built to be an element
of input to the function of RAG module.
"""
if not isinstance(config, dict):
return config

if "create_object" in config:
# if a term in args is a object,
# recursively create object with args from config
module_name = config.get("module", "")
class_name = config.get("class", "")
init_args = config.get("init_args", {})
try:
cur_module = importlib.import_module(module_name)
cur_class = getattr(cur_module, class_name)
init_args = self._prepare_args_from_config(init_args)
logger.info(
f"load and build object{cur_module, cur_class, init_args}",
)
return cur_class(**init_args)
except ImportError as exc_inner:
logger.error(
f"Fail to load class {class_name} "
f"from module {module_name}",
)
raise ImportError(
f"Fail to load class {class_name} "
f"from module {module_name}",
) from exc_inner
else:
prepared_args = {}
for key, value in config.items():
if isinstance(value, list):
prepared_args[key] = []
for c in value:
prepared_args[key].append(
self._prepare_args_from_config(c),
)
elif isinstance(value, dict):
prepared_args[key] = self._prepare_args_from_config(value)
else:
prepared_args[key] = value
return prepared_args
76 changes: 1 addition & 75 deletions examples/conversation_with_RAG_agents/rag_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,62 +71,6 @@ def __init__(
def init_rag(self) -> RAGBase:
"""initialize RAG with configuration"""

def _prepare_args_from_config(
self,
config: dict,
) -> Any:
"""
Helper function to build args for the two functions:
rag.load_data(...) and rag.store_and_index(docs, ...)
in RAG classes.
Args:
config (dict): a dictionary containing configurations
Returns:
Any: an object that is parsed/built to be an element
of input to the function of RAG module.
"""
if not isinstance(config, dict):
return config

if "create_object" in config:
# if a term in args is a object,
# recursively create object with args from config
module_name = config.get("module", "")
class_name = config.get("class", "")
init_args = config.get("init_args", {})
try:
cur_module = importlib.import_module(module_name)
cur_class = getattr(cur_module, class_name)
init_args = self._prepare_args_from_config(init_args)
logger.info(
f"load and build object{cur_module, cur_class, init_args}",
)
return cur_class(**init_args)
except ImportError as exc_inner:
logger.error(
f"Fail to load class {class_name} "
f"from module {module_name}",
)
raise ImportError(
f"Fail to load class {class_name} "
f"from module {module_name}",
) from exc_inner
else:
prepared_args = {}
for key, value in config.items():
if isinstance(value, list):
prepared_args[key] = []
for c in value:
prepared_args[key].append(
self._prepare_args_from_config(c),
)
elif isinstance(value, dict):
prepared_args[key] = self._prepare_args_from_config(value)
else:
prepared_args[key] = value
return prepared_args

def reply(
self,
x: dict = None,
Expand Down Expand Up @@ -312,25 +256,7 @@ def init_rag(self) -> LlamaIndexRAG:
# and transformations, the length of the list depends on
# the total count of loaded data.
for index_config_i in range(len(index_config)):
if "load_data" in index_config[index_config_i]:
load_data_args = self._prepare_args_from_config(
index_config[index_config_i]["load_data"],
)
else:
try:
from llama_index.core import SimpleDirectoryReader
except ImportError as exc_inner:
raise ImportError(
" LlamaIndexAgent requires llama-index to be install."
"Please run `pip install llama-index`",
) from exc_inner
load_data_args = {
"loader": SimpleDirectoryReader(
index_config[index_config_i][
"set_default_data_path"]),
}
logger.info(f"rag.load_data args: {load_data_args}")
docs = rag.load_data(**load_data_args)
docs = rag.load_docs(index_config = index_config[index_config_i])
docs_list.append(docs)

# store and indexing for each file type
Expand Down

0 comments on commit d7ae87c

Please sign in to comment.