diff --git a/src/cohere/client.py b/src/cohere/client.py index 14ef841f1..40d9cc241 100644 --- a/src/cohere/client.py +++ b/src/cohere/client.py @@ -178,7 +178,8 @@ def __exit__(self, exc_type, exc_value, traceback): def embed( self, *, - texts: typing.Sequence[str], + texts: typing.Optional[typing.Sequence[str]] = OMIT, + images: typing.Optional[typing.Sequence[str]] = OMIT, model: typing.Optional[str] = OMIT, input_type: typing.Optional[EmbedInputType] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, @@ -190,6 +191,7 @@ def embed( return BaseCohere.embed( self, texts=texts, + images=images, model=model, input_type=input_type, embedding_types=embedding_types, @@ -197,21 +199,28 @@ def embed( request_options=request_options, ) + texts = texts or [] texts_batches = [texts[i : i + embed_batch_size] for i in range(0, len(texts), embed_batch_size)] + images = images or [] + images_batches = [images[i : i + embed_batch_size] for i in range(0, len(images), embed_batch_size)] + + zipped = zip(texts_batches, images_batches) + responses = [ response for response in self._executor.map( - lambda text_batch: BaseCohere.embed( + lambda batch: BaseCohere.embed( self, - texts=text_batch, + texts=batch[0], + images=batch[1], model=model, input_type=input_type, embedding_types=embedding_types, truncate=truncate, request_options=request_options, ), - texts_batches, + zipped, ) ] @@ -366,7 +375,8 @@ async def __aexit__(self, exc_type, exc_value, traceback): async def embed( self, *, - texts: typing.Sequence[str], + texts: typing.Optional[typing.Sequence[str]] = OMIT, + images: typing.Optional[typing.Sequence[str]] = OMIT, model: typing.Optional[str] = OMIT, input_type: typing.Optional[EmbedInputType] = OMIT, embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT, @@ -385,22 +395,30 @@ async def embed( request_options=request_options, ) + + texts = texts or [] texts_batches = [texts[i : i + embed_batch_size] for i in range(0, len(texts), embed_batch_size)] + images = images or [] + images_batches = [images[i : i + embed_batch_size] for i in range(0, len(images), embed_batch_size)] + + zipped = zip(texts_batches, images_batches) + responses = typing.cast( typing.List[EmbedResponse], await asyncio.gather( *[ AsyncBaseCohere.embed( self, - texts=text_batch, + texts=batch[0], + images=batch[1], model=model, input_type=input_type, embedding_types=embedding_types, truncate=truncate, request_options=request_options, ) - for text_batch in texts_batches + for batch in zipped ] ), )