From b9d24efa6441bc7092d7ca9096f70739197ce152 Mon Sep 17 00:00:00 2001 From: "Bwook (Byoungwook) Kim" Date: Sun, 21 Apr 2024 15:33:11 +0900 Subject: [PATCH] Parallel processing at Flag embedding reranker and Flag embedding llm reranker (#328) * update util * update flag embedding * parallel processing flag embedding * parallel processing flag llm embedding * use new util function 'select_top_k' --------- Co-authored-by: Jeffrey (Dongkyu) Kim --- .../nodes/passagereranker/flag_embedding.py | 61 +++++++------------ .../passagereranker/flag_embedding_llm.py | 25 ++++---- 2 files changed, 35 insertions(+), 51 deletions(-) diff --git a/autorag/nodes/passagereranker/flag_embedding.py b/autorag/nodes/passagereranker/flag_embedding.py index 8143c5df0..4f2fd8aba 100644 --- a/autorag/nodes/passagereranker/flag_embedding.py +++ b/autorag/nodes/passagereranker/flag_embedding.py @@ -1,11 +1,11 @@ -import asyncio from typing import List, Tuple +import pandas as pd import torch from FlagEmbedding import FlagReranker from autorag.nodes.passagereranker.base import passage_reranker_node -from autorag.utils.util import process_batch +from autorag.utils.util import make_batch, sort_by_scores, flatten_apply, select_top_k @passage_reranker_node @@ -31,48 +31,29 @@ def flag_embedding_reranker(queries: List[str], contents_list: List[List[str]], model = FlagReranker( model_name_or_path=model_name, use_fp16=use_fp16 ) - tasks = [flag_embedding_reranker_pure(query, contents, scores, top_k, ids, model) - for query, contents, scores, ids in zip(queries, contents_list, scores_list, ids_list)] - loop = asyncio.get_event_loop() - results = loop.run_until_complete(process_batch(tasks, batch_size=batch)) - content_result = list(map(lambda x: x[0], results)) - id_result = list(map(lambda x: x[1], results)) - score_result = list(map(lambda x: x[2], results)) + nested_list = [list(map(lambda x: [query, x], content_list)) for query, content_list in zip(queries, contents_list)] + rerank_scores = flatten_apply(flag_embedding_run_model, nested_list, model=model, batch_size=batch) + + df = pd.DataFrame({ + 'contents': contents_list, + 'ids': ids_list, + 'scores': rerank_scores, + }) + df[['contents', 'ids', 'scores']] = df.apply(sort_by_scores, axis=1, result_type='expand') + results = select_top_k(df, ['contents', 'ids', 'scores'], top_k) del model if torch.cuda.is_available(): torch.cuda.empty_cache() - return content_result, id_result, score_result - - -async def flag_embedding_reranker_pure(query: str, contents: List[str], scores: List[float], top_k: int, - ids: List[str], model) -> Tuple[List[str], List[str], List[float]]: - """ - Rerank a list of contents based on their relevance to a query using BAAI Reranker model. - - :param query: The query to use for reranking - :param contents: The list of contents to rerank - :param scores: The list of scores retrieved from the initial ranking - :param ids: The list of ids retrieved from the initial ranking - :param top_k: The number of passages to be retrieved - :param model: BAAI Reranker model. - :return: tuple of lists containing the reranked contents, ids, and scores - """ - input_texts = [(query, content) for content in contents] - with torch.no_grad(): - pred_scores = model.compute_score(sentence_pairs=input_texts) - - content_ids_probs = list(zip(contents, ids, pred_scores)) - - # Sort the list of pairs based on the relevance score in descending order - sorted_content_ids_probs = sorted(content_ids_probs, key=lambda x: x[2], reverse=True) - - # crop with top_k - if len(contents) < top_k: - top_k = len(contents) - sorted_content_ids_probs = sorted_content_ids_probs[:top_k] + return results['contents'].tolist(), results['ids'].tolist(), results['scores'].tolist() - content_result, id_result, score_result = zip(*sorted_content_ids_probs) - return list(content_result), list(id_result), list(score_result) +def flag_embedding_run_model(input_texts, model, batch_size: int): + batch_input_texts = make_batch(input_texts, batch_size) + results = [] + for batch_texts in batch_input_texts: + with torch.no_grad(): + pred_scores = model.compute_score(sentence_pairs=batch_texts) + results.extend(pred_scores) + return results diff --git a/autorag/nodes/passagereranker/flag_embedding_llm.py b/autorag/nodes/passagereranker/flag_embedding_llm.py index a79f19eba..1569f69e1 100644 --- a/autorag/nodes/passagereranker/flag_embedding_llm.py +++ b/autorag/nodes/passagereranker/flag_embedding_llm.py @@ -1,12 +1,12 @@ -import asyncio from typing import List, Tuple +import pandas as pd import torch from FlagEmbedding import FlagLLMReranker from autorag.nodes.passagereranker.base import passage_reranker_node -from autorag.nodes.passagereranker.flag_embedding import flag_embedding_reranker_pure -from autorag.utils.util import process_batch +from autorag.nodes.passagereranker.flag_embedding import flag_embedding_run_model +from autorag.utils.util import flatten_apply, sort_by_scores, select_top_k @passage_reranker_node @@ -32,16 +32,19 @@ def flag_embedding_llm_reranker(queries: List[str], contents_list: List[List[str model = FlagLLMReranker( model_name_or_path=model_name, use_fp16=use_fp16 ) - tasks = [flag_embedding_reranker_pure(query, contents, scores, top_k, ids, model) - for query, contents, scores, ids in zip(queries, contents_list, scores_list, ids_list)] - loop = asyncio.get_event_loop() - results = loop.run_until_complete(process_batch(tasks, batch_size=batch)) - content_result = list(map(lambda x: x[0], results)) - id_result = list(map(lambda x: x[1], results)) - score_result = list(map(lambda x: x[2], results)) + nested_list = [list(map(lambda x: [query, x], content_list)) for query, content_list in zip(queries, contents_list)] + rerank_scores = flatten_apply(flag_embedding_run_model, nested_list, model=model, batch_size=batch) + + df = pd.DataFrame({ + 'contents': contents_list, + 'ids': ids_list, + 'scores': rerank_scores, + }) + df[['contents', 'ids', 'scores']] = df.apply(sort_by_scores, axis=1, result_type='expand') + results = select_top_k(df, ['contents', 'ids', 'scores'], top_k) del model if torch.cuda.is_available(): torch.cuda.empty_cache() - return content_result, id_result, score_result + return results['contents'].tolist(), results['ids'].tolist(), results['scores'].tolist()