diff --git a/autorag/data/corpus/langchain.py b/autorag/data/corpus/langchain.py index adf42b469..85fecb406 100644 --- a/autorag/data/corpus/langchain.py +++ b/autorag/data/corpus/langchain.py @@ -25,11 +25,16 @@ def langchain_documents_to_parquet(langchain_documents: List[Document], Default is False. :return: Corpus data as pd.DataFrame """ - corpus_df = pd.DataFrame({ - 'doc_id': [str(uuid.uuid4()) for _ in range(len(langchain_documents))], - 'contents': list(map(lambda doc: doc.page_content, langchain_documents)), - 'metadata': list(map(lambda doc: add_essential_metadata(doc.metadata), langchain_documents)), - }) + doc_ids = [str(uuid.uuid4()) for _ in langchain_documents] + corpus_df = pd.DataFrame([ + { + 'doc_id': doc_id, + 'contents': doc.page_content, + 'metadata': add_essential_metadata(doc.metadata, prev_id, next_id) + } + for doc, doc_id, prev_id, next_id in + zip(langchain_documents, doc_ids, [None] + doc_ids[:-1], doc_ids[1:] + [None]) + ]) if output_filepath is not None: save_parquet_safe(corpus_df, output_filepath, upsert=upsert) diff --git a/autorag/data/corpus/llama_index.py b/autorag/data/corpus/llama_index.py index 8e482981d..8aca61817 100644 --- a/autorag/data/corpus/llama_index.py +++ b/autorag/data/corpus/llama_index.py @@ -26,11 +26,16 @@ def llama_documents_to_parquet(llama_documents: List[Document], Default is False. :return: Corpus data as pd.DataFrame """ - doc_lst = list(map(lambda doc: { - 'doc_id': str(uuid.uuid4()), - 'contents': doc.text, - 'metadata': add_essential_metadata(doc.metadata) - }, llama_documents)) + doc_ids = [str(uuid.uuid4()) for _ in llama_documents] + doc_lst = [ + { + 'doc_id': doc_id, + 'contents': doc.text, + 'metadata': add_essential_metadata(doc.metadata, prev_id, next_id) + } + for doc, doc_id, prev_id, next_id in zip(llama_documents, doc_ids, [None] + doc_ids[:-1], doc_ids[1:] + [None]) + ] + processed_df = pd.DataFrame(doc_lst) if output_filepath is not None: @@ -57,11 +62,15 @@ def llama_text_node_to_parquet(text_nodes: List[TextNode], :return: Corpus data as pd.DataFrame """ - corpus_df = pd.DataFrame(list(map(lambda node: { - 'doc_id': node.node_id, - 'contents': node.text, - 'metadata': add_essential_metadata(node.metadata) - }, text_nodes))) + node_ids = [node.node_id for node in text_nodes] + corpus_df = pd.DataFrame([ + { + 'doc_id': node_id, + 'contents': node.text, + 'metadata': add_essential_metadata(node.metadata, prev_id, next_id) + } + for node, node_id, prev_id, next_id in zip(text_nodes, node_ids, [None] + node_ids[:-1], node_ids[1:] + [None]) + ]) if output_filepath is not None: save_parquet_safe(corpus_df, output_filepath, upsert=upsert) diff --git a/autorag/data/utils/util.py b/autorag/data/utils/util.py index a7858505b..97bd6432a 100644 --- a/autorag/data/utils/util.py +++ b/autorag/data/utils/util.py @@ -31,9 +31,13 @@ def get_file_metadata(file_path: str) -> Dict: } -def add_essential_metadata(metadata: Dict) -> Dict: +def add_essential_metadata(metadata: Dict, prev_id: str, next_id: str) -> Dict: if 'last_modified_datetime' not in metadata: metadata['last_modified_datetime'] = datetime.now() + if 'prev_id' not in metadata: + metadata['prev_id'] = prev_id + if 'next_id' not in metadata: + metadata['next_id'] = next_id return metadata diff --git a/autorag/nodes/passageaugmenter/__init__.py b/autorag/nodes/passageaugmenter/__init__.py new file mode 100644 index 000000000..c90ee5e79 --- /dev/null +++ b/autorag/nodes/passageaugmenter/__init__.py @@ -0,0 +1 @@ +from .prev_next_augmenter import prev_next_augmenter diff --git a/autorag/nodes/passageaugmenter/base.py b/autorag/nodes/passageaugmenter/base.py new file mode 100644 index 000000000..10404fd57 --- /dev/null +++ b/autorag/nodes/passageaugmenter/base.py @@ -0,0 +1,95 @@ +import functools +import itertools +import logging +import os +from pathlib import Path +from typing import List, Union, Tuple + +import numpy as np +import pandas as pd +import torch + +from autorag import embedding_models +from autorag.evaluate.metric.util import calculate_cosine_similarity +from autorag.utils import result_to_dataframe, validate_qa_dataset, fetch_contents, sort_by_scores +from autorag.utils.util import reconstruct_list, filter_dict_keys + +logger = logging.getLogger("AutoRAG") + + +def passage_augmenter_node(func): + @functools.wraps(func) + @result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"]) + def wrapper( + project_dir: Union[str, Path], + previous_result: pd.DataFrame, + *args, **kwargs) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: + validate_qa_dataset(previous_result) + data_dir = os.path.join(project_dir, "data") + + # find queries columns + assert "query" in previous_result.columns, "previous_result must have query column." + queries = previous_result["query"].tolist() + + # find ids columns + assert "retrieved_ids" in previous_result.columns, "previous_result must have retrieved_ids column." + ids = previous_result["retrieved_ids"].tolist() + + corpus_df = pd.read_parquet(os.path.join(data_dir, "corpus.parquet")) + + if func.__name__ == 'prev_next_augmenter': + slim_corpus_df = corpus_df[["doc_id", "metadata"]] + slim_corpus_df['metadata'] = slim_corpus_df['metadata'].apply(filter_dict_keys, keys=['prev_id', 'next_id']) + + mode = kwargs.pop("mode", 'next') + num_passages = kwargs.pop("num_passages", 1) + + # get augmented ids + ids = func(ids_list=ids, corpus_df=slim_corpus_df, mode=mode, num_passages=num_passages) + else: + ids = func(ids_list=ids, *args, **kwargs) + + # fetch contents from corpus to use augmented ids + contents = fetch_contents(corpus_df, ids) + + # set embedding model for getting scores + embedding_model_str = kwargs.pop("embedding_model", 'openai') + query_embeddings, contents_embeddings = embedding_query_content(queries, contents, embedding_model_str, + batch=128) + + # get scores from calculated cosine similarity + scores = [np.array([calculate_cosine_similarity(query_embedding, x) for x in content_embeddings]).tolist() + for query_embedding, content_embeddings in zip(query_embeddings, contents_embeddings)] + + # sort by scores + df = pd.DataFrame({ + 'contents': contents, + 'ids': ids, + 'scores': scores, + }) + df[['contents', 'ids', 'scores']] = df.apply(sort_by_scores, axis=1, result_type='expand') + augmented_contents, augmented_ids, augmented_scores = \ + df['contents'].tolist(), df['ids'].tolist(), df['scores'].tolist() + + return augmented_contents, augmented_ids, augmented_scores + + return wrapper + + +def embedding_query_content(queries: List[str], contents_list: List[List[str]], + embedding_model: str, batch: int = 128): + embedding_model = embedding_models[embedding_model] + + # Embedding using batch + embedding_model.embed_batch_size = batch + query_embeddings = embedding_model.get_text_embedding_batch(queries) + + content_lengths = list(map(len, contents_list)) + content_embeddings_flatten = embedding_model.get_text_embedding_batch(list( + itertools.chain.from_iterable(contents_list))) + content_embeddings = reconstruct_list(content_embeddings_flatten, content_lengths) + + del embedding_model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return query_embeddings, content_embeddings diff --git a/autorag/nodes/passageaugmenter/prev_next_augmenter.py b/autorag/nodes/passageaugmenter/prev_next_augmenter.py new file mode 100644 index 000000000..c1e4d8a96 --- /dev/null +++ b/autorag/nodes/passageaugmenter/prev_next_augmenter.py @@ -0,0 +1,57 @@ +from typing import List + +import pandas as pd + +from autorag.nodes.passageaugmenter.base import passage_augmenter_node + + +@passage_augmenter_node +def prev_next_augmenter(ids_list: List[List[str]], + corpus_df: pd.DataFrame, + num_passages: int = 1, + mode: str = 'next' + ) -> List[List[str]]: + """ + Add passages before and/or after the retrieved passage. + For more information, visit https://docs.llamaindex.ai/en/stable/examples/node_postprocessor/PrevNextPostprocessorDemo/. + + :param ids_list: The list of lists of ids retrieved + :param corpus_df: The corpus dataframe + :param num_passages: The number of passages to add before and after the retrieved passage + Default is 1. + :param mode: The mode of augmentation + 'prev': add passages before the retrieved passage + 'next': add passages after the retrieved passage + 'both': add passages before and after the retrieved passage + Default is 'next'. + :return: The list of lists of augmented ids + """ + if mode not in ['prev', 'next', 'both']: + raise ValueError(f"mode must be 'prev', 'next', or 'both', but got {mode}") + + augmented_ids = [(lambda ids: prev_next_augmenter_pure(ids, corpus_df, mode, num_passages))(ids) for ids in + ids_list] + + return augmented_ids + + +def prev_next_augmenter_pure(ids: List[str], corpus_df: pd.DataFrame, mode: str, num_passages: int): + def fetch_id_sequence(start_id, key): + sequence = [] + current_id = start_id + for _ in range(num_passages): + current_id = corpus_df.loc[corpus_df['doc_id'] == current_id]['metadata'].values[0].get(key) + if current_id is None: + break + sequence.append(current_id) + return sequence + + augmented_group = [] + for id_ in ids: + current_ids = [id_] + if mode in ['prev', 'both']: + current_ids = fetch_id_sequence(id_, 'prev_id')[::-1] + current_ids + if mode in ['next', 'both']: + current_ids += fetch_id_sequence(id_, 'next_id') + augmented_group.extend(current_ids) + return augmented_group diff --git a/autorag/nodes/passageaugmenter/run.py b/autorag/nodes/passageaugmenter/run.py new file mode 100644 index 000000000..c097f407f --- /dev/null +++ b/autorag/nodes/passageaugmenter/run.py @@ -0,0 +1,69 @@ +import logging +import os +import pathlib +from typing import List, Callable, Dict + +import pandas as pd + +from autorag.nodes.retrieval.run import evaluate_retrieval_node +from autorag.strategy import measure_speed, filter_by_threshold, select_best_average + +logger = logging.getLogger("AutoRAG") + + +def run_passage_augmenter_node(modules: List[Callable], + module_params: List[Dict], + previous_result: pd.DataFrame, + node_line_dir: str, + strategies: Dict, + ) -> pd.DataFrame: + if not os.path.exists(node_line_dir): + os.makedirs(node_line_dir) + project_dir = pathlib.PurePath(node_line_dir).parent.parent + retrieval_gt = pd.read_parquet(os.path.join(project_dir, "data", "qa.parquet"))['retrieval_gt'].tolist() + + results, execution_times = zip(*map(lambda task: measure_speed( + task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params))) + average_times = list(map(lambda x: x / len(results[0]), execution_times)) + + # run metrics before filtering + if strategies.get('metrics') is None: + raise ValueError("You must at least one metrics for passage_augmenter evaluation.") + results = list(map(lambda x: evaluate_retrieval_node(x, retrieval_gt, strategies.get('metrics')), results)) + + # save results to folder + save_dir = os.path.join(node_line_dir, "passage_augmenter") # node name + if not os.path.exists(save_dir): + os.makedirs(save_dir) + filepaths = list(map(lambda x: os.path.join(save_dir, f'{x}.parquet'), range(len(modules)))) + list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet + filenames = list(map(lambda x: os.path.basename(x), filepaths)) + + summary_df = pd.DataFrame({ + 'filename': filenames, + 'module_name': list(map(lambda module: module.__name__, modules)), + 'module_params': module_params, + 'execution_time': average_times, + **{f'passage_augmenter_{metric}': list(map(lambda result: result[metric].mean(), results)) for metric in + strategies.get('metrics')}, + }) + + # filter by strategies + if strategies.get('speed_threshold') is not None: + results, filenames = filter_by_threshold(results, average_times, strategies['speed_threshold'], filenames) + selected_result, selected_filename = select_best_average(results, strategies.get('metrics'), filenames) + # change metric name columns to passage_augmenter_metric_name + selected_result = selected_result.rename(columns={ + metric_name: f'passage_augmenter_{metric_name}' for metric_name in strategies['metrics']}) + # drop retrieval result columns in previous_result + previous_result = previous_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores']) + best_result = pd.concat([previous_result, selected_result], axis=1) + + # add 'is_best' column to summary file + summary_df['is_best'] = summary_df['filename'] == selected_filename + + # save files + summary_df.to_csv(os.path.join(save_dir, "summary.csv"), index=False) + best_result.to_parquet(os.path.join(save_dir, f'best_{os.path.splitext(selected_filename)[0]}.parquet'), + index=False) + return best_result diff --git a/autorag/nodes/passagefilter/run.py b/autorag/nodes/passagefilter/run.py index a888f7af1..083c414fb 100644 --- a/autorag/nodes/passagefilter/run.py +++ b/autorag/nodes/passagefilter/run.py @@ -22,7 +22,7 @@ def run_passage_filter_node(modules: List[Callable], Could be retrieval, reranker, passage filter modules result. It means it must contain 'query', 'retrieved_contents', 'retrieved_ids', 'retrieve_scores' columns. :param node_line_dir: This node line's directory. - :param strategies: Strategies for passage reranker node. + :param strategies: Strategies for passage filter node. In this node, we use 'retrieval_f1', 'retrieval_recall' and 'retrieval_precision'. You can skip evaluation when you use only one module and a module parameter. :return: The best result dataframe with previous result columns. @@ -38,7 +38,7 @@ def run_passage_filter_node(modules: List[Callable], # run metrics before filtering if strategies.get('metrics') is None: - raise ValueError("You must at least one metrics for passage_reranker evaluation.") + raise ValueError("You must at least one metrics for passage_filter evaluation.") results = list(map(lambda x: evaluate_retrieval_node(x, retrieval_gt, strategies.get('metrics')), results)) # save results to folder diff --git a/autorag/support.py b/autorag/support.py index 444866a20..c2e24eb6f 100644 --- a/autorag/support.py +++ b/autorag/support.py @@ -25,6 +25,8 @@ def get_support_modules(module_name: str) -> Callable: 'hybrid_cc': ('autorag.nodes.retrieval', 'hybrid_cc'), 'hybrid_rsf': ('autorag.nodes.retrieval', 'hybrid_rsf'), 'hybrid_dbsf': ('autorag.nodes.retrieval', 'hybrid_dbsf'), + # passage_augmenter + 'prev_next_augmenter': ('autorag.nodes.passageaugmenter', 'prev_next_augmenter'), # passage_reranker 'monot5': ('autorag.nodes.passagereranker', 'monot5'), 'tart': ('autorag.nodes.passagereranker', 'tart'), @@ -67,5 +69,6 @@ def get_support_nodes(node_name: str) -> Callable: 'passage_filter': ('autorag.nodes.passagefilter.run', 'run_passage_filter_node'), 'passage_compressor': ('autorag.nodes.passagecompressor.run', 'run_passage_compressor_node'), 'passage_reranker': ('autorag.nodes.passagereranker.run', 'run_passage_reranker_node'), + 'passage_augmenter': ('autorag.nodes.passageaugmenter.run', 'run_passage_augmenter_node'), } return dynamically_find_function(node_name, support_nodes) diff --git a/autorag/utils/__init__.py b/autorag/utils/__init__.py index 6a54a61a9..e26eb0c78 100644 --- a/autorag/utils/__init__.py +++ b/autorag/utils/__init__.py @@ -1,2 +1,2 @@ from .preprocess import validate_qa_dataset, validate_corpus_dataset, cast_qa_dataset, cast_corpus_dataset -from .util import fetch_contents, result_to_dataframe +from .util import fetch_contents, result_to_dataframe, sort_by_scores diff --git a/autorag/utils/util.py b/autorag/utils/util.py index 10403885a..cca0950ee 100644 --- a/autorag/utils/util.py +++ b/autorag/utils/util.py @@ -316,3 +316,13 @@ def select_top_k(df, column_names: List[str], top_k: int): for column_name in column_names: df[column_name] = df[column_name].apply(lambda x: x[:top_k]) return df + + +def filter_dict_keys(dict_, keys: List[str]): + result = {} + for key in keys: + if key in dict_: + result[key] = dict_[key] + else: + raise KeyError(f"Key '{key}' not found in dictionary.") + return result diff --git a/docs/source/api_spec/autorag.nodes.passageaugmenter.rst b/docs/source/api_spec/autorag.nodes.passageaugmenter.rst new file mode 100644 index 000000000..ca19bcf2a --- /dev/null +++ b/docs/source/api_spec/autorag.nodes.passageaugmenter.rst @@ -0,0 +1,37 @@ +autorag.nodes.passageaugmenter package +====================================== + +Submodules +---------- + +autorag.nodes.passageaugmenter.base module +------------------------------------------ + +.. automodule:: autorag.nodes.passageaugmenter.base + :members: + :undoc-members: + :show-inheritance: + +autorag.nodes.passageaugmenter.prev\_next\_augmenter module +----------------------------------------------------------- + +.. automodule:: autorag.nodes.passageaugmenter.prev_next_augmenter + :members: + :undoc-members: + :show-inheritance: + +autorag.nodes.passageaugmenter.run module +----------------------------------------- + +.. automodule:: autorag.nodes.passageaugmenter.run + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: autorag.nodes.passageaugmenter + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api_spec/autorag.nodes.passagereranker.rst b/docs/source/api_spec/autorag.nodes.passagereranker.rst index 6745ebad8..fc42db2f4 100644 --- a/docs/source/api_spec/autorag.nodes.passagereranker.rst +++ b/docs/source/api_spec/autorag.nodes.passagereranker.rst @@ -60,6 +60,14 @@ autorag.nodes.passagereranker.jina module :undoc-members: :show-inheritance: +autorag.nodes.passagereranker.just\_lab module +---------------------------------------------- + +.. automodule:: autorag.nodes.passagereranker.just_lab + :members: + :undoc-members: + :show-inheritance: + autorag.nodes.passagereranker.koreranker module ----------------------------------------------- diff --git a/docs/source/api_spec/autorag.nodes.retrieval.rst b/docs/source/api_spec/autorag.nodes.retrieval.rst index c2a6423d5..6013aa76e 100644 --- a/docs/source/api_spec/autorag.nodes.retrieval.rst +++ b/docs/source/api_spec/autorag.nodes.retrieval.rst @@ -52,6 +52,14 @@ autorag.nodes.retrieval.hybrid\_rsf module :undoc-members: :show-inheritance: +autorag.nodes.retrieval.recursive\_chunk module +----------------------------------------------- + +.. automodule:: autorag.nodes.retrieval.recursive_chunk + :members: + :undoc-members: + :show-inheritance: + autorag.nodes.retrieval.run module ---------------------------------- diff --git a/docs/source/api_spec/autorag.nodes.rst b/docs/source/api_spec/autorag.nodes.rst index 6c14f9b0c..1264b86ad 100644 --- a/docs/source/api_spec/autorag.nodes.rst +++ b/docs/source/api_spec/autorag.nodes.rst @@ -8,6 +8,7 @@ Subpackages :maxdepth: 4 autorag.nodes.generator + autorag.nodes.passageaugmenter autorag.nodes.passagecompressor autorag.nodes.passagefilter autorag.nodes.passagereranker diff --git a/docs/source/nodes/passage_augmenter/passage_augmenter.md b/docs/source/nodes/passage_augmenter/passage_augmenter.md new file mode 100644 index 000000000..0be68d4de --- /dev/null +++ b/docs/source/nodes/passage_augmenter/passage_augmenter.md @@ -0,0 +1,47 @@ +# Passage Augmenter + +### 🔎 **Definition** + +Passage augmenter is a node that augments passages. +As opposed to the passage filter node, this is a node that adds passages + +### 🤸 **Benefits** + +The primary benefit of passage augmenter is that allows users to fetch additional passages. + +## **Node Parameters** + +**embedding_model** + +- **Description**: The embedding model name to be used for calculating the cosine similarity between the query and the + augmented passages. + +### Example config.yaml file + +```yaml +node_lines: + - node_line_name: retrieve_node_line # Arbitrary node line name + nodes: + - node_type: passage_augmenter + strategy: + metrics: [ retrieval_f1, retrieval_recall, retrieval_precision ] + speed_threshold: 5 + embedding_model: openai + modules: + - module_type: pass_passage_augmenter + - module_type: prev_next_augmenter + mode: next +``` + +```{admonition} What is pass_passage_augmenter? +Its purpose is to test the performance that 'not using' any passage augmenter module. +Because it can be the better option that not using passage augmenter node. +So with this module, you can automatically test the performance without using any passage augmenter module. +``` + +```{toctree} +--- +maxdepth: 1 +--- +prev_next_augmenter.md +``` diff --git a/docs/source/nodes/passage_augmenter/prev_next_augmenter.md b/docs/source/nodes/passage_augmenter/prev_next_augmenter.md new file mode 100644 index 000000000..57d70ef63 --- /dev/null +++ b/docs/source/nodes/passage_augmenter/prev_next_augmenter.md @@ -0,0 +1,25 @@ +# Prev Next Augmenter + +This module is inspired by +LlamaIndex ['Forward/Backward Augmentation'](https://docs.llamaindex.ai/en/stable/examples/node_postprocessor/PrevNextPostprocessorDemo/). +It allows users to fetch additional passages. + +## **Module Parameters** + +- **num_passages** : The number of passages to add before and after the retrieved passage + Default is 1. +- **mode** : The mode of augmentation + - `prev`: add passages before the retrieved passage + - `next`: add passages after the retrieved passage + - `both`: add passages before and after the retrieved passage + + Default is 'next'. + +## **Example config.yaml** + +```yaml +modules: + - module_type: prev_next_augmenter + num_passages: 1 + mode: next +``` diff --git a/tests/autorag/data/corpus/test_base.py b/tests/autorag/data/corpus/test_base.py index 66d941d14..6c01afc43 100644 --- a/tests/autorag/data/corpus/test_base.py +++ b/tests/autorag/data/corpus/test_base.py @@ -13,4 +13,6 @@ def validate_corpus(result_df: pd.DataFrame, length: int, parquet_filepath): assert ['test text'] * length == result_df['contents'].tolist() assert all(['last_modified_datetime' in metadata for metadata in result_df['metadata'].tolist()]) + assert all(['prev_id' in metadata for metadata in result_df['metadata'].tolist()]) + assert all(['next_id' in metadata for metadata in result_df['metadata'].tolist()]) assert all([isinstance(doc_id, str) for doc_id in result_df['doc_id'].tolist()]) diff --git a/tests/autorag/nodes/passageaugmenter/test_base_passage_augmenter.py b/tests/autorag/nodes/passageaugmenter/test_base_passage_augmenter.py new file mode 100644 index 000000000..c9c02793f --- /dev/null +++ b/tests/autorag/nodes/passageaugmenter/test_base_passage_augmenter.py @@ -0,0 +1,24 @@ +import os +import pathlib + +import pandas as pd + +root_dir = pathlib.PurePath(os.path.dirname(os.path.realpath(__file__))).parent.parent.parent +project_dir = os.path.join(root_dir, "resources", "sample_project") +qa_data = pd.read_parquet(os.path.join(project_dir, "data", "qa.parquet")) +corpus_data = pd.read_parquet(os.path.join(project_dir, "data", "corpus.parquet")) +doc_id_list = corpus_data["doc_id"].tolist() +previous_result = pd.DataFrame({ + 'qid': [1, 2], + 'query': ['What is the capital of France?', 'How many members are in Newjeans?'], + 'retrieved_ids': [[doc_id_list[1], doc_id_list[3]], + [doc_id_list[5], doc_id_list[7]]], + 'retrieved_contents': [['Paris is the capital of France.', 'Paris is one of the capital from France. Isn\'t it?'], + ['Newjeans has 5 members.', 'Danielle is one of the members of Newjeans.']], + 'retrieve_scores': [[0.1, 0.8], [0.1, 0.2]], + 'retrieval_gt': [1, 1], + 'generation_gt': ['answer', 'answer'], + 'retrieval_f1': [0.4, 1.0], + 'retrieval_recall': [1.0, 0.3] +}) +ids_list = [[doc_id_list[1]], [doc_id_list[3]], [doc_id_list[0], doc_id_list[29]]] diff --git a/tests/autorag/nodes/passageaugmenter/test_prev_next_augmenter.py b/tests/autorag/nodes/passageaugmenter/test_prev_next_augmenter.py new file mode 100644 index 000000000..77696526e --- /dev/null +++ b/tests/autorag/nodes/passageaugmenter/test_prev_next_augmenter.py @@ -0,0 +1,50 @@ +from autorag.nodes.passageaugmenter import prev_next_augmenter + +from tests.autorag.nodes.passageaugmenter.test_base_passage_augmenter import ids_list, project_dir, \ + previous_result, corpus_data, doc_id_list + + +def test_prev_next_augmenter_next(): + results = prev_next_augmenter.__wrapped__(ids_list, corpus_data, num_passages=1, mode='next') + assert results == [[doc_id_list[1], doc_id_list[2]], + [doc_id_list[3], doc_id_list[4]], + [doc_id_list[0], doc_id_list[1], doc_id_list[29]]] + + +def test_prev_next_augmenter_prev(): + results = prev_next_augmenter.__wrapped__(ids_list, corpus_data, num_passages=1, mode='prev') + assert results == [[doc_id_list[0], doc_id_list[1]], + [doc_id_list[2], doc_id_list[3]], + [doc_id_list[0], doc_id_list[28], doc_id_list[29]]] + + +def test_prev_next_augmenter_both(): + results = prev_next_augmenter.__wrapped__(ids_list, corpus_data, num_passages=1, mode='both') + assert results == [[doc_id_list[0], doc_id_list[1], doc_id_list[2]], + [doc_id_list[2], doc_id_list[3], doc_id_list[4]], + [doc_id_list[0], doc_id_list[1], doc_id_list[28], doc_id_list[29]]] + + +def test_prev_next_augmenter_multi_passages(): + results = prev_next_augmenter.__wrapped__(ids_list, corpus_data, num_passages=3, mode='prev') + assert results == [[doc_id_list[0], doc_id_list[1]], + [doc_id_list[0], doc_id_list[1], doc_id_list[2], doc_id_list[3]], + [doc_id_list[0], doc_id_list[26], doc_id_list[27], doc_id_list[28], doc_id_list[29]]] + + +def test_prev_next_augmenter_node(): + result_df = prev_next_augmenter(project_dir=project_dir, previous_result=previous_result, mode='next') + contents = result_df["retrieved_contents"].tolist() + ids = result_df["retrieved_ids"].tolist() + scores = result_df["retrieve_scores"].tolist() + assert len(contents) == len(ids) == len(scores) == 2 + assert len(contents[0]) == len(ids[0]) == len(scores[0]) == 4 + for content_list, id_list, score_list in zip(contents, ids, scores): + for i, (content, _id, score) in enumerate(zip(content_list, id_list, score_list)): + assert isinstance(content, str) + assert isinstance(_id, str) + assert isinstance(score, float) + assert _id in corpus_data["doc_id"].tolist() + assert content == corpus_data[corpus_data["doc_id"] == _id]["contents"].values[0] + if i >= 1: + assert score_list[i - 1] >= score_list[i] diff --git a/tests/autorag/nodes/passageaugmenter/test_run_passage_augmenter.py b/tests/autorag/nodes/passageaugmenter/test_run_passage_augmenter.py new file mode 100644 index 000000000..f4292d7d3 --- /dev/null +++ b/tests/autorag/nodes/passageaugmenter/test_run_passage_augmenter.py @@ -0,0 +1,59 @@ +import os +import tempfile + +import pandas as pd +import pytest + +from autorag.nodes.passageaugmenter import prev_next_augmenter +from autorag.nodes.passageaugmenter.run import run_passage_augmenter_node +from autorag.utils.util import load_summary_file +from tests.autorag.nodes.passageaugmenter.test_base_passage_augmenter import qa_data, corpus_data, previous_result + + +@pytest.fixture +def node_line_dir(): + with tempfile.TemporaryDirectory() as project_dir: + data_dir = os.path.join(project_dir, "data") + os.makedirs(data_dir) + qa_data.to_parquet(os.path.join(data_dir, "qa.parquet"), index=False) + corpus_data.to_parquet(os.path.join(data_dir, "corpus.parquet"), index=False) + trial_dir = os.path.join(project_dir, "0") + os.makedirs(trial_dir) + node_line_dir = os.path.join(trial_dir, "node_line_1") + os.makedirs(node_line_dir) + yield node_line_dir + + +def test_run_passage_augmenter_node(node_line_dir): + modules = [prev_next_augmenter] + module_params = [{'num_passages': 1}] + strategies = { + 'metrics': ['retrieval_f1', 'retrieval_recall'], + } + best_result = run_passage_augmenter_node(modules, module_params, previous_result, node_line_dir, strategies) + assert os.path.exists(os.path.join(node_line_dir, "passage_augmenter")) + assert set(best_result.columns) == {'qid', 'query', 'retrieval_gt', 'generation_gt', + 'retrieved_contents', 'retrieved_ids', 'retrieve_scores', + 'retrieval_f1', 'retrieval_recall', + 'passage_augmenter_retrieval_f1', + 'passage_augmenter_retrieval_recall'} + # test summary feature + summary_path = os.path.join(node_line_dir, "passage_augmenter", "summary.csv") + assert os.path.exists(summary_path) + result_path = os.path.join(node_line_dir, "passage_augmenter", '0.parquet') + assert os.path.exists(result_path) + result_df = pd.read_parquet(result_path) + summary_df = load_summary_file(summary_path) + assert set(summary_df.columns) == {'filename', 'passage_augmenter_retrieval_f1', + 'passage_augmenter_retrieval_recall', + 'module_name', 'module_params', 'execution_time', 'is_best'} + assert len(summary_df) == 1 + assert summary_df['filename'][0] == "0.parquet" + assert summary_df['passage_augmenter_retrieval_f1'][0] == result_df['retrieval_f1'].mean() + assert summary_df['passage_augmenter_retrieval_recall'][0] == result_df['retrieval_recall'].mean() + assert summary_df['module_name'][0] == "prev_next_augmenter" + assert summary_df['module_params'][0] == {'num_passages': 1} + assert summary_df['execution_time'][0] > 0 + # test the best file is saved properly + best_path = summary_df[summary_df['is_best']]['filename'].values[0] + assert os.path.exists(os.path.join(node_line_dir, "passage_augmenter", f"best_{best_path}")) diff --git a/tests/resources/sample_project/data/corpus.parquet b/tests/resources/sample_project/data/corpus.parquet index 3ad10ddd9..c45dbd17f 100644 Binary files a/tests/resources/sample_project/data/corpus.parquet and b/tests/resources/sample_project/data/corpus.parquet differ