Skip to content

Commit

Permalink
add compression parameter to embed
Browse files Browse the repository at this point in the history
  • Loading branch information
alekhya-n committed Oct 22, 2023
1 parent d6c9d60 commit 9052ec6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
9 changes: 3 additions & 6 deletions cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,7 @@ def embed(
texts: List[str],
model: Optional[str] = None,
truncate: Optional[str] = None,
compress: Optional[bool] = False,
compression_codebook: Optional[str] = "default",
compression: Optional[str] = None,
input_type: Optional[str] = None,
) -> Embeddings:
"""Returns an Embeddings object for the provided texts. Visit https://cohere.ai/embed to learn about embeddings.
Expand All @@ -403,8 +402,7 @@ def embed(
text (List[str]): A list of strings to embed.
model (str): (Optional) The model ID to use for embedding the text.
truncate (str): (Optional) One of NONE|START|END, defaults to END. How the API handles text longer than the maximum token length.
compress (bool): (Optional) Whether to compress the embeddings. When True, the compressed_embeddings will be returned as integers in the range [0, 255].
compression_codebook (str): (Optional) The compression codebook to use for compressed embeddings. Defaults to "default".
compression (str): (Optional) One of "int8" or "binary". The type of compression to use for the embeddings.
input_type (str): (Optional) One of "classification", "clustering", "search_document", "search_query". The type of input text provided to embed.
"""
responses = {
Expand All @@ -420,8 +418,7 @@ def embed(
"model": model,
"texts": texts_batch,
"truncate": truncate,
"compress": compress,
"compression_codebook": compression_codebook,
"compression": compression,
"input_type": input_type,
}
)
Expand Down
9 changes: 3 additions & 6 deletions cohere/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,7 @@ async def embed(
texts: List[str],
model: Optional[str] = None,
truncate: Optional[str] = None,
compress: Optional[bool] = False,
compression_codebook: Optional[str] = "default",
compression: Optional[str] = None,
input_type: Optional[str] = None,
) -> Embeddings:
"""Returns an Embeddings object for the provided texts. Visit https://cohere.ai/embed to learn about embeddings.
Expand All @@ -281,17 +280,15 @@ async def embed(
text (List[str]): A list of strings to embed.
model (str): (Optional) The model ID to use for embedding the text.
truncate (str): (Optional) One of NONE|START|END, defaults to END. How the API handles text longer than the maximum token length.
compress (bool): (Optional) Whether to compress the embeddings. When True, the compressed_embeddings will be returned as integers in the range [0, 255].
compression_codebook (str): (Optional) The compression codebook to use for compressed embeddings. Defaults to "default".
compression (str): (Optional) One of "int8" or "binary". The type of compression to use for the embeddings.
input_type (str): (Optional) One of "classification", "clustering", "search_document", "search_query". The type of input text provided to embed.
"""
json_bodys = [
dict(
texts=texts[i : i + cohere.COHERE_EMBED_BATCH_SIZE],
model=model,
truncate=truncate,
compress=compress,
compression_codebook=compression_codebook,
compression=compression,
input_type=input_type,
)
for i in range(0, len(texts), cohere.COHERE_EMBED_BATCH_SIZE)
Expand Down

0 comments on commit 9052ec6

Please sign in to comment.