Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add POST /tokens/count API support #7

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
* [Взаимная аутентификация по протоколу TLS (mTLS)](./example_auth_certs_mtls.py)
* [Пример вопрос - ответ, для проверки сертификатов используем файл с корневым сертификатом Минцифры России](./example_russian_trusted_root_ca.py)
* [Пример перенаправления заголовков запроса в GigaChat](./example_contextvars.py)
* [Пример подсчёта токенов](./example_tokens.py)
7 changes: 7 additions & 0 deletions examples/example_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Пример подсчета токенов в тексте"""
from gigachat import GigaChat

# Используйте токен, полученный в личном кабинете из поля Авторизационные данные
with GigaChat(credentials=..., verify_ssl_certs=False, model='GigaChat-Pro', ) as giga:
result = giga.tokens_count(input=["12345"], model="GigaChat-Pro")
print(result)
91 changes: 91 additions & 0 deletions src/gigachat/api/post_tokens_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from http import HTTPStatus
from typing import Any, Dict, Optional, List

import httpx

from gigachat.context import (
authorization_cvar,
client_id_cvar,
operation_id_cvar,
request_id_cvar,
service_id_cvar,
session_id_cvar,
)
from gigachat.exceptions import AuthenticationError, ResponseError
from gigachat.models import TokensCount


def _get_kwargs(
*,
input: List[str],
model: str,
access_token: Optional[str] = None,
) -> Dict[str, Any]:
headers = {}

if access_token:
headers["Authorization"] = f"Bearer {access_token}"

authorization = authorization_cvar.get()
client_id = client_id_cvar.get()
session_id = session_id_cvar.get()
request_id = request_id_cvar.get()
service_id = service_id_cvar.get()
operation_id = operation_id_cvar.get()

if authorization:
headers["Authorization"] = authorization
if client_id:
headers["X-Client-ID"] = client_id
if session_id:
headers["X-Session-ID"] = session_id
if request_id:
headers["X-Request-ID"] = request_id
if service_id:
headers["X-Service-ID"] = service_id
if operation_id:
headers["X-Operation-ID"] = operation_id
json_data = {"model": model, "input": input}

return {
"method": "POST",
"url": "/tokens/count",
"headers": headers,
"json": json_data
}


def _build_response(response: httpx.Response) -> List[TokensCount]:
print(response.json()[0])
if response.status_code == HTTPStatus.OK:
return [TokensCount(**row) for row in response.json()]
elif response.status_code == HTTPStatus.UNAUTHORIZED:
raise AuthenticationError(response.url, response.status_code, response.content, response.headers)
else:
raise ResponseError(response.url, response.status_code, response.content, response.headers)


def sync(
client: httpx.Client,
*,
input: List[str],
model: str,
access_token: Optional[str] = None,
) -> List[TokensCount]:
"""Возвращает объект с информацией о количестве токенов"""
kwargs = _get_kwargs(access_token=access_token, input=input, model=model)
response = client.request(**kwargs)
return _build_response(response)


async def asyncio(
client: httpx.AsyncClient,
*,
input: List[str],
model: str,
access_token: Optional[str] = None,
) -> List[TokensCount]:
"""Возвращает объект с информацией о количестве токенов"""
kwargs = _get_kwargs(access_token=access_token, input=input, model=model)
response = await client.request(**kwargs)
return _build_response(response)
14 changes: 13 additions & 1 deletion src/gigachat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
Optional,
TypeVar,
Union,
List
)

import httpx

from gigachat.api import get_model, get_models, post_auth, post_chat, post_token, stream_chat
from gigachat.api import get_model, get_models, post_auth, post_chat, post_token, stream_chat, post_tokens_count
from gigachat.exceptions import AuthenticationError
from gigachat.models import (
AccessToken,
Expand All @@ -26,6 +27,7 @@
Model,
Models,
Token,
TokensCount,
)
from gigachat.settings import Settings

Expand Down Expand Up @@ -188,6 +190,10 @@ def _decorator(self, call: Callable[..., T]) -> T:
self._update_token()
return call()

def tokens_count(self, input: List[str], model: str) -> List[TokensCount]:
"""Возвращает объект с информацией о количестве токенов"""
return self._decorator(lambda: post_tokens_count.sync(self._client, access_token=self.token, input=input, model=model))

def get_models(self) -> Models:
"""Возвращает массив объектов с данными доступных моделей"""
return self._decorator(lambda: get_models.sync(self._client, access_token=self.token))
Expand Down Expand Up @@ -267,6 +273,12 @@ async def _adecorator(self, acall: Callable[..., Awaitable[T]]) -> T:
await self._aupdate_token()
return await acall()

async def atokens_count(self, input: List[str], model: str) -> List[TokensCount]:
async def _acall() -> Model:
return await post_tokens_count.asyncio(self._aclient, access_token=self.token, input=input, model=model)

return await self._adecorator(_acall)

async def aget_models(self) -> Models:
"""Возвращает массив объектов с данными доступных моделей"""

Expand Down
3 changes: 3 additions & 0 deletions src/gigachat/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from gigachat.models.models import Models
from gigachat.models.token import Token
from gigachat.models.usage import Usage
from gigachat.models.tokens_count import TokensCount


__all__ = (
"AccessToken",
Expand All @@ -26,4 +28,5 @@
"Models",
"Token",
"Usage",
"TokensCount"
)
14 changes: 14 additions & 0 deletions src/gigachat/models/tokens_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import List

from gigachat.pydantic_v1 import BaseModel, Field


class TokensCount(BaseModel):
"""Информация о количестве токенов"""

tokens: int
"""Количество токенов в соответствующей строке."""
characters: int
"""Количество токенов в соответствующей строке."""
object_: str = Field(alias="object")
"""Тип сущности в ответе, например, список"""
7 changes: 7 additions & 0 deletions tests/data/tokens_count.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[
{
"object": "tokens",
"tokens": 7,
"characters": 36
}
]
27 changes: 26 additions & 1 deletion tests/unit_tests/gigachat/test_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest
from pytest_httpx import HTTPXMock
from typing import List

from gigachat.client import GigaChatAsyncClient, GigaChatSyncClient, _get_kwargs
from gigachat.exceptions import AuthenticationError
from gigachat.models import Chat, ChatCompletion, ChatCompletionChunk, Model, Models
from gigachat.models import Chat, ChatCompletion, ChatCompletionChunk, Model, Models, TokensCount
from gigachat.settings import Settings

from ...utils import get_bytes, get_json
Expand All @@ -14,13 +15,15 @@
TOKEN_URL = f"{BASE_URL}/token"
MODELS_URL = f"{BASE_URL}/models"
MODEL_URL = f"{BASE_URL}/models/model"
TOKENS_COUNT_URL = f"{BASE_URL}/tokens/count"

ACCESS_TOKEN = get_json("access_token.json")
TOKEN = get_json("token.json")
CHAT = Chat.parse_obj(get_json("chat.json"))
CHAT_COMPLETION = get_json("chat_completion.json")
CHAT_COMPLETION_STREAM = get_bytes("chat_completion.stream")
MODELS = get_json("models.json")
TOKENS_COUNT = get_json("tokens_count.json")
MODEL = get_json("model.json")

HEADERS_STREAM = {"Content-Type": "text/event-stream"}
Expand All @@ -33,6 +36,16 @@ def test__get_kwargs() -> None:
assert _get_kwargs(settings)


def test_get_tokens_count(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=TOKENS_COUNT_URL, json=TOKENS_COUNT)

with GigaChatSyncClient(base_url=BASE_URL) as client:
response = client.tokens_count(input=["123"], model="GigaChat:latest")
assert isinstance(response, List)
for row in response:
assert isinstance(row, TokensCount)


def test_get_models(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=MODELS_URL, json=MODELS)

Expand Down Expand Up @@ -240,6 +253,18 @@ async def test_aget_models(httpx_mock: HTTPXMock) -> None:
assert isinstance(response, Models)


@pytest.mark.asyncio()
async def test_atokens_count(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=TOKENS_COUNT_URL, json=TOKENS_COUNT)

async with GigaChatAsyncClient(base_url=BASE_URL) as client:
response = await client.atokens_count(input=["text"], model="GigaChat:latest")

assert isinstance(response, List)
for row in response:
assert isinstance(row, TokensCount)


@pytest.mark.asyncio()
async def test_aget_model(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=MODEL_URL, json=MODEL)
Expand Down