Skip to content

Commit

Permalink
Change text -> texts for embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
Konstantin Krestnikov authored and Konstantin Krestnikov committed Dec 27, 2023
1 parent 3e4c73e commit ab2ba2d
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/example_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
credentials=...,
verify_ssl_certs=False,
) as giga:
response = giga.embeddings("Hello world!")
response = giga.embeddings(["Hello world!"])
print(response)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "gigachat"
version = "0.1.12"
version = "0.1.12.1"
description = "GigaChat. Python-library for GigaChain and LangChain"
authors = ["Konstantin Krestnikov <[email protected]>", "Sergey Malyshev <[email protected]>"]
license = "MIT"
Expand Down
8 changes: 4 additions & 4 deletions src/gigachat/api/post_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from http import HTTPStatus
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import httpx

Expand All @@ -16,7 +16,7 @@

def _get_kwargs(
*,
input_: str,
input_: List[str],
model: str,
access_token: Optional[str] = None,
) -> Dict[str, Any]:
Expand Down Expand Up @@ -62,7 +62,7 @@ def _build_response(response: httpx.Response) -> Embeddings:
def sync(
client: httpx.Client,
*,
input_: str,
input_: List[str],
model: str,
access_token: Optional[str] = None,
) -> Embeddings:
Expand All @@ -74,7 +74,7 @@ def sync(
async def asyncio(
client: httpx.AsyncClient,
*,
input_: str,
input_: List[str],
model: str,
access_token: Optional[str] = None,
) -> Embeddings:
Expand Down
8 changes: 4 additions & 4 deletions src/gigachat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,10 @@ def tokens_count(self, input_: List[str], model: Optional[str] = None) -> List[T
lambda: post_tokens_count.sync(self._client, input_=input_, model=model, access_token=self.token)
)

def embeddings(self, text: str, model: str = "Embeddings") -> Embeddings:
def embeddings(self, texts: List[str], model: str = "Embeddings") -> Embeddings:
"""Возвращает эмбеддинги"""
return self._decorator(
lambda: post_embeddings.sync(self._client, access_token=self.token, input_=text, model=model)
lambda: post_embeddings.sync(self._client, access_token=self.token, input_=texts, model=model)
)

def get_models(self) -> Models:
Expand Down Expand Up @@ -301,11 +301,11 @@ async def _acall() -> List[TokensCount]:

return await self._adecorator(_acall)

async def aembeddings(self, text: str, model: str = "Embeddings") -> Embeddings:
async def aembeddings(self, texts: List[str], model: str = "Embeddings") -> Embeddings:
"""Возвращает эмбеддинги"""

async def _acall() -> Embeddings:
return await post_embeddings.asyncio(self._aclient, access_token=self.token, input_=text, model=model)
return await post_embeddings.asyncio(self._aclient, access_token=self.token, input_=texts, model=model)

return await self._adecorator(_acall)

Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/gigachat/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def test_embeddings(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=EMBEDDINGS_URL, json=EMBEDDINGS)

with GigaChatSyncClient(base_url=BASE_URL) as client:
response = client.embeddings(text="text", model="model")
response = client.embeddings(texts=["text"], model="model")
assert isinstance(response, Embeddings)
for row in response.data:
assert isinstance(row, Embedding)
Expand Down Expand Up @@ -456,7 +456,7 @@ async def test_aembeddings(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=EMBEDDINGS_URL, json=EMBEDDINGS)

async with GigaChatAsyncClient(base_url=BASE_URL) as client:
response = await client.aembeddings(text="text", model="model")
response = await client.aembeddings(texts=["text"], model="model")
assert isinstance(response, Embeddings)
for row in response.data:
assert isinstance(row, Embedding)
Expand Down

0 comments on commit ab2ba2d

Please sign in to comment.