Skip to content

Commit

Permalink
Merge branch 'Feature/#338' into Feature/#335
Browse files Browse the repository at this point in the history
  • Loading branch information
bwook00 committed Apr 19, 2024
2 parents a6dcf3f + dbacc9a commit 0f682db
Show file tree
Hide file tree
Showing 22 changed files with 533 additions and 19 deletions.
15 changes: 10 additions & 5 deletions autorag/data/corpus/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@ def langchain_documents_to_parquet(langchain_documents: List[Document],
Default is False.
:return: Corpus data as pd.DataFrame
"""
corpus_df = pd.DataFrame({
'doc_id': [str(uuid.uuid4()) for _ in range(len(langchain_documents))],
'contents': list(map(lambda doc: doc.page_content, langchain_documents)),
'metadata': list(map(lambda doc: add_essential_metadata(doc.metadata), langchain_documents)),
})
doc_ids = [str(uuid.uuid4()) for _ in langchain_documents]
corpus_df = pd.DataFrame([
{
'doc_id': doc_id,
'contents': doc.page_content,
'metadata': add_essential_metadata(doc.metadata, prev_id, next_id)
}
for doc, doc_id, prev_id, next_id in
zip(langchain_documents, doc_ids, [None] + doc_ids[:-1], doc_ids[1:] + [None])
])

if output_filepath is not None:
save_parquet_safe(corpus_df, output_filepath, upsert=upsert)
Expand Down
29 changes: 19 additions & 10 deletions autorag/data/corpus/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,16 @@ def llama_documents_to_parquet(llama_documents: List[Document],
Default is False.
:return: Corpus data as pd.DataFrame
"""
doc_lst = list(map(lambda doc: {
'doc_id': str(uuid.uuid4()),
'contents': doc.text,
'metadata': add_essential_metadata(doc.metadata)
}, llama_documents))
doc_ids = [str(uuid.uuid4()) for _ in llama_documents]
doc_lst = [
{
'doc_id': doc_id,
'contents': doc.text,
'metadata': add_essential_metadata(doc.metadata, prev_id, next_id)
}
for doc, doc_id, prev_id, next_id in zip(llama_documents, doc_ids, [None] + doc_ids[:-1], doc_ids[1:] + [None])
]

processed_df = pd.DataFrame(doc_lst)

if output_filepath is not None:
Expand All @@ -57,11 +62,15 @@ def llama_text_node_to_parquet(text_nodes: List[TextNode],
:return: Corpus data as pd.DataFrame
"""

corpus_df = pd.DataFrame(list(map(lambda node: {
'doc_id': node.node_id,
'contents': node.text,
'metadata': add_essential_metadata(node.metadata)
}, text_nodes)))
node_ids = [node.node_id for node in text_nodes]
corpus_df = pd.DataFrame([
{
'doc_id': node_id,
'contents': node.text,
'metadata': add_essential_metadata(node.metadata, prev_id, next_id)
}
for node, node_id, prev_id, next_id in zip(text_nodes, node_ids, [None] + node_ids[:-1], node_ids[1:] + [None])
])

if output_filepath is not None:
save_parquet_safe(corpus_df, output_filepath, upsert=upsert)
Expand Down
6 changes: 5 additions & 1 deletion autorag/data/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ def get_file_metadata(file_path: str) -> Dict:
}


def add_essential_metadata(metadata: Dict) -> Dict:
def add_essential_metadata(metadata: Dict, prev_id: str, next_id: str) -> Dict:
if 'last_modified_datetime' not in metadata:
metadata['last_modified_datetime'] = datetime.now()
if 'prev_id' not in metadata:
metadata['prev_id'] = prev_id
if 'next_id' not in metadata:
metadata['next_id'] = next_id
return metadata


Expand Down
1 change: 1 addition & 0 deletions autorag/nodes/passageaugmenter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .prev_next_augmenter import prev_next_augmenter
95 changes: 95 additions & 0 deletions autorag/nodes/passageaugmenter/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import functools
import itertools
import logging
import os
from pathlib import Path
from typing import List, Union, Tuple

import numpy as np
import pandas as pd
import torch

from autorag import embedding_models
from autorag.evaluate.metric.util import calculate_cosine_similarity
from autorag.utils import result_to_dataframe, validate_qa_dataset, fetch_contents, sort_by_scores
from autorag.utils.util import reconstruct_list, filter_dict_keys

logger = logging.getLogger("AutoRAG")


def passage_augmenter_node(func):
@functools.wraps(func)
@result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"])
def wrapper(
project_dir: Union[str, Path],
previous_result: pd.DataFrame,
*args, **kwargs) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
validate_qa_dataset(previous_result)
data_dir = os.path.join(project_dir, "data")

# find queries columns
assert "query" in previous_result.columns, "previous_result must have query column."
queries = previous_result["query"].tolist()

# find ids columns
assert "retrieved_ids" in previous_result.columns, "previous_result must have retrieved_ids column."
ids = previous_result["retrieved_ids"].tolist()

corpus_df = pd.read_parquet(os.path.join(data_dir, "corpus.parquet"))

if func.__name__ == 'prev_next_augmenter':
slim_corpus_df = corpus_df[["doc_id", "metadata"]]
slim_corpus_df['metadata'] = slim_corpus_df['metadata'].apply(filter_dict_keys, keys=['prev_id', 'next_id'])

mode = kwargs.pop("mode", 'next')
num_passages = kwargs.pop("num_passages", 1)

# get augmented ids
ids = func(ids_list=ids, corpus_df=slim_corpus_df, mode=mode, num_passages=num_passages)
else:
ids = func(ids_list=ids, *args, **kwargs)

# fetch contents from corpus to use augmented ids
contents = fetch_contents(corpus_df, ids)

# set embedding model for getting scores
embedding_model_str = kwargs.pop("embedding_model", 'openai')
query_embeddings, contents_embeddings = embedding_query_content(queries, contents, embedding_model_str,
batch=128)

# get scores from calculated cosine similarity
scores = [np.array([calculate_cosine_similarity(query_embedding, x) for x in content_embeddings]).tolist()
for query_embedding, content_embeddings in zip(query_embeddings, contents_embeddings)]

# sort by scores
df = pd.DataFrame({
'contents': contents,
'ids': ids,
'scores': scores,
})
df[['contents', 'ids', 'scores']] = df.apply(sort_by_scores, axis=1, result_type='expand')
augmented_contents, augmented_ids, augmented_scores = \
df['contents'].tolist(), df['ids'].tolist(), df['scores'].tolist()

return augmented_contents, augmented_ids, augmented_scores

return wrapper


def embedding_query_content(queries: List[str], contents_list: List[List[str]],
embedding_model: str, batch: int = 128):
embedding_model = embedding_models[embedding_model]

# Embedding using batch
embedding_model.embed_batch_size = batch
query_embeddings = embedding_model.get_text_embedding_batch(queries)

content_lengths = list(map(len, contents_list))
content_embeddings_flatten = embedding_model.get_text_embedding_batch(list(
itertools.chain.from_iterable(contents_list)))
content_embeddings = reconstruct_list(content_embeddings_flatten, content_lengths)

del embedding_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return query_embeddings, content_embeddings
57 changes: 57 additions & 0 deletions autorag/nodes/passageaugmenter/prev_next_augmenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import List

import pandas as pd

from autorag.nodes.passageaugmenter.base import passage_augmenter_node


@passage_augmenter_node
def prev_next_augmenter(ids_list: List[List[str]],
corpus_df: pd.DataFrame,
num_passages: int = 1,
mode: str = 'next'
) -> List[List[str]]:
"""
Add passages before and/or after the retrieved passage.
For more information, visit https://docs.llamaindex.ai/en/stable/examples/node_postprocessor/PrevNextPostprocessorDemo/.
:param ids_list: The list of lists of ids retrieved
:param corpus_df: The corpus dataframe
:param num_passages: The number of passages to add before and after the retrieved passage
Default is 1.
:param mode: The mode of augmentation
'prev': add passages before the retrieved passage
'next': add passages after the retrieved passage
'both': add passages before and after the retrieved passage
Default is 'next'.
:return: The list of lists of augmented ids
"""
if mode not in ['prev', 'next', 'both']:
raise ValueError(f"mode must be 'prev', 'next', or 'both', but got {mode}")

augmented_ids = [(lambda ids: prev_next_augmenter_pure(ids, corpus_df, mode, num_passages))(ids) for ids in
ids_list]

return augmented_ids


def prev_next_augmenter_pure(ids: List[str], corpus_df: pd.DataFrame, mode: str, num_passages: int):
def fetch_id_sequence(start_id, key):
sequence = []
current_id = start_id
for _ in range(num_passages):
current_id = corpus_df.loc[corpus_df['doc_id'] == current_id]['metadata'].values[0].get(key)
if current_id is None:
break
sequence.append(current_id)
return sequence

augmented_group = []
for id_ in ids:
current_ids = [id_]
if mode in ['prev', 'both']:
current_ids = fetch_id_sequence(id_, 'prev_id')[::-1] + current_ids
if mode in ['next', 'both']:
current_ids += fetch_id_sequence(id_, 'next_id')
augmented_group.extend(current_ids)
return augmented_group
69 changes: 69 additions & 0 deletions autorag/nodes/passageaugmenter/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import logging
import os
import pathlib
from typing import List, Callable, Dict

import pandas as pd

from autorag.nodes.retrieval.run import evaluate_retrieval_node
from autorag.strategy import measure_speed, filter_by_threshold, select_best_average

logger = logging.getLogger("AutoRAG")


def run_passage_augmenter_node(modules: List[Callable],
module_params: List[Dict],
previous_result: pd.DataFrame,
node_line_dir: str,
strategies: Dict,
) -> pd.DataFrame:
if not os.path.exists(node_line_dir):
os.makedirs(node_line_dir)
project_dir = pathlib.PurePath(node_line_dir).parent.parent
retrieval_gt = pd.read_parquet(os.path.join(project_dir, "data", "qa.parquet"))['retrieval_gt'].tolist()

results, execution_times = zip(*map(lambda task: measure_speed(
task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))
average_times = list(map(lambda x: x / len(results[0]), execution_times))

# run metrics before filtering
if strategies.get('metrics') is None:
raise ValueError("You must at least one metrics for passage_augmenter evaluation.")
results = list(map(lambda x: evaluate_retrieval_node(x, retrieval_gt, strategies.get('metrics')), results))

# save results to folder
save_dir = os.path.join(node_line_dir, "passage_augmenter") # node name
if not os.path.exists(save_dir):
os.makedirs(save_dir)
filepaths = list(map(lambda x: os.path.join(save_dir, f'{x}.parquet'), range(len(modules))))
list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet
filenames = list(map(lambda x: os.path.basename(x), filepaths))

summary_df = pd.DataFrame({
'filename': filenames,
'module_name': list(map(lambda module: module.__name__, modules)),
'module_params': module_params,
'execution_time': average_times,
**{f'passage_augmenter_{metric}': list(map(lambda result: result[metric].mean(), results)) for metric in
strategies.get('metrics')},
})

# filter by strategies
if strategies.get('speed_threshold') is not None:
results, filenames = filter_by_threshold(results, average_times, strategies['speed_threshold'], filenames)
selected_result, selected_filename = select_best_average(results, strategies.get('metrics'), filenames)
# change metric name columns to passage_augmenter_metric_name
selected_result = selected_result.rename(columns={
metric_name: f'passage_augmenter_{metric_name}' for metric_name in strategies['metrics']})
# drop retrieval result columns in previous_result
previous_result = previous_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])
best_result = pd.concat([previous_result, selected_result], axis=1)

# add 'is_best' column to summary file
summary_df['is_best'] = summary_df['filename'] == selected_filename

# save files
summary_df.to_csv(os.path.join(save_dir, "summary.csv"), index=False)
best_result.to_parquet(os.path.join(save_dir, f'best_{os.path.splitext(selected_filename)[0]}.parquet'),
index=False)
return best_result
4 changes: 2 additions & 2 deletions autorag/nodes/passagefilter/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def run_passage_filter_node(modules: List[Callable],
Could be retrieval, reranker, passage filter modules result.
It means it must contain 'query', 'retrieved_contents', 'retrieved_ids', 'retrieve_scores' columns.
:param node_line_dir: This node line's directory.
:param strategies: Strategies for passage reranker node.
:param strategies: Strategies for passage filter node.
In this node, we use 'retrieval_f1', 'retrieval_recall' and 'retrieval_precision'.
You can skip evaluation when you use only one module and a module parameter.
:return: The best result dataframe with previous result columns.
Expand All @@ -38,7 +38,7 @@ def run_passage_filter_node(modules: List[Callable],

# run metrics before filtering
if strategies.get('metrics') is None:
raise ValueError("You must at least one metrics for passage_reranker evaluation.")
raise ValueError("You must at least one metrics for passage_filter evaluation.")
results = list(map(lambda x: evaluate_retrieval_node(x, retrieval_gt, strategies.get('metrics')), results))

# save results to folder
Expand Down
3 changes: 3 additions & 0 deletions autorag/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def get_support_modules(module_name: str) -> Callable:
'hybrid_cc': ('autorag.nodes.retrieval', 'hybrid_cc'),
'hybrid_rsf': ('autorag.nodes.retrieval', 'hybrid_rsf'),
'hybrid_dbsf': ('autorag.nodes.retrieval', 'hybrid_dbsf'),
# passage_augmenter
'prev_next_augmenter': ('autorag.nodes.passageaugmenter', 'prev_next_augmenter'),
# passage_reranker
'monot5': ('autorag.nodes.passagereranker', 'monot5'),
'tart': ('autorag.nodes.passagereranker', 'tart'),
Expand Down Expand Up @@ -67,5 +69,6 @@ def get_support_nodes(node_name: str) -> Callable:
'passage_filter': ('autorag.nodes.passagefilter.run', 'run_passage_filter_node'),
'passage_compressor': ('autorag.nodes.passagecompressor.run', 'run_passage_compressor_node'),
'passage_reranker': ('autorag.nodes.passagereranker.run', 'run_passage_reranker_node'),
'passage_augmenter': ('autorag.nodes.passageaugmenter.run', 'run_passage_augmenter_node'),
}
return dynamically_find_function(node_name, support_nodes)
2 changes: 1 addition & 1 deletion autorag/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .preprocess import validate_qa_dataset, validate_corpus_dataset, cast_qa_dataset, cast_corpus_dataset
from .util import fetch_contents, result_to_dataframe
from .util import fetch_contents, result_to_dataframe, sort_by_scores
10 changes: 10 additions & 0 deletions autorag/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,13 @@ def select_top_k(df, column_names: List[str], top_k: int):
for column_name in column_names:
df[column_name] = df[column_name].apply(lambda x: x[:top_k])
return df


def filter_dict_keys(dict_, keys: List[str]):
result = {}
for key in keys:
if key in dict_:
result[key] = dict_[key]
else:
raise KeyError(f"Key '{key}' not found in dictionary.")
return result
37 changes: 37 additions & 0 deletions docs/source/api_spec/autorag.nodes.passageaugmenter.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
autorag.nodes.passageaugmenter package
======================================

Submodules
----------

autorag.nodes.passageaugmenter.base module
------------------------------------------

.. automodule:: autorag.nodes.passageaugmenter.base
:members:
:undoc-members:
:show-inheritance:

autorag.nodes.passageaugmenter.prev\_next\_augmenter module
-----------------------------------------------------------

.. automodule:: autorag.nodes.passageaugmenter.prev_next_augmenter
:members:
:undoc-members:
:show-inheritance:

autorag.nodes.passageaugmenter.run module
-----------------------------------------

.. automodule:: autorag.nodes.passageaugmenter.run
:members:
:undoc-members:
:show-inheritance:

Module contents
---------------

.. automodule:: autorag.nodes.passageaugmenter
:members:
:undoc-members:
:show-inheritance:
Loading

0 comments on commit 0f682db

Please sign in to comment.