Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
iuliaturc committed Oct 30, 2024
1 parent 62e7537 commit a42ab83
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
4 changes: 2 additions & 2 deletions sage/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ def build_reranker(provider: str, model: Optional[str] = None, top_k: int = 5) -
RerankerProvider.COHERE.value: "COHERE_API_KEY",
RerankerProvider.NVIDIA.value: "NVIDIA_API_KEY",
RerankerProvider.JINA.value: "JINA_API_KEY",
RerankerProvider.VOYAGE.value: "VOYAGE_API_KEY"
RerankerProvider.VOYAGE.value: "VOYAGE_API_KEY",
}

provider_defaults = {
RerankerProvider.HUGGINGFACE.value: "cross-encoder/ms-marco-MiniLM-L-6-v2",
RerankerProvider.COHERE.value: "rerank-english-v3.0",
RerankerProvider.NVIDIA.value: "nvidia/nv-rerankqa-mistral-4b-v3",
RerankerProvider.VOYAGE.value: "rerank-1"
RerankerProvider.VOYAGE.value: "rerank-1",
}

model = model or provider_defaults.get(provider)
Expand Down
25 changes: 14 additions & 11 deletions sage/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import marqo
import nltk
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import Marqo
from langchain_community.vectorstores import Pinecone as LangChainPinecone
from langchain_core.documents import Document
Expand Down Expand Up @@ -134,25 +134,28 @@ def upsert_batch(self, vectors: List[Vector], namespace: str):
self.index.upsert(vectors=pinecone_vectors, namespace=namespace)

def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str):
bm25_retriever = BM25Retriever(
embeddings=embeddings,
sparse_encoder=self.bm25_encoder,
index=self.index,
namespace=namespace,
top_k=top_k,
) if self.bm25_encoder else None
bm25_retriever = (
BM25Retriever(
embeddings=embeddings,
sparse_encoder=self.bm25_encoder,
index=self.index,
namespace=namespace,
top_k=top_k,
)
if self.bm25_encoder
else None
)

dense_retriever = LangChainPinecone.from_existing_index(
index_name=self.index_name, embedding=embeddings, namespace=namespace
).as_retriever(search_kwargs={"k": top_k})

if bm25_retriever:
return EnsembleRetriever(retrievers=[dense_retriever, bm25_retriever], weights=[self.alpha, 1-self.alpha])
return EnsembleRetriever(retrievers=[dense_retriever, bm25_retriever], weights=[self.alpha, 1 - self.alpha])
else:
return dense_retriever



class MarqoVectorStore(VectorStore):
"""Vector store implementation using Marqo."""

Expand Down

0 comments on commit a42ab83

Please sign in to comment.