|
| 1 | +import os |
| 2 | +from typing import Any, Dict, Optional, List |
| 3 | + |
| 4 | +from chromadbx.reranking import ( |
| 5 | + Queries, |
| 6 | + RankedResults, |
| 7 | + Rerankable, |
| 8 | + RerankedDocuments, |
| 9 | + RerankedQueryResult, |
| 10 | + RerankerID, |
| 11 | + RerankingFunction, |
| 12 | +) |
| 13 | +from chromadbx.reranking.utils import get_query_documents_tuples |
| 14 | + |
| 15 | + |
| 16 | +class CohereReranker(RerankingFunction[Rerankable, RankedResults]): |
| 17 | + def __init__( |
| 18 | + self, |
| 19 | + api_key: str, |
| 20 | + model_name: Optional[str] = "rerank-v3.5", |
| 21 | + *, |
| 22 | + raw_scores: bool = False, |
| 23 | + top_n: Optional[int] = None, |
| 24 | + max_tokens_per_document: Optional[int] = 4096, |
| 25 | + timeout: Optional[int] = 60, |
| 26 | + max_retries: Optional[int] = 3, |
| 27 | + additional_headers: Optional[Dict[str, Any]] = None, |
| 28 | + ): |
| 29 | + """ |
| 30 | + Initialize the CohereReranker. |
| 31 | +
|
| 32 | + Args: |
| 33 | + api_key: The Cohere API key. |
| 34 | + model_name: The Cohere model to use for reranking. Defaults to `rerank-v3.5`. |
| 35 | + raw_scores: Whether to return the raw scores from the Cohere API. Defaults to `False`. |
| 36 | + top_n: The number of results to return. Defaults to `None`. |
| 37 | + max_tokens_per_document: The maximum number of tokens per document. Defaults to `4096`. |
| 38 | + timeout: The timeout for the Cohere API request. Defaults to `60`. |
| 39 | + max_retries: The maximum number of retries for the Cohere API request. Defaults to `3`. |
| 40 | + additional_headers: Additional headers to include in the Cohere API request. Defaults to `None`. |
| 41 | + """ |
| 42 | + try: |
| 43 | + import cohere |
| 44 | + from cohere.core.request_options import RequestOptions |
| 45 | + except ImportError: |
| 46 | + raise ImportError( |
| 47 | + "cohere is not installed. Please install it with `pip install cohere`" |
| 48 | + ) |
| 49 | + if not api_key and not os.getenv("COHERE_API_KEY"): |
| 50 | + raise ValueError( |
| 51 | + "API key is required. Please set the COHERE_API_KEY environment variable or pass it directly." |
| 52 | + ) |
| 53 | + if not model_name: |
| 54 | + raise ValueError( |
| 55 | + "Model name is required. Please set the model_name parameter or use the default value." |
| 56 | + ) |
| 57 | + self._client = cohere.ClientV2(api_key) |
| 58 | + self._model_name = model_name |
| 59 | + self._top_n = top_n |
| 60 | + self._raw_scores = raw_scores |
| 61 | + self._max_tokens_per_document = max_tokens_per_document |
| 62 | + self._request_options = RequestOptions( |
| 63 | + timeout_in_seconds=timeout, |
| 64 | + max_retries=max_retries, |
| 65 | + additional_headers=additional_headers, |
| 66 | + ) |
| 67 | + |
| 68 | + def id(self) -> RerankerID: |
| 69 | + return RerankerID("cohere") |
| 70 | + |
| 71 | + def _combine_reranked_results( |
| 72 | + self, results_list: List["cohere.v2.types.V2RerankResponse"], rerankables: Rerankable # type: ignore # noqa: F821 |
| 73 | + ) -> RankedResults: |
| 74 | + all_ordered_scores = [] |
| 75 | + |
| 76 | + for results in results_list: |
| 77 | + if self._raw_scores: |
| 78 | + ordered_scores = [ |
| 79 | + r.relevance_score |
| 80 | + for r in sorted(results.results, key=lambda x: x.index) # type: ignore |
| 81 | + ] |
| 82 | + else: # by default we calculate the distance to make results comparable with Chroma distance |
| 83 | + ordered_scores = [ |
| 84 | + 1 - r.relevance_score |
| 85 | + for r in sorted(results.results, key=lambda x: x.index) # type: ignore |
| 86 | + ] |
| 87 | + all_ordered_scores.append(ordered_scores) |
| 88 | + |
| 89 | + if isinstance(rerankables, dict): |
| 90 | + combined_ordered_scores = [ |
| 91 | + score for sublist in all_ordered_scores for score in sublist |
| 92 | + ] |
| 93 | + if len(rerankables["ids"]) != len(combined_ordered_scores): |
| 94 | + combined_ordered_scores = combined_ordered_scores + [None] * ( |
| 95 | + len(rerankables["ids"]) - len(combined_ordered_scores) |
| 96 | + ) |
| 97 | + return RerankedQueryResult( |
| 98 | + ids=rerankables["ids"], |
| 99 | + embeddings=rerankables["embeddings"] |
| 100 | + if "embeddings" in rerankables |
| 101 | + else None, |
| 102 | + documents=rerankables["documents"] |
| 103 | + if "documents" in rerankables |
| 104 | + else None, |
| 105 | + uris=rerankables["uris"] if "uris" in rerankables else None, |
| 106 | + data=rerankables["data"] if "data" in rerankables else None, |
| 107 | + metadatas=rerankables["metadatas"] |
| 108 | + if "metadatas" in rerankables |
| 109 | + else None, |
| 110 | + distances=rerankables["distances"] |
| 111 | + if "distances" in rerankables |
| 112 | + else None, |
| 113 | + included=rerankables["included"] if "included" in rerankables else None, |
| 114 | + ranked_distances={self.id(): combined_ordered_scores}, |
| 115 | + ) |
| 116 | + elif isinstance(rerankables, list): |
| 117 | + if len(results_list) > 1: |
| 118 | + raise ValueError("Cannot rerank documents with multiple results") |
| 119 | + combined_ordered_scores = [ |
| 120 | + score for sublist in all_ordered_scores for score in sublist |
| 121 | + ] |
| 122 | + if len(rerankables) != len(combined_ordered_scores): |
| 123 | + combined_ordered_scores = combined_ordered_scores + [None] * ( |
| 124 | + len(rerankables) - len(combined_ordered_scores) |
| 125 | + ) |
| 126 | + return RerankedDocuments( |
| 127 | + documents=rerankables, |
| 128 | + ranked_distances={self.id(): combined_ordered_scores}, |
| 129 | + ) |
| 130 | + else: |
| 131 | + raise ValueError("Invalid rerankables type") |
| 132 | + |
| 133 | + def __call__(self, queries: Queries, rerankables: Rerankable) -> RankedResults: |
| 134 | + query_documents_tuples = get_query_documents_tuples(queries, rerankables) |
| 135 | + results = [] |
| 136 | + for query, documents in query_documents_tuples: |
| 137 | + response = self._client.rerank( |
| 138 | + model=self._model_name, |
| 139 | + query=query, |
| 140 | + documents=documents, |
| 141 | + top_n=self._top_n or len(documents), |
| 142 | + max_tokens_per_doc=self._max_tokens_per_document, |
| 143 | + request_options=self._request_options, |
| 144 | + ) |
| 145 | + results.append(response) |
| 146 | + return self._combine_reranked_results(results, rerankables) |
0 commit comments