Skip to content

Commit

Permalink
Implemented Reciprocal Rank Fusion (#87)
Browse files Browse the repository at this point in the history
* Implemented Ensemble Retriever

* Used self.alpha for the weights

* Cleaned up the return statement
  • Loading branch information
aarya-16 authored Oct 29, 2024
1 parent 045e127 commit a6571a5
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions sage/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import marqo
import nltk
from langchain_community.retrievers import PineconeHybridSearchRetriever
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain_community.vectorstores import Marqo
from langchain_community.vectorstores import Pinecone as LangChainPinecone
from langchain_core.documents import Document
Expand Down Expand Up @@ -133,19 +134,23 @@ def upsert_batch(self, vectors: List[Vector], namespace: str):
self.index.upsert(vectors=pinecone_vectors, namespace=namespace)

def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str):
if self.bm25_encoder:
return PineconeHybridSearchRetriever(
embeddings=embeddings,
sparse_encoder=self.bm25_encoder,
index=self.index,
namespace=namespace,
top_k=top_k,
alpha=self.alpha,
)

return LangChainPinecone.from_existing_index(
bm25_retriever = BM25Retriever(
embeddings=embeddings,
sparse_encoder=self.bm25_encoder,
index=self.index,
namespace=namespace,
top_k=top_k,
) if self.bm25_encoder else None

dense_retriever = LangChainPinecone.from_existing_index(
index_name=self.index_name, embedding=embeddings, namespace=namespace
).as_retriever(search_kwargs={"k": top_k})

if bm25_retriever:
return EnsembleRetriever(retrievers=[dense_retriever, bm25_retriever], weights=[self.alpha, 1-self.alpha])
else:
return dense_retriever



class MarqoVectorStore(VectorStore):
Expand Down

0 comments on commit a6571a5

Please sign in to comment.