diff --git a/src/together/client.py b/src/together/client.py index ab2e362e..91518230 100644 --- a/src/together/client.py +++ b/src/together/client.py @@ -18,6 +18,7 @@ class Together: images: resources.Images models: resources.Models fine_tuning: resources.FineTuning + rerank: resources.Rerank # client options client: TogetherClient @@ -77,6 +78,7 @@ def __init__( self.images = resources.Images(self.client) self.models = resources.Models(self.client) self.fine_tuning = resources.FineTuning(self.client) + self.rerank = resources.Rerank(self.client) class AsyncTogether: @@ -87,6 +89,7 @@ class AsyncTogether: images: resources.AsyncImages models: resources.AsyncModels fine_tuning: resources.AsyncFineTuning + rerank: resources.AsyncRerank # client options client: TogetherClient @@ -146,6 +149,7 @@ def __init__( self.images = resources.AsyncImages(self.client) self.models = resources.AsyncModels(self.client) self.fine_tuning = resources.AsyncFineTuning(self.client) + self.rerank = resources.AsyncRerank(self.client) Client = Together diff --git a/src/together/resources/__init__.py b/src/together/resources/__init__.py index 01931501..e5e85eac 100644 --- a/src/together/resources/__init__.py +++ b/src/together/resources/__init__.py @@ -5,6 +5,7 @@ from together.resources.finetune import AsyncFineTuning, FineTuning from together.resources.images import AsyncImages, Images from together.resources.models import AsyncModels, Models +from together.resources.rerank import AsyncRerank, Rerank __all__ = [ @@ -22,4 +23,6 @@ "Images", "AsyncModels", "Models", + "AsyncRerank", + "Rerank", ] diff --git a/src/together/resources/rerank.py b/src/together/resources/rerank.py new file mode 100644 index 00000000..4e9aacd4 --- /dev/null +++ b/src/together/resources/rerank.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from typing import List, Dict, Any + +from together.abstract import api_requestor +from together.together_response import TogetherResponse +from together.types import ( + RerankRequest, + RerankResponse, + TogetherClient, + TogetherRequest, +) + + +class Rerank: + def __init__(self, client: TogetherClient) -> None: + self._client = client + + def create( + self, + *, + model: str, + query: str, + documents: List[str] | List[Dict[str, Any]], + top_n: int | None = None, + return_documents: bool = False, + rank_fields: List[str] | None = None, + ) -> RerankResponse: + """ + Method to generate completions based on a given prompt using a specified model. + + Args: + model (str): The name of the model to query. + query (str): The input query or list of queries to rerank. + documents (List[str] | List[Dict[str, Any]]): List of documents to be reranked. + top_n (int | None): Number of top results to return. + return_documents (bool): Flag to indicate whether to return documents. + rank_fields (List[str] | None): Fields to be used for ranking the documents. + + Returns: + RerankResponse: Object containing reranked scores and documents + """ + + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + parameter_payload = RerankRequest( + model=model, + query=query, + documents=documents, + top_n=top_n, + return_documents=return_documents, + rank_fields=rank_fields, + ).model_dump(exclude_none=True) + + response, _, _ = requestor.request( + options=TogetherRequest( + method="POST", + url="rerank", + params=parameter_payload, + ), + stream=False, + ) + + assert isinstance(response, TogetherResponse) + + return RerankResponse(**response.data) + + +class AsyncRerank: + def __init__(self, client: TogetherClient) -> None: + self._client = client + + async def create( + self, + *, + model: str, + query: str, + documents: List[str] | List[Dict[str, Any]], + top_n: int | None = None, + return_documents: bool = False, + rank_fields: List[str] | None = None, + ) -> RerankResponse: + """ + Async method to generate completions based on a given prompt using a specified model. + + Args: + model (str): The name of the model to query. + query (str): The input query or list of queries to rerank. + documents (List[str] | List[Dict[str, Any]]): List of documents to be reranked. + top_n (int | None): Number of top results to return. + return_documents (bool): Flag to indicate whether to return documents. + rank_fields (List[str] | None): Fields to be used for ranking the documents. + + Returns: + RerankResponse: Object containing reranked scores and documents + """ + + requestor = api_requestor.APIRequestor( + client=self._client, + ) + + parameter_payload = RerankRequest( + model=model, + query=query, + documents=documents, + top_n=top_n, + return_documents=return_documents, + rank_fields=rank_fields, + ).model_dump(exclude_none=True) + + response, _, _ = await requestor.arequest( + options=TogetherRequest( + method="POST", + url="rerank", + params=parameter_payload, + ), + stream=False, + ) + + assert isinstance(response, TogetherResponse) + + return RerankResponse(**response.data) diff --git a/src/together/types/__init__.py b/src/together/types/__init__.py index b99cfb96..baca1789 100644 --- a/src/together/types/__init__.py +++ b/src/together/types/__init__.py @@ -35,7 +35,10 @@ ImageResponse, ) from together.types.models import ModelObject - +from together.types.rerank import ( + RerankRequest, + RerankResponse, +) __all__ = [ "TogetherClient", @@ -66,4 +69,6 @@ "TrainingType", "FullTrainingType", "LoRATrainingType", + "RerankRequest", + "RerankResponse", ] diff --git a/src/together/types/rerank.py b/src/together/types/rerank.py new file mode 100644 index 00000000..f8dc8b71 --- /dev/null +++ b/src/together/types/rerank.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import List, Literal, Dict, Any + +from together.types.abstract import BaseModel +from together.types.common import UsageData + + +class RerankRequest(BaseModel): + # model to query + model: str + # input or list of inputs + query: str + # list of documents + documents: List[str] | List[Dict[str, Any]] + # return top_n results + top_n: int | None = None + # boolean to return documents + return_documents: bool = False + # field selector for documents + rank_fields: List[str] | None = None + + +class RerankChoicesData(BaseModel): + # response index + index: int + # object type + relevance_score: float + # rerank response + document: Dict[str, Any] | None = None + + +class RerankResponse(BaseModel): + # job id + id: str | None = None + # object type + object: Literal["rerank"] | None = None + # query model + model: str | None = None + # list of reranked results + results: List[RerankChoicesData] | None = None + # usage stats + usage: UsageData | None = None diff --git a/tests/unit/test_async_client.py b/tests/unit/test_async_client.py index 18839628..0b11b39d 100644 --- a/tests/unit/test_async_client.py +++ b/tests/unit/test_async_client.py @@ -113,3 +113,12 @@ def test_fine_tuning_initialized(self, async_together_instance): assert async_together_instance.fine_tuning is not None assert isinstance(async_together_instance.fine_tuning._client, TogetherClient) + + def test_rerank_initialized(self, async_together_instance): + """ + Test initializing rerank + """ + + assert async_together_instance.rerank is not None + + assert isinstance(async_together_instance.rerank._client, TogetherClient) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 5855e55b..f8bdcbe6 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -114,3 +114,12 @@ def test_fine_tuning_initialized(self, sync_together_instance): assert sync_together_instance.fine_tuning is not None assert isinstance(sync_together_instance.fine_tuning._client, TogetherClient) + + def test_rerank_initialized(self, sync_together_instance): + """ + Test initializing rerank + """ + + assert sync_together_instance.rerank is not None + + assert isinstance(sync_together_instance.rerank._client, TogetherClient)