Skip to content

Commit

Permalink
Update reranker.py (#99)
Browse files Browse the repository at this point in the history
Refactored reranker.py
  • Loading branch information
ghimirebibek authored Oct 30, 2024
1 parent a6571a5 commit 72d2a87
Showing 1 changed file with 36 additions and 21 deletions.
57 changes: 36 additions & 21 deletions sage/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,45 @@ class RerankerProvider(Enum):
VOYAGE = "voyage"


def build_reranker(provider: str, model: Optional[str] = None, top_k: Optional[int] = 5) -> BaseDocumentCompressor:
def build_reranker(provider: str, model: Optional[str] = None, top_k: int = 5) -> Optional[BaseDocumentCompressor]:
if provider == RerankerProvider.NONE.value:
return None

api_key_env_vars = {
RerankerProvider.COHERE.value: "COHERE_API_KEY",
RerankerProvider.NVIDIA.value: "NVIDIA_API_KEY",
RerankerProvider.JINA.value: "JINA_API_KEY",
RerankerProvider.VOYAGE.value: "VOYAGE_API_KEY"
}

provider_defaults = {
RerankerProvider.HUGGINGFACE.value: "cross-encoder/ms-marco-MiniLM-L-6-v2",
RerankerProvider.COHERE.value: "rerank-english-v3.0",
RerankerProvider.NVIDIA.value: "nvidia/nv-rerankqa-mistral-4b-v3",
RerankerProvider.VOYAGE.value: "rerank-1"
}

model = model or provider_defaults.get(provider)

if provider == RerankerProvider.HUGGINGFACE.value:
model = model or "cross-encoder/ms-marco-MiniLM-L-6-v2"
encoder_model = HuggingFaceCrossEncoder(model_name=model)
return CrossEncoderReranker(model=encoder_model, top_n=top_k)
if provider == RerankerProvider.COHERE.value:
if not os.environ.get("COHERE_API_KEY"):
raise ValueError("Please set the COHERE_API_KEY environment variable")
model = model or "rerank-english-v3.0"
return CohereRerank(model=model, cohere_api_key=os.environ.get("COHERE_API_KEY"), top_n=top_k)
if provider == RerankerProvider.NVIDIA.value:
if not os.environ.get("NVIDIA_API_KEY"):
raise ValueError("Please set the NVIDIA_API_KEY environment variable")
model = model or "nvidia/nv-rerankqa-mistral-4b-v3"
return NVIDIARerank(model=model, api_key=os.environ.get("NVIDIA_API_KEY"), top_n=top_k, truncate="END")
if provider == RerankerProvider.JINA.value:
if not os.environ.get("JINA_API_KEY"):
raise ValueError("Please set the JINA_API_KEY environment variable")
return JinaRerank(top_n=top_k)
if provider == RerankerProvider.VOYAGE.value:
if not os.environ.get("VOYAGE_API_KEY"):
raise ValueError("Please set the VOYAGE_API_KEY environment variable")
model = model or "rerank-1"
return VoyageAIRerank(model=model, api_key=os.environ.get("VOYAGE_API_KEY"), top_k=top_k)

if provider in api_key_env_vars:
api_key = os.getenv(api_key_env_vars[provider])
if not api_key:
raise ValueError(f"Please set the {api_key_env_vars[provider]} environment variable")

if provider == RerankerProvider.COHERE.value:
return CohereRerank(model=model, cohere_api_key=api_key, top_n=top_k)

if provider == RerankerProvider.NVIDIA.value:
return NVIDIARerank(model=model, api_key=api_key, top_n=top_k, truncate="END")

if provider == RerankerProvider.JINA.value:
return JinaRerank(top_n=top_k)

if provider == RerankerProvider.VOYAGE.value:
return VoyageAIRerank(model=model, api_key=api_key, top_k=top_k)

raise ValueError(f"Invalid reranker provider: {provider}")

0 comments on commit 72d2a87

Please sign in to comment.