diff --git a/sage/reranker.py b/sage/reranker.py index db1e008..0aa2c37 100644 --- a/sage/reranker.py +++ b/sage/reranker.py @@ -28,14 +28,14 @@ def build_reranker(provider: str, model: Optional[str] = None, top_k: int = 5) - RerankerProvider.COHERE.value: "COHERE_API_KEY", RerankerProvider.NVIDIA.value: "NVIDIA_API_KEY", RerankerProvider.JINA.value: "JINA_API_KEY", - RerankerProvider.VOYAGE.value: "VOYAGE_API_KEY" + RerankerProvider.VOYAGE.value: "VOYAGE_API_KEY", } provider_defaults = { RerankerProvider.HUGGINGFACE.value: "cross-encoder/ms-marco-MiniLM-L-6-v2", RerankerProvider.COHERE.value: "rerank-english-v3.0", RerankerProvider.NVIDIA.value: "nvidia/nv-rerankqa-mistral-4b-v3", - RerankerProvider.VOYAGE.value: "rerank-1" + RerankerProvider.VOYAGE.value: "rerank-1", } model = model or provider_defaults.get(provider) diff --git a/sage/vector_store.py b/sage/vector_store.py index d9f69e0..826980e 100644 --- a/sage/vector_store.py +++ b/sage/vector_store.py @@ -8,8 +8,8 @@ import marqo import nltk -from langchain_community.retrievers import BM25Retriever from langchain.retrievers import EnsembleRetriever +from langchain_community.retrievers import BM25Retriever from langchain_community.vectorstores import Marqo from langchain_community.vectorstores import Pinecone as LangChainPinecone from langchain_core.documents import Document @@ -134,25 +134,28 @@ def upsert_batch(self, vectors: List[Vector], namespace: str): self.index.upsert(vectors=pinecone_vectors, namespace=namespace) def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str): - bm25_retriever = BM25Retriever( - embeddings=embeddings, - sparse_encoder=self.bm25_encoder, - index=self.index, - namespace=namespace, - top_k=top_k, - ) if self.bm25_encoder else None + bm25_retriever = ( + BM25Retriever( + embeddings=embeddings, + sparse_encoder=self.bm25_encoder, + index=self.index, + namespace=namespace, + top_k=top_k, + ) + if self.bm25_encoder + else None + ) dense_retriever = LangChainPinecone.from_existing_index( index_name=self.index_name, embedding=embeddings, namespace=namespace ).as_retriever(search_kwargs={"k": top_k}) - + if bm25_retriever: - return EnsembleRetriever(retrievers=[dense_retriever, bm25_retriever], weights=[self.alpha, 1-self.alpha]) + return EnsembleRetriever(retrievers=[dense_retriever, bm25_retriever], weights=[self.alpha, 1 - self.alpha]) else: return dense_retriever - class MarqoVectorStore(VectorStore): """Vector store implementation using Marqo."""