Skip to content

Commit

Permalink
Merge branch 'main' into Feature/#335
Browse files Browse the repository at this point in the history
  • Loading branch information
bwook00 authored Apr 21, 2024
2 parents f8ddc5f + b9d24ef commit a71f8a4
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 51 deletions.
61 changes: 21 additions & 40 deletions autorag/nodes/passagereranker/flag_embedding.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
from typing import List, Tuple

import pandas as pd
import torch
from FlagEmbedding import FlagReranker

from autorag.nodes.passagereranker.base import passage_reranker_node
from autorag.utils.util import process_batch
from autorag.utils.util import make_batch, sort_by_scores, flatten_apply, select_top_k


@passage_reranker_node
Expand All @@ -31,48 +31,29 @@ def flag_embedding_reranker(queries: List[str], contents_list: List[List[str]],
model = FlagReranker(
model_name_or_path=model_name, use_fp16=use_fp16
)
tasks = [flag_embedding_reranker_pure(query, contents, scores, top_k, ids, model)
for query, contents, scores, ids in zip(queries, contents_list, scores_list, ids_list)]
loop = asyncio.get_event_loop()
results = loop.run_until_complete(process_batch(tasks, batch_size=batch))
content_result = list(map(lambda x: x[0], results))
id_result = list(map(lambda x: x[1], results))
score_result = list(map(lambda x: x[2], results))
nested_list = [list(map(lambda x: [query, x], content_list)) for query, content_list in zip(queries, contents_list)]
rerank_scores = flatten_apply(flag_embedding_run_model, nested_list, model=model, batch_size=batch)

df = pd.DataFrame({
'contents': contents_list,
'ids': ids_list,
'scores': rerank_scores,
})
df[['contents', 'ids', 'scores']] = df.apply(sort_by_scores, axis=1, result_type='expand')
results = select_top_k(df, ['contents', 'ids', 'scores'], top_k)

del model
if torch.cuda.is_available():
torch.cuda.empty_cache()

return content_result, id_result, score_result


async def flag_embedding_reranker_pure(query: str, contents: List[str], scores: List[float], top_k: int,
ids: List[str], model) -> Tuple[List[str], List[str], List[float]]:
"""
Rerank a list of contents based on their relevance to a query using BAAI Reranker model.
:param query: The query to use for reranking
:param contents: The list of contents to rerank
:param scores: The list of scores retrieved from the initial ranking
:param ids: The list of ids retrieved from the initial ranking
:param top_k: The number of passages to be retrieved
:param model: BAAI Reranker model.
:return: tuple of lists containing the reranked contents, ids, and scores
"""
input_texts = [(query, content) for content in contents]
with torch.no_grad():
pred_scores = model.compute_score(sentence_pairs=input_texts)

content_ids_probs = list(zip(contents, ids, pred_scores))

# Sort the list of pairs based on the relevance score in descending order
sorted_content_ids_probs = sorted(content_ids_probs, key=lambda x: x[2], reverse=True)

# crop with top_k
if len(contents) < top_k:
top_k = len(contents)
sorted_content_ids_probs = sorted_content_ids_probs[:top_k]
return results['contents'].tolist(), results['ids'].tolist(), results['scores'].tolist()

content_result, id_result, score_result = zip(*sorted_content_ids_probs)

return list(content_result), list(id_result), list(score_result)
def flag_embedding_run_model(input_texts, model, batch_size: int):
batch_input_texts = make_batch(input_texts, batch_size)
results = []
for batch_texts in batch_input_texts:
with torch.no_grad():
pred_scores = model.compute_score(sentence_pairs=batch_texts)
results.extend(pred_scores)
return results
25 changes: 14 additions & 11 deletions autorag/nodes/passagereranker/flag_embedding_llm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import asyncio
from typing import List, Tuple

import pandas as pd
import torch
from FlagEmbedding import FlagLLMReranker

from autorag.nodes.passagereranker.base import passage_reranker_node
from autorag.nodes.passagereranker.flag_embedding import flag_embedding_reranker_pure
from autorag.utils.util import process_batch
from autorag.nodes.passagereranker.flag_embedding import flag_embedding_run_model
from autorag.utils.util import flatten_apply, sort_by_scores, select_top_k


@passage_reranker_node
Expand All @@ -32,16 +32,19 @@ def flag_embedding_llm_reranker(queries: List[str], contents_list: List[List[str
model = FlagLLMReranker(
model_name_or_path=model_name, use_fp16=use_fp16
)
tasks = [flag_embedding_reranker_pure(query, contents, scores, top_k, ids, model)
for query, contents, scores, ids in zip(queries, contents_list, scores_list, ids_list)]
loop = asyncio.get_event_loop()
results = loop.run_until_complete(process_batch(tasks, batch_size=batch))
content_result = list(map(lambda x: x[0], results))
id_result = list(map(lambda x: x[1], results))
score_result = list(map(lambda x: x[2], results))
nested_list = [list(map(lambda x: [query, x], content_list)) for query, content_list in zip(queries, contents_list)]
rerank_scores = flatten_apply(flag_embedding_run_model, nested_list, model=model, batch_size=batch)

df = pd.DataFrame({
'contents': contents_list,
'ids': ids_list,
'scores': rerank_scores,
})
df[['contents', 'ids', 'scores']] = df.apply(sort_by_scores, axis=1, result_type='expand')
results = select_top_k(df, ['contents', 'ids', 'scores'], top_k)

del model
if torch.cuda.is_available():
torch.cuda.empty_cache()

return content_result, id_result, score_result
return results['contents'].tolist(), results['ids'].tolist(), results['scores'].tolist()

0 comments on commit a71f8a4

Please sign in to comment.