Skip to content

Commit d62d415

Browse files
authored
feat: Cohere Reranking function (#82)
Closes #79
1 parent 6823575 commit d62d415

File tree

8 files changed

+303
-3
lines changed

8 files changed

+303
-3
lines changed

.github/workflows/test.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,4 @@ jobs:
8787
CF_GATEWAY_ENDPOINT: ${{ secrets.CF_GATEWAY_ENDPOINT }}
8888
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
8989
NOMIC_API_KEY: ${{ secrets.NOMIC_API_KEY }}
90+
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ pip install chromadbx
1919
- [SpaCy](https://github.com/amikos-tech/chromadbx/blob/main/docs/embeddings.md#spacy) embeddings
2020
- [Together](https://github.com/amikos-tech/chromadbx/blob/main/docs/embeddings.md#together) embeddings.
2121
- [Nomic](https://github.com/amikos-tech/chromadbx/blob/main/docs/embeddings.md#nomic) embeddings.
22+
- [Reranking](https://github.com/amikos-tech/chromadbx/blob/main/docs/reranking.md) - rerank documents and query results using Cohere, OpenAI, or custom reranking functions.
23+
- [Cohere](https://github.com/amikos-tech/chromadbx/blob/main/docs/reranking.md#cohere) - rerank documents and query results using Cohere.
2224

2325
## Usage
2426

chromadbx/reranking/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ class RerankedQueryResult(TypedDict):
4848
metadatas: Optional[List[List[Metadata]]]
4949
distances: Optional[List[List[Distance]]]
5050
included: Include
51-
ranked_distances: Dict[RerankerID, List[Distance]]
51+
ranked_distances: Dict[RerankerID, List[Distances]]
5252

5353

5454
class RerankedDocuments(TypedDict):
5555
documents: List[Documents]
5656
ranked_distances: Dict[RerankerID, Distances]
5757

5858

59-
RankedResults = Union[List[Documents], List[RerankedQueryResult]]
59+
RankedResults = Union[RerankedDocuments, RerankedQueryResult]
6060

6161
D = TypeVar("D", bound=Rerankable, contravariant=True)
6262
T = TypeVar("T", bound=RankedResults, covariant=True)

chromadbx/reranking/cohere.py

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import os
2+
from typing import Any, Dict, Optional, List
3+
4+
from chromadbx.reranking import (
5+
Queries,
6+
RankedResults,
7+
Rerankable,
8+
RerankedDocuments,
9+
RerankedQueryResult,
10+
RerankerID,
11+
RerankingFunction,
12+
)
13+
from chromadbx.reranking.utils import get_query_documents_tuples
14+
15+
16+
class CohereReranker(RerankingFunction[Rerankable, RankedResults]):
17+
def __init__(
18+
self,
19+
api_key: str,
20+
model_name: Optional[str] = "rerank-v3.5",
21+
*,
22+
raw_scores: bool = False,
23+
top_n: Optional[int] = None,
24+
max_tokens_per_document: Optional[int] = 4096,
25+
timeout: Optional[int] = 60,
26+
max_retries: Optional[int] = 3,
27+
additional_headers: Optional[Dict[str, Any]] = None,
28+
):
29+
"""
30+
Initialize the CohereReranker.
31+
32+
Args:
33+
api_key: The Cohere API key.
34+
model_name: The Cohere model to use for reranking. Defaults to `rerank-v3.5`.
35+
raw_scores: Whether to return the raw scores from the Cohere API. Defaults to `False`.
36+
top_n: The number of results to return. Defaults to `None`.
37+
max_tokens_per_document: The maximum number of tokens per document. Defaults to `4096`.
38+
timeout: The timeout for the Cohere API request. Defaults to `60`.
39+
max_retries: The maximum number of retries for the Cohere API request. Defaults to `3`.
40+
additional_headers: Additional headers to include in the Cohere API request. Defaults to `None`.
41+
"""
42+
try:
43+
import cohere
44+
from cohere.core.request_options import RequestOptions
45+
except ImportError:
46+
raise ImportError(
47+
"cohere is not installed. Please install it with `pip install cohere`"
48+
)
49+
if not api_key and not os.getenv("COHERE_API_KEY"):
50+
raise ValueError(
51+
"API key is required. Please set the COHERE_API_KEY environment variable or pass it directly."
52+
)
53+
if not model_name:
54+
raise ValueError(
55+
"Model name is required. Please set the model_name parameter or use the default value."
56+
)
57+
self._client = cohere.ClientV2(api_key)
58+
self._model_name = model_name
59+
self._top_n = top_n
60+
self._raw_scores = raw_scores
61+
self._max_tokens_per_document = max_tokens_per_document
62+
self._request_options = RequestOptions(
63+
timeout_in_seconds=timeout,
64+
max_retries=max_retries,
65+
additional_headers=additional_headers,
66+
)
67+
68+
def id(self) -> RerankerID:
69+
return RerankerID("cohere")
70+
71+
def _combine_reranked_results(
72+
self, results_list: List["cohere.v2.types.V2RerankResponse"], rerankables: Rerankable # type: ignore # noqa: F821
73+
) -> RankedResults:
74+
all_ordered_scores = []
75+
76+
for results in results_list:
77+
if self._raw_scores:
78+
ordered_scores = [
79+
r.relevance_score
80+
for r in sorted(results.results, key=lambda x: x.index) # type: ignore
81+
]
82+
else: # by default we calculate the distance to make results comparable with Chroma distance
83+
ordered_scores = [
84+
1 - r.relevance_score
85+
for r in sorted(results.results, key=lambda x: x.index) # type: ignore
86+
]
87+
all_ordered_scores.append(ordered_scores)
88+
89+
if isinstance(rerankables, dict):
90+
combined_ordered_scores = [
91+
score for sublist in all_ordered_scores for score in sublist
92+
]
93+
if len(rerankables["ids"]) != len(combined_ordered_scores):
94+
combined_ordered_scores = combined_ordered_scores + [None] * (
95+
len(rerankables["ids"]) - len(combined_ordered_scores)
96+
)
97+
return RerankedQueryResult(
98+
ids=rerankables["ids"],
99+
embeddings=rerankables["embeddings"]
100+
if "embeddings" in rerankables
101+
else None,
102+
documents=rerankables["documents"]
103+
if "documents" in rerankables
104+
else None,
105+
uris=rerankables["uris"] if "uris" in rerankables else None,
106+
data=rerankables["data"] if "data" in rerankables else None,
107+
metadatas=rerankables["metadatas"]
108+
if "metadatas" in rerankables
109+
else None,
110+
distances=rerankables["distances"]
111+
if "distances" in rerankables
112+
else None,
113+
included=rerankables["included"] if "included" in rerankables else None,
114+
ranked_distances={self.id(): combined_ordered_scores},
115+
)
116+
elif isinstance(rerankables, list):
117+
if len(results_list) > 1:
118+
raise ValueError("Cannot rerank documents with multiple results")
119+
combined_ordered_scores = [
120+
score for sublist in all_ordered_scores for score in sublist
121+
]
122+
if len(rerankables) != len(combined_ordered_scores):
123+
combined_ordered_scores = combined_ordered_scores + [None] * (
124+
len(rerankables) - len(combined_ordered_scores)
125+
)
126+
return RerankedDocuments(
127+
documents=rerankables,
128+
ranked_distances={self.id(): combined_ordered_scores},
129+
)
130+
else:
131+
raise ValueError("Invalid rerankables type")
132+
133+
def __call__(self, queries: Queries, rerankables: Rerankable) -> RankedResults:
134+
query_documents_tuples = get_query_documents_tuples(queries, rerankables)
135+
results = []
136+
for query, documents in query_documents_tuples:
137+
response = self._client.rerank(
138+
model=self._model_name,
139+
query=query,
140+
documents=documents,
141+
top_n=self._top_n or len(documents),
142+
max_tokens_per_doc=self._max_tokens_per_document,
143+
request_options=self._request_options,
144+
)
145+
results.append(response)
146+
return self._combine_reranked_results(results, rerankables)

docs/reranking.md

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Reranking
2+
3+
Reranking is a process of reordering a list of items based on their relevance to a query. This project supports reranking of documents and query results.
4+
5+
```python
6+
from chromadbx.reranking.some_reranker import SomeReranker
7+
import chromadb
8+
some_reranker = SomeReranker()
9+
10+
client = chromadb.Client()
11+
12+
collection = client.get_collection("documents")
13+
14+
results = collection.query(
15+
query_texts=["What is the capital of the United States?"],
16+
n_results=10,
17+
)
18+
19+
reranked_results = some_reranker(results)
20+
21+
print("Documents:", reranked_results["documents"][0])
22+
print("Distances:", reranked_results["distances"][0])
23+
print("Reranked distances:", reranked_results["ranked_distances"][some_reranker.id()][0])
24+
```
25+
26+
> [!NOTE]
27+
> It is our intent that all officially supported reranking functions shall return distances instead of scores to be consistent with the core Chroma project. However, this is not a hard requirement and you should check the documentation for each reranking function you plan to use.
28+
29+
The following reranking functions are supported:
30+
31+
| Reranking Function | Official Docs |
32+
| ------------------ | ------------- |
33+
| [Cohere](#cohere) | [docs](https://docs.cohere.com/docs/rerank-2) |
34+
35+
## Cohere
36+
37+
Cohere reranking function offers a convinient wrapper around the Cohere API to rerank documents and query results. For more information on Cohere reranking, visit the official [docs](https://docs.cohere.com/docs/rerank-2) or [API docs](https://docs.cohere.com/reference/rerank).
38+
39+
You need to install the `cohere` package to use this reranking function.
40+
41+
42+
```bash
43+
pip install cohere # or poetry add cohere
44+
```
45+
46+
Before using the reranking function, you need to obtain [Cohere API](https://dashboard.cohere.com/api-keys) key and set the `COHERE_API_KEY` environment variable.
47+
48+
> [!TIP]
49+
> By default, the reranking function will return distances. If you need to get the raw scores, set the `raw_scores` parameter to `True`.
50+
51+
```python
52+
import os
53+
import chromadb
54+
from chromadbx.reranking import CohereReranker
55+
56+
cohere = CohereReranker(api_key=os.getenv("COHERE_API_KEY"))
57+
58+
client = chromadb.Client()
59+
60+
collection = client.get_collection("documents")
61+
62+
results = collection.query(
63+
query_texts=["What is the capital of the United States?"],
64+
n_results=10,
65+
)
66+
67+
reranked_results = cohere(results)
68+
```
69+
70+
Available options:
71+
72+
- `api_key`: The Cohere API key.
73+
- `model_name`: The Cohere model to use for reranking. Defaults to `rerank-v3.5`.
74+
- `raw_scores`: Whether to return the raw scores from the Cohere API. Defaults to `False`.
75+
- `top_n`: The number of results to return. Defaults to `None`.
76+
- `max_tokens_per_document`: The maximum number of tokens per document. Defaults to `4096`.
77+
- `timeout`: The timeout for the Cohere API request. Defaults to `60`.
78+
- `max_retries`: The maximum number of retries for the Cohere API request. Defaults to `3`.
79+
- `additional_headers`: Additional headers to include in the Cohere API request. Defaults to `None`.

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ llama-embedder = "^0.0.7"
3434
mistralai = "^1.1.0"
3535
spacy = "^3.8.4"
3636
together = "^1.3.11"
37+
cohere = "^5.13.8"
3738

3839
[tool.poetry.extras]
3940
ids = ["ulid-py", "nanoid"]
4041
embeddings = ["llama-embedder", "onnxruntime", "huggingface_hub", "mistralai", "spacy", "together", "vertexai"]
42+
reranking = ["cohere"]
4143
core = ["chromadb"]
4244

4345
[build-system]

test/embeddings/test_nomic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33
from chromadbx.embeddings.nomic import NomicEmbeddingFunction
44

5-
httpx = pytest.importorskip("httpx", reason="nomic not installed")
5+
httpx = pytest.importorskip("httpx", reason="httpx not installed")
66

77

88
@pytest.mark.skipif(

test/reranking/test_cohere.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import os
2+
from typing import cast
3+
4+
from chromadb import QueryResult
5+
import pytest
6+
from chromadbx.reranking import RerankedDocuments, RerankedQueryResult
7+
from chromadbx.reranking.cohere import CohereReranker
8+
9+
10+
from unittest.mock import MagicMock
11+
12+
_cohere = pytest.importorskip("cohere", reason="cohere not installed")
13+
14+
15+
def test_cohere_mock_rerank_documents() -> None:
16+
mock_client = MagicMock()
17+
mock_client.rerank.return_value = MagicMock(results=[])
18+
19+
cohere = CohereReranker(api_key="test")
20+
cohere._client = mock_client
21+
22+
queries = "What is the capital of the United States?"
23+
rerankables = ["Washington, D.C.", "New York", "Los Angeles"]
24+
25+
cohere(queries, rerankables)
26+
mock_client.rerank.assert_called_once_with(
27+
model="rerank-v3.5",
28+
query=queries,
29+
documents=rerankables,
30+
top_n=len(rerankables),
31+
max_tokens_per_doc=4096,
32+
request_options=cohere._request_options,
33+
)
34+
35+
36+
@pytest.mark.skipif(
37+
os.getenv("COHERE_API_KEY") is None,
38+
reason="COHERE_API_KEY environment variable is not set",
39+
)
40+
def test_cohere_rerank_documents() -> None:
41+
cohere = CohereReranker(api_key=os.getenv("COHERE_API_KEY", ""))
42+
queries = "What is the capital of the United States?"
43+
rerankables = ["Washington, D.C.", "New York", "Los Angeles"]
44+
result = cast(RerankedDocuments, cohere(queries, rerankables))
45+
assert "ranked_distances" in result
46+
assert len(result["ranked_distances"][cohere.id()]) == len(rerankables)
47+
assert result["ranked_distances"][cohere.id()].index(
48+
min(result["ranked_distances"][cohere.id()])
49+
) == rerankables.index("Washington, D.C.")
50+
51+
52+
@pytest.mark.skipif(
53+
os.getenv("COHERE_API_KEY") is None,
54+
reason="COHERE_API_KEY environment variable is not set",
55+
)
56+
def test_cohere_rerank_documents_with_query_result() -> None:
57+
cohere = CohereReranker(api_key=os.getenv("COHERE_API_KEY", ""))
58+
queries = ["What is the capital of the United States?"]
59+
rerankables = QueryResult(
60+
documents=[["Washington, D.C.", "New York", "Los Angeles"]],
61+
metadatas=[[{"source": "test"}, {"source": "test"}, {"source": "test"}]],
62+
embeddings=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
63+
ids=[["id1", "id2", "id3"]],
64+
)
65+
result = cast(RerankedQueryResult, cohere(queries, rerankables))
66+
assert "ranked_distances" in result
67+
assert len(result["ranked_distances"][cohere.id()]) == len(rerankables["ids"][0])
68+
assert result["ranked_distances"][cohere.id()].index(
69+
min(result["ranked_distances"][cohere.id()])
70+
) == rerankables["ids"][0].index("id1")

0 commit comments

Comments
 (0)