Skip to content

Commit

Permalink
add rerank api support
Browse files Browse the repository at this point in the history
  • Loading branch information
orangetin committed Aug 26, 2024
1 parent af7826b commit b248178
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/together/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Together:
images: resources.Images
models: resources.Models
fine_tuning: resources.FineTuning
rerank: resources.Rerank

# client options
client: TogetherClient
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
self.images = resources.Images(self.client)
self.models = resources.Models(self.client)
self.fine_tuning = resources.FineTuning(self.client)
self.rerank = resources.Rerank(self.client)


class AsyncTogether:
Expand All @@ -87,6 +89,7 @@ class AsyncTogether:
images: resources.AsyncImages
models: resources.AsyncModels
fine_tuning: resources.AsyncFineTuning
rerank: resources.AsyncRerank

# client options
client: TogetherClient
Expand Down Expand Up @@ -146,6 +149,7 @@ def __init__(
self.images = resources.AsyncImages(self.client)
self.models = resources.AsyncModels(self.client)
self.fine_tuning = resources.AsyncFineTuning(self.client)
self.rerank = resources.AsyncRerank(self.client)


Client = Together
Expand Down
3 changes: 3 additions & 0 deletions src/together/resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from together.resources.finetune import AsyncFineTuning, FineTuning
from together.resources.images import AsyncImages, Images
from together.resources.models import AsyncModels, Models
from together.resources.rerank import AsyncRerank, Rerank


__all__ = [
Expand All @@ -22,4 +23,6 @@
"Images",
"AsyncModels",
"Models",
"AsyncRerank",
"Rerank",
]
124 changes: 124 additions & 0 deletions src/together/resources/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

from typing import List, Dict, Any

from together.abstract import api_requestor
from together.together_response import TogetherResponse
from together.types import (
RerankRequest,
RerankResponse,
TogetherClient,
TogetherRequest,
)


class Rerank:
def __init__(self, client: TogetherClient) -> None:
self._client = client

def create(
self,
*,
model: str,
query: str,
documents: List[str] | List[Dict[str, Any]],
top_n: int | None = None,
return_documents: bool = False,
rank_fields: List[str] | None = None,
) -> RerankResponse:
"""
Method to generate completions based on a given prompt using a specified model.
Args:
model (str): The name of the model to query.
query (str): The input query or list of queries to rerank.
documents (List[str] | List[Dict[str, Any]]): List of documents to be reranked.
top_n (int | None): Number of top results to return.
return_documents (bool): Flag to indicate whether to return documents.
rank_fields (List[str] | None): Fields to be used for ranking the documents.
Returns:
RerankResponse: Object containing reranked scores and documents
"""

requestor = api_requestor.APIRequestor(
client=self._client,
)

parameter_payload = RerankRequest(
model=model,
query=query,
documents=documents,
top_n=top_n,
return_documents=return_documents,
rank_fields=rank_fields,
).model_dump(exclude_none=True)

response, _, _ = requestor.request(
options=TogetherRequest(
method="POST",
url="rerank",
params=parameter_payload,
),
stream=False,
)

assert isinstance(response, TogetherResponse)

return RerankResponse(**response.data)


class AsyncRerank:
def __init__(self, client: TogetherClient) -> None:
self._client = client

async def create(
self,
*,
model: str,
query: str,
documents: List[str] | List[Dict[str, Any]],
top_n: int | None = None,
return_documents: bool = False,
rank_fields: List[str] | None = None,
) -> RerankResponse:
"""
Async method to generate completions based on a given prompt using a specified model.
Args:
model (str): The name of the model to query.
query (str): The input query or list of queries to rerank.
documents (List[str] | List[Dict[str, Any]]): List of documents to be reranked.
top_n (int | None): Number of top results to return.
return_documents (bool): Flag to indicate whether to return documents.
rank_fields (List[str] | None): Fields to be used for ranking the documents.
Returns:
RerankResponse: Object containing reranked scores and documents
"""

requestor = api_requestor.APIRequestor(
client=self._client,
)

parameter_payload = RerankRequest(
model=model,
query=query,
documents=documents,
top_n=top_n,
return_documents=return_documents,
rank_fields=rank_fields,
).model_dump(exclude_none=True)

response, _, _ = await requestor.arequest(
options=TogetherRequest(
method="POST",
url="rerank",
params=parameter_payload,
),
stream=False,
)

assert isinstance(response, TogetherResponse)

return RerankResponse(**response.data)
7 changes: 6 additions & 1 deletion src/together/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
ImageResponse,
)
from together.types.models import ModelObject

from together.types.rerank import (
RerankRequest,
RerankResponse,
)

__all__ = [
"TogetherClient",
Expand Down Expand Up @@ -66,4 +69,6 @@
"TrainingType",
"FullTrainingType",
"LoRATrainingType",
"RerankRequest",
"RerankResponse",
]
43 changes: 43 additions & 0 deletions src/together/types/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

from typing import List, Literal, Dict, Any

from together.types.abstract import BaseModel
from together.types.common import UsageData


class RerankRequest(BaseModel):
# model to query
model: str
# input or list of inputs
query: str
# list of documents
documents: List[str] | List[Dict[str, Any]]
# return top_n results
top_n: int | None = None
# boolean to return documents
return_documents: bool = False
# field selector for documents
rank_fields: List[str] | None = None


class RerankChoicesData(BaseModel):
# response index
index: int
# object type
relevance_score: float
# rerank response
document: Dict[str, Any] | None = None


class RerankResponse(BaseModel):
# job id
id: str | None = None
# object type
object: Literal["rerank"] | None = None
# query model
model: str | None = None
# list of reranked results
results: List[RerankChoicesData] | None = None
# usage stats
usage: UsageData | None = None
9 changes: 9 additions & 0 deletions tests/unit/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,12 @@ def test_fine_tuning_initialized(self, async_together_instance):
assert async_together_instance.fine_tuning is not None

assert isinstance(async_together_instance.fine_tuning._client, TogetherClient)

def test_rerank_initialized(self, async_together_instance):
"""
Test initializing rerank
"""

assert async_together_instance.rerank is not None

assert isinstance(async_together_instance.rerank._client, TogetherClient)
9 changes: 9 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,12 @@ def test_fine_tuning_initialized(self, sync_together_instance):
assert sync_together_instance.fine_tuning is not None

assert isinstance(sync_together_instance.fine_tuning._client, TogetherClient)

def test_rerank_initialized(self, sync_together_instance):
"""
Test initializing rerank
"""

assert sync_together_instance.rerank is not None

assert isinstance(sync_together_instance.rerank._client, TogetherClient)

0 comments on commit b248178

Please sign in to comment.