From 8b06e586d978fcaaeb7aee67d903265615fc2bda Mon Sep 17 00:00:00 2001 From: Sergey Malyshev Date: Wed, 18 Oct 2023 11:25:08 +0300 Subject: [PATCH] add payload str, mtls, tests --- Makefile | 2 +- README.md | 46 +++---- examples/README.md | 4 +- examples/example_ask.py | 7 ++ examples/example_auth_certs_mtls.py | 13 ++ pyproject.toml | 2 +- src/gigachat/client.py | 57 +++++++-- src/gigachat/settings.py | 21 ++-- tests/unit_tests/conftest.py | 2 +- .../unit_tests/gigachat/api/test_get_model.py | 14 +-- .../gigachat/api/test_get_models.py | 13 +- .../unit_tests/gigachat/api/test_post_auth.py | 13 +- .../unit_tests/gigachat/api/test_post_chat.py | 13 +- .../gigachat/api/test_post_token.py | 13 +- .../gigachat/api/test_stream_chat.py | 17 +-- tests/unit_tests/gigachat/test_client.py | 112 +++++++++++------- tests/unit_tests/gigachat/test_settings.py | 6 +- 17 files changed, 203 insertions(+), 152 deletions(-) create mode 100644 examples/example_ask.py create mode 100644 examples/example_auth_certs_mtls.py diff --git a/Makefile b/Makefile index 99599c0..11eac1e 100644 --- a/Makefile +++ b/Makefile @@ -47,7 +47,7 @@ lint: .PHONY: mypy ## Perform type-checking mypy: - poetry run mypy src + poetry run mypy src tests .PHONY: test ## Run tests and generate a coverage report test: diff --git a/README.md b/README.md index 63fc2e2..ad10df2 100644 --- a/README.md +++ b/README.md @@ -30,40 +30,22 @@ pip install gigachat Пример показывает как импортировать библиотеку в [GigaChain](https://github.com/ai-forever/gigachain) и использовать ее для обращения к GigaChat: ```py -"""Пример работы с чатом""" from gigachat import GigaChat -from gigachat.models import Chat, Messages, MessagesRole - - -payload = Chat( - messages=[ - Messages( - role=MessagesRole.SYSTEM, - content="Ты внимательный бот-психолог, который помогает пользователю решить его проблемы." - ) - ], - temperature=0.7, - max_tokens=100, -) -# Используйте токен, полученный в личном кабинете в поле Авторизационные данные +# Используйте токен, полученный в личном кабинете из поля Авторизационные данные with GigaChat(credentials=..., verify_ssl_certs=False) as giga: - while True: - user_input = input("User: ") - payload.messages.append(Messages(role=MessagesRole.USER, content=user_input)) - response = giga.chat(payload) - payload.messages.append(response.choices[0].message) - print("Bot: ", response.choices[0].message.content) + response = giga.chat("Какие факторы влияют на стоимость страховки на дом?") + print(response.choices[0].message.content) ``` [Больше примеров](./examples/README.md). -## Способы авторизации +## Авторизация запросов -Авторизация с помощью [токена доступа GigaChat API](https://developers.sber.ru/docs/ru/gigachat/api/authorization): +Авторизация с помощью токена (в личном кабинете из поля Авторизационные данные): ```py -giga = GigaChat(access_token=...) +giga = GigaChat(credentials=...) ``` Авторизация с помощью логина и пароля: @@ -76,16 +58,19 @@ giga = GigaChat( ) ``` -## Дополнительные настройки - -Отключение авторизации: +Взаимная аутентификация по протоколу TLS (mTLS): ```py -giga = GigaChat(use_auth=False) +giga = GigaChat( + base_url="https://gigachat.devices.sberbank.ru/api/v1", + ca_bundle_file="certs/ca.pem", + cert_file="certs/tls.pem", + key_file="certs/tls.key", + key_file_password="123456", +) ``` -> [!NOTE] -> Функция может быть полезна, например, при авторизации с помощью Istio service mesh. +## Дополнительные настройки Отключение проверки сертификатов: @@ -96,7 +81,6 @@ giga = GigaChat(verify_ssl_certs=False) > [!WARNING] > Отключение проверки сертификатов снижает безопасность обмена данными. - ### Настройки в переменных окружения Чтобы задать настройки с помощью переменных окружения, используйте префикс `GIGACHAT_`. diff --git a/examples/README.md b/examples/README.md index f437f3d..868b6c4 100644 --- a/examples/README.md +++ b/examples/README.md @@ -2,5 +2,7 @@ Здесь вы найдете примеры работы с сервисом GigaChat с помощью библиотеки: +* [Пример вопрос - ответ](./example_ask.py) * [Работа с чатом](./simple_chat.py) -* [Ассинхронная работа с потоковой обработкой токенов](./streaming_asyncio.py) \ No newline at end of file +* [Асинхронная работа с потоковой обработкой токенов](./streaming_asyncio.py) +* [Взаимная аутентификация по протоколу TLS (mTLS)](./example_auth_certs_mtls.py) diff --git a/examples/example_ask.py b/examples/example_ask.py new file mode 100644 index 0000000..f2e6916 --- /dev/null +++ b/examples/example_ask.py @@ -0,0 +1,7 @@ +"""Пример вопрос - ответ""" +from gigachat import GigaChat + +# Используйте токен, полученный в личном кабинете из поля Авторизационные данные +with GigaChat(credentials=..., verify_ssl_certs=False) as giga: + response = giga.chat("Какие факторы влияют на стоимость страховки на дом?") + print(response.choices[0].message.content) diff --git a/examples/example_auth_certs_mtls.py b/examples/example_auth_certs_mtls.py new file mode 100644 index 0000000..4cc504c --- /dev/null +++ b/examples/example_auth_certs_mtls.py @@ -0,0 +1,13 @@ +"""Взаимная аутентификация по протоколу TLS (mTLS)""" + +from gigachat import GigaChat + +with GigaChat( + base_url="https://gigachat.devices.sberbank.ru/api/v1", + ca_bundle_file="certs/ca.pem", + cert_file="certs/tls.pem", + key_file="certs/tls.key", + key_file_password="123456", +) as giga: + response = giga.chat("Какие факторы влияют на стоимость страховки на дом?") + print(response.choices[0].message.content) diff --git a/pyproject.toml b/pyproject.toml index 7fd779d..c45c0a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gigachat" -version = "0.1.3" +version = "0.1.4" description = "GigaChat Python Library" authors = ["Konstantin Krestnikov "] license = "MIT" diff --git a/src/gigachat/client.py b/src/gigachat/client.py index 9c05920..30f0cb4 100644 --- a/src/gigachat/client.py +++ b/src/gigachat/client.py @@ -21,6 +21,8 @@ Chat, ChatCompletion, ChatCompletionChunk, + Messages, + MessagesRole, Model, Models, Token, @@ -34,11 +36,16 @@ def _get_kwargs(settings: Settings) -> Dict[str, Any]: """Настройки для подключения к API GIGACHAT""" - return { + kwargs = { "base_url": settings.base_url, "verify": settings.verify_ssl_certs, "timeout": httpx.Timeout(settings.timeout), } + if settings.ca_bundle_file: + kwargs["verify"] = settings.ca_bundle_file + if settings.cert_file: + kwargs["cert"] = (settings.cert_file, settings.key_file, settings.key_file_password) + return kwargs def _get_auth_kwargs(settings: Settings) -> Dict[str, Any]: @@ -49,11 +56,14 @@ def _get_auth_kwargs(settings: Settings) -> Dict[str, Any]: } -def _parse_chat(chat: Union[Chat, Dict[str, Any]], model: Optional[str]) -> Chat: - payload = Chat.parse_obj(chat) +def _parse_chat(payload: Union[Chat, Dict[str, Any], str], model: Optional[str]) -> Chat: + if isinstance(payload, str): + chat = Chat(messages=[Messages(role=MessagesRole.USER, content=payload)]) + else: + chat = Chat.parse_obj(payload) if model: - payload.model = model - return payload + chat.model = model + return chat def _build_access_token(token: Token) -> AccessToken: @@ -76,24 +86,45 @@ def __init__( password: Optional[str] = None, timeout: Optional[float] = None, verify_ssl_certs: Optional[bool] = None, - use_auth: Optional[bool] = None, verbose: Optional[bool] = None, + ca_bundle_file: Optional[str] = None, + cert_file: Optional[str] = None, + key_file: Optional[str] = None, + key_file_password: Optional[str] = None, + **_kwargs: Any, ) -> None: - config = {k: v for k, v in locals().items() if k != "self" and v is not None} + kwargs: Dict[str, Any] = { + "base_url": base_url, + "auth_url": auth_url, + "credentials": credentials, + "scope": scope, + "access_token": access_token, + "model": model, + "user": user, + "password": password, + "timeout": timeout, + "verify_ssl_certs": verify_ssl_certs, + "verbose": verbose, + "ca_bundle_file": ca_bundle_file, + "cert_file": cert_file, + "key_file": key_file, + "key_file_password": key_file_password, + } + config = {k: v for k, v in kwargs.items() if v is not None} self._settings = Settings(**config) if self._settings.access_token: self._access_token = AccessToken(access_token=self._settings.access_token, expires_at=0) @property def token(self) -> Optional[str]: - if self._settings.use_auth and self._access_token: + if self._access_token: return self._access_token.access_token else: return None @property def _use_auth(self) -> bool: - return self._settings.use_auth + return bool(self._settings.credentials or (self._settings.user and self._settings.password)) def _check_validity_token(self) -> bool: """Проверить время завершения действия токена""" @@ -162,12 +193,12 @@ def get_model(self, model: str) -> Model: """Возвращает объект с описанием указанной модели""" return self._decorator(lambda: get_model.sync(self._client, model=model, access_token=self.token)) - def chat(self, payload: Union[Chat, Dict[str, Any]]) -> ChatCompletion: + def chat(self, payload: Union[Chat, Dict[str, Any], str]) -> ChatCompletion: """Возвращает ответ модели с учетом переданных сообщений""" chat = _parse_chat(payload, model=self._settings.model) return self._decorator(lambda: post_chat.sync(self._client, chat=chat, access_token=self.token)) - def stream(self, payload: Union[Chat, Dict[str, Any]]) -> Iterator[ChatCompletionChunk]: + def stream(self, payload: Union[Chat, Dict[str, Any], str]) -> Iterator[ChatCompletionChunk]: """Возвращает ответ модели с учетом переданных сообщений""" chat = _parse_chat(payload, model=self._settings.model) @@ -249,7 +280,7 @@ async def _acall() -> Model: return await self._adecorator(_acall) - async def achat(self, payload: Union[Chat, Dict[str, Any]]) -> ChatCompletion: + async def achat(self, payload: Union[Chat, Dict[str, Any], str]) -> ChatCompletion: """Возвращает ответ модели с учетом переданных сообщений""" chat = _parse_chat(payload, model=self._settings.model) @@ -258,7 +289,7 @@ async def _acall() -> ChatCompletion: return await self._adecorator(_acall) - async def astream(self, payload: Union[Chat, Dict[str, Any]]) -> AsyncIterator[ChatCompletionChunk]: + async def astream(self, payload: Union[Chat, Dict[str, Any], str]) -> AsyncIterator[ChatCompletionChunk]: """Возвращает ответ модели с учетом переданных сообщений""" chat = _parse_chat(payload, model=self._settings.model) diff --git a/src/gigachat/settings.py b/src/gigachat/settings.py index 6d29f12..f999ed5 100644 --- a/src/gigachat/settings.py +++ b/src/gigachat/settings.py @@ -1,11 +1,9 @@ -import logging -from typing import Any, Dict, Optional +from typing import Optional -from gigachat.pydantic_v1 import BaseSettings, root_validator +from gigachat.pydantic_v1 import BaseSettings ENV_PREFIX = "GIGACHAT_" -# BASE_URL = "https://beta.saluteai.sberdevices.ru/v1" BASE_URL = "https://gigachat.devices.sberbank.ru/api/v1" AUTH_URL = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth" SCOPE = "GIGACHAT_API_CORP" @@ -26,17 +24,12 @@ class Settings(BaseSettings): timeout: float = 30.0 verify_ssl_certs: bool = True - use_auth: bool = True verbose: bool = False + ca_bundle_file: Optional[str] = None + cert_file: Optional[str] = None + key_file: Optional[str] = None + key_file_password: Optional[str] = None + class Config: env_prefix = ENV_PREFIX - - @root_validator - def check_credentials(cls, values: Dict[str, Any]) -> Dict[str, Any]: - if values["use_auth"]: - use_secrets = values["credentials"] or values["access_token"] or (values["user"] and values["password"]) - if not use_secrets: - logging.warning("Please provide GIGACHAT_CREDENTIALS environment variables.") - - return values diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index ccb04ab..333c421 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -4,5 +4,5 @@ @pytest.fixture(autouse=True) -def _delenv(monkeypatch) -> None: +def _delenv(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(os, "environ", {}) diff --git a/tests/unit_tests/gigachat/api/test_get_model.py b/tests/unit_tests/gigachat/api/test_get_model.py index a02ec52..86546cd 100644 --- a/tests/unit_tests/gigachat/api/test_get_model.py +++ b/tests/unit_tests/gigachat/api/test_get_model.py @@ -1,6 +1,6 @@ import httpx import pytest -import pytest_httpx +from pytest_httpx import HTTPXMock from gigachat.api import get_model from gigachat.exceptions import AuthenticationError, ResponseError @@ -14,7 +14,7 @@ MODEL = get_json("model.json") -def test_sync(httpx_mock: pytest_httpx.HTTPXMock): +def test_sync(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODEL_URL, json=MODEL) with httpx.Client(base_url=BASE_URL) as client: @@ -23,7 +23,7 @@ def test_sync(httpx_mock: pytest_httpx.HTTPXMock): assert isinstance(response, Model) -def test_sync_value_error(httpx_mock): +def test_sync_value_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODEL_URL, json={}) with httpx.Client(base_url=BASE_URL) as client: @@ -31,7 +31,7 @@ def test_sync_value_error(httpx_mock): get_model.sync(client, model="model") -def test_sync_authentication_error(httpx_mock): +def test_sync_authentication_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODEL_URL, status_code=401) with httpx.Client(base_url=BASE_URL) as client: @@ -39,7 +39,7 @@ def test_sync_authentication_error(httpx_mock): get_model.sync(client, model="model") -def test_sync_response_error(httpx_mock): +def test_sync_response_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODEL_URL, status_code=400) with httpx.Client(base_url=BASE_URL) as client: @@ -47,7 +47,7 @@ def test_sync_response_error(httpx_mock): get_model.sync(client, model="model") -def test_sync_headers(httpx_mock: pytest_httpx.HTTPXMock): +def test_sync_headers(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODEL_URL, json=MODEL) with httpx.Client(base_url=BASE_URL) as client: @@ -64,7 +64,7 @@ def test_sync_headers(httpx_mock: pytest_httpx.HTTPXMock): @pytest.mark.asyncio() -async def test_asyncio(httpx_mock): +async def test_asyncio(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODEL_URL, json=MODEL) async with httpx.AsyncClient(base_url=BASE_URL) as client: diff --git a/tests/unit_tests/gigachat/api/test_get_models.py b/tests/unit_tests/gigachat/api/test_get_models.py index 4e66da3..79b3a89 100644 --- a/tests/unit_tests/gigachat/api/test_get_models.py +++ b/tests/unit_tests/gigachat/api/test_get_models.py @@ -1,5 +1,6 @@ import httpx import pytest +from pytest_httpx import HTTPXMock from gigachat.api import get_models from gigachat.exceptions import AuthenticationError, ResponseError @@ -13,7 +14,7 @@ MODELS = get_json("models.json") -def test_sync(httpx_mock): +def test_sync(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODELS_URL, json=MODELS) with httpx.Client(base_url=BASE_URL) as client: @@ -22,7 +23,7 @@ def test_sync(httpx_mock): assert isinstance(response, Models) -def test_sync_value_error(httpx_mock): +def test_sync_value_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODELS_URL, json={}) with httpx.Client(base_url=BASE_URL) as client: @@ -30,7 +31,7 @@ def test_sync_value_error(httpx_mock): get_models.sync(client) -def test_sync_authentication_error(httpx_mock): +def test_sync_authentication_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODELS_URL, status_code=401) with httpx.Client(base_url=BASE_URL) as client: @@ -38,7 +39,7 @@ def test_sync_authentication_error(httpx_mock): get_models.sync(client) -def test_sync_response_error(httpx_mock): +def test_sync_response_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODELS_URL, status_code=400) with httpx.Client(base_url=BASE_URL) as client: @@ -46,7 +47,7 @@ def test_sync_response_error(httpx_mock): get_models.sync(client) -def test_sync_headers(httpx_mock): +def test_sync_headers(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODELS_URL, json=MODELS) with httpx.Client(base_url=BASE_URL) as client: @@ -58,7 +59,7 @@ def test_sync_headers(httpx_mock): @pytest.mark.asyncio() -async def test_asyncio(httpx_mock): +async def test_asyncio(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODELS_URL, json=MODELS) async with httpx.AsyncClient(base_url=BASE_URL) as client: diff --git a/tests/unit_tests/gigachat/api/test_post_auth.py b/tests/unit_tests/gigachat/api/test_post_auth.py index f76dd03..dd739e6 100644 --- a/tests/unit_tests/gigachat/api/test_post_auth.py +++ b/tests/unit_tests/gigachat/api/test_post_auth.py @@ -1,5 +1,6 @@ import httpx import pytest +from pytest_httpx import HTTPXMock from gigachat.api import post_auth from gigachat.exceptions import AuthenticationError, ResponseError @@ -12,7 +13,7 @@ ACCESS_TOKEN = get_json("access_token.json") -def test_sync(httpx_mock): +def test_sync(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, json=ACCESS_TOKEN) with httpx.Client() as client: @@ -21,7 +22,7 @@ def test_sync(httpx_mock): assert isinstance(response, AccessToken) -def test_sync_value_error(httpx_mock): +def test_sync_value_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, json={}) with httpx.Client() as client: @@ -29,7 +30,7 @@ def test_sync_value_error(httpx_mock): post_auth.sync(client, url=MOCK_URL, credentials="credentials", scope="scope") -def test_sync_authentication_error(httpx_mock): +def test_sync_authentication_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, status_code=401) with httpx.Client() as client: @@ -37,7 +38,7 @@ def test_sync_authentication_error(httpx_mock): post_auth.sync(client, url=MOCK_URL, credentials="credentials", scope="scope") -def test_sync_response_error(httpx_mock): +def test_sync_response_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, status_code=400) with httpx.Client() as client: @@ -45,7 +46,7 @@ def test_sync_response_error(httpx_mock): post_auth.sync(client, url=MOCK_URL, credentials="credentials", scope="scope") -def test_sync_headers(httpx_mock): +def test_sync_headers(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, json=ACCESS_TOKEN) with httpx.Client() as client: @@ -57,7 +58,7 @@ def test_sync_headers(httpx_mock): @pytest.mark.asyncio() -async def test_asyncio(httpx_mock): +async def test_asyncio(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, json=ACCESS_TOKEN) async with httpx.AsyncClient() as client: diff --git a/tests/unit_tests/gigachat/api/test_post_chat.py b/tests/unit_tests/gigachat/api/test_post_chat.py index e1424d8..c237942 100644 --- a/tests/unit_tests/gigachat/api/test_post_chat.py +++ b/tests/unit_tests/gigachat/api/test_post_chat.py @@ -1,5 +1,6 @@ import httpx import pytest +from pytest_httpx import HTTPXMock from gigachat.api import post_chat from gigachat.exceptions import AuthenticationError, ResponseError @@ -14,7 +15,7 @@ CHAT_COMPLETION = get_json("chat_completion.json") -def test_sync(httpx_mock): +def test_sync(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, json=CHAT_COMPLETION) with httpx.Client(base_url=BASE_URL) as client: @@ -23,7 +24,7 @@ def test_sync(httpx_mock): assert isinstance(response, ChatCompletion) -def test_sync_value_error(httpx_mock): +def test_sync_value_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, json={}) with httpx.Client(base_url=BASE_URL) as client: @@ -31,7 +32,7 @@ def test_sync_value_error(httpx_mock): post_chat.sync(client, chat=CHAT) -def test_sync_authentication_error(httpx_mock): +def test_sync_authentication_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, status_code=401) with httpx.Client(base_url=BASE_URL) as client: @@ -39,7 +40,7 @@ def test_sync_authentication_error(httpx_mock): post_chat.sync(client, chat=CHAT) -def test_sync_response_error(httpx_mock): +def test_sync_response_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, status_code=400) with httpx.Client(base_url=BASE_URL) as client: @@ -47,7 +48,7 @@ def test_sync_response_error(httpx_mock): post_chat.sync(client, chat=CHAT) -def test_sync_headers(httpx_mock): +def test_sync_headers(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, json=CHAT_COMPLETION) with httpx.Client(base_url=BASE_URL) as client: @@ -64,7 +65,7 @@ def test_sync_headers(httpx_mock): @pytest.mark.asyncio() -async def test_asyncio(httpx_mock): +async def test_asyncio(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, json=CHAT_COMPLETION) async with httpx.AsyncClient(base_url=BASE_URL) as client: diff --git a/tests/unit_tests/gigachat/api/test_post_token.py b/tests/unit_tests/gigachat/api/test_post_token.py index f9b91e6..e598437 100644 --- a/tests/unit_tests/gigachat/api/test_post_token.py +++ b/tests/unit_tests/gigachat/api/test_post_token.py @@ -1,5 +1,6 @@ import httpx import pytest +from pytest_httpx import HTTPXMock from gigachat.api import post_token from gigachat.exceptions import AuthenticationError, ResponseError @@ -13,7 +14,7 @@ TOKEN = get_json("token.json") -def test_sync(httpx_mock): +def test_sync(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, json=TOKEN) with httpx.Client(base_url=BASE_URL) as client: @@ -22,7 +23,7 @@ def test_sync(httpx_mock): assert isinstance(response, Token) -def test_sync_value_error(httpx_mock): +def test_sync_value_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, json={}) with httpx.Client(base_url=BASE_URL) as client: @@ -30,7 +31,7 @@ def test_sync_value_error(httpx_mock): post_token.sync(client, user="user", password="password") -def test_sync_authentication_error(httpx_mock): +def test_sync_authentication_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, status_code=401) with httpx.Client(base_url=BASE_URL) as client: @@ -38,7 +39,7 @@ def test_sync_authentication_error(httpx_mock): post_token.sync(client, user="user", password="password") -def test_sync_response_error(httpx_mock): +def test_sync_response_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, status_code=400) with httpx.Client(base_url=BASE_URL) as client: @@ -46,7 +47,7 @@ def test_sync_response_error(httpx_mock): post_token.sync(client, user="user", password="password") -def test_sync_headers(httpx_mock): +def test_sync_headers(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, json=TOKEN) with httpx.Client(base_url=BASE_URL) as client: @@ -63,7 +64,7 @@ def test_sync_headers(httpx_mock): @pytest.mark.asyncio() -async def test_asyncio(httpx_mock): +async def test_asyncio(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, json=TOKEN) async with httpx.AsyncClient(base_url=BASE_URL) as client: diff --git a/tests/unit_tests/gigachat/api/test_stream_chat.py b/tests/unit_tests/gigachat/api/test_stream_chat.py index a190139..29ddf26 100644 --- a/tests/unit_tests/gigachat/api/test_stream_chat.py +++ b/tests/unit_tests/gigachat/api/test_stream_chat.py @@ -1,5 +1,6 @@ import httpx import pytest +from pytest_httpx import HTTPXMock from gigachat.api import stream_chat from gigachat.exceptions import AuthenticationError, ResponseError @@ -15,7 +16,7 @@ HEADERS_STREAM = {"Content-Type": "text/event-stream"} -def test_sync(httpx_mock): +def test_sync(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, content=CHAT_COMPLETION_STREAM, headers=HEADERS_STREAM) with httpx.Client(base_url=BASE_URL) as client: @@ -26,15 +27,15 @@ def test_sync(httpx_mock): assert response[2].choices[0].finish_reason == "stop" -def test_sync_content_type_error(httpx_mock): - httpx_mock.add_response(url=MOCK_URL, content=CHAT_COMPLETION_STREAM, headers={}) +def test_sync_content_type_error(httpx_mock: HTTPXMock) -> None: + httpx_mock.add_response(url=MOCK_URL, content=CHAT_COMPLETION_STREAM) with httpx.Client(base_url=BASE_URL) as client: with pytest.raises(httpx.TransportError): list(stream_chat.sync(client, chat=CHAT)) -def test_sync_value_error(httpx_mock): +def test_sync_value_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, content=b"data: {}", headers=HEADERS_STREAM) with httpx.Client(base_url=BASE_URL) as client: @@ -42,7 +43,7 @@ def test_sync_value_error(httpx_mock): list(stream_chat.sync(client, chat=CHAT)) -def test_sync_authentication_error(httpx_mock): +def test_sync_authentication_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, status_code=401) with httpx.Client(base_url=BASE_URL) as client: @@ -50,7 +51,7 @@ def test_sync_authentication_error(httpx_mock): list(stream_chat.sync(client, chat=CHAT)) -def test_sync_response_error(httpx_mock): +def test_sync_response_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, status_code=400) with httpx.Client(base_url=BASE_URL) as client: @@ -58,7 +59,7 @@ def test_sync_response_error(httpx_mock): list(stream_chat.sync(client, chat=CHAT)) -def test_sync_headers(httpx_mock): +def test_sync_headers(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, content=CHAT_COMPLETION_STREAM, headers=HEADERS_STREAM) with httpx.Client(base_url=BASE_URL) as client: @@ -79,7 +80,7 @@ def test_sync_headers(httpx_mock): @pytest.mark.asyncio() -async def test_asyncio(httpx_mock): +async def test_asyncio(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MOCK_URL, content=CHAT_COMPLETION_STREAM, headers=HEADERS_STREAM) async with httpx.AsyncClient(base_url=BASE_URL) as client: diff --git a/tests/unit_tests/gigachat/test_client.py b/tests/unit_tests/gigachat/test_client.py index ab4df50..8b7f19f 100644 --- a/tests/unit_tests/gigachat/test_client.py +++ b/tests/unit_tests/gigachat/test_client.py @@ -1,8 +1,10 @@ import pytest +from pytest_httpx import HTTPXMock -from gigachat.client import GigaChatAsyncClient, GigaChatSyncClient +from gigachat.client import GigaChatAsyncClient, GigaChatSyncClient, _get_kwargs from gigachat.exceptions import AuthenticationError from gigachat.models import Chat, ChatCompletion, ChatCompletionChunk, Model, Models +from gigachat.settings import Settings from ...utils import get_bytes, get_json @@ -26,34 +28,39 @@ CREDENTIALS = "NmIwNzhlODgtNDlkNC00ZjFmLTljMjMtYjFiZTZjMjVmNTRlOmU3NWJlNjVhLTk4YjAtNGY0Ni1iOWVhLTljMDkwZGE4YTk4MQ==" -def test_get_models(httpx_mock): +def test__get_kwargs() -> None: + settings = Settings(ca_bundle_file="ca.pem", cert_file="tls.pem", key_file="tls.key") + assert _get_kwargs(settings) + + +def test_get_models(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODELS_URL, json=MODELS) - with GigaChatSyncClient(base_url=BASE_URL, use_auth=False) as client: + with GigaChatSyncClient(base_url=BASE_URL) as client: response = client.get_models() assert isinstance(response, Models) -def test_get_model(httpx_mock): +def test_get_model(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODEL_URL, json=MODEL) - with GigaChatSyncClient(base_url=BASE_URL, use_auth=False) as client: + with GigaChatSyncClient(base_url=BASE_URL) as client: response = client.get_model("model") assert isinstance(response, Model) -def test_chat(httpx_mock): +def test_chat(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, json=CHAT_COMPLETION) - with GigaChatSyncClient(base_url=BASE_URL, use_auth=False, model="model") as client: - response = client.chat(CHAT) + with GigaChatSyncClient(base_url=BASE_URL, model="model") as client: + response = client.chat("text") assert isinstance(response, ChatCompletion) -def test_chat_access_token(httpx_mock): +def test_chat_access_token(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, json=CHAT_COMPLETION) access_token = "access_token" @@ -63,7 +70,7 @@ def test_chat_access_token(httpx_mock): assert isinstance(response, ChatCompletion) -def test_chat_credentials(httpx_mock): +def test_chat_credentials(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=AUTH_URL, json=ACCESS_TOKEN) httpx_mock.add_response(url=CHAT_URL, json=CHAT_COMPLETION) @@ -73,7 +80,7 @@ def test_chat_credentials(httpx_mock): assert isinstance(response, ChatCompletion) -def test_chat_user_password(httpx_mock): +def test_chat_user_password(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, json=CHAT_COMPLETION) httpx_mock.add_response(url=TOKEN_URL, json=TOKEN) @@ -83,7 +90,7 @@ def test_chat_user_password(httpx_mock): assert isinstance(response, ChatCompletion) -def test_chat_authentication_error(httpx_mock): +def test_chat_authentication_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=AUTH_URL, json=ACCESS_TOKEN) httpx_mock.add_response(url=CHAT_URL, status_code=401) @@ -92,7 +99,7 @@ def test_chat_authentication_error(httpx_mock): client.chat(CHAT) -def test_chat_update_token_credentials(httpx_mock): +def test_chat_update_token_credentials(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=AUTH_URL, json=ACCESS_TOKEN) httpx_mock.add_response(url=CHAT_URL, status_code=401) access_token = "access_token" @@ -107,7 +114,7 @@ def test_chat_update_token_credentials(httpx_mock): assert client.token != access_token -def test_chat_update_token_user_password(httpx_mock): +def test_chat_update_token_user_password(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, status_code=401) httpx_mock.add_response(url=TOKEN_URL, json=TOKEN) access_token = "access_token" @@ -120,7 +127,7 @@ def test_chat_update_token_user_password(httpx_mock): assert client.token != access_token -def test_chat_update_token(httpx_mock): +def test_chat_update_token_false(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, status_code=401) access_token = "access_token" @@ -128,10 +135,10 @@ def test_chat_update_token(httpx_mock): assert client.token == access_token with pytest.raises(AuthenticationError): client.chat(CHAT) - assert client.token is None + assert client.token == access_token -def test_chat_update_token_success(httpx_mock): +def test_chat_update_token_success(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, status_code=401) httpx_mock.add_response(url=CHAT_URL, json=CHAT_COMPLETION) httpx_mock.add_response(url=TOKEN_URL, json=TOKEN) @@ -146,7 +153,7 @@ def test_chat_update_token_success(httpx_mock): assert isinstance(response, ChatCompletion) -def test_chat_update_token_error(httpx_mock): +def test_chat_update_token_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, status_code=401) httpx_mock.add_response(url=TOKEN_URL, json=TOKEN) access_token = "access_token" @@ -160,10 +167,10 @@ def test_chat_update_token_error(httpx_mock): assert client.token != access_token -def test_stream(httpx_mock): +def test_stream(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, content=CHAT_COMPLETION_STREAM, headers=HEADERS_STREAM) - with GigaChatSyncClient(base_url=BASE_URL, use_auth=False) as client: + with GigaChatSyncClient(base_url=BASE_URL) as client: response = list(client.stream(CHAT)) assert len(response) == 3 @@ -171,11 +178,11 @@ def test_stream(httpx_mock): assert response[2].choices[0].finish_reason == "stop" -def test_stream_access_token(httpx_mock): +def test_stream_access_token(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, content=CHAT_COMPLETION_STREAM, headers=HEADERS_STREAM) access_token = "access_token" - with GigaChatSyncClient(base_url=BASE_URL, access_token=access_token) as client: + with GigaChatSyncClient(base_url=BASE_URL, access_token=access_token, user="user", password="password") as client: response = list(client.stream(CHAT)) assert len(response) == 3 @@ -183,7 +190,7 @@ def test_stream_access_token(httpx_mock): assert response[2].choices[0].finish_reason == "stop" -def test_stream_authentication_error(httpx_mock): +def test_stream_authentication_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=AUTH_URL, json=ACCESS_TOKEN) httpx_mock.add_response(url=CHAT_URL, status_code=401) @@ -192,7 +199,7 @@ def test_stream_authentication_error(httpx_mock): list(client.stream(CHAT)) -def test_stream_update_token_success(httpx_mock): +def test_stream_update_token_success(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, status_code=401) httpx_mock.add_response(url=CHAT_URL, content=CHAT_COMPLETION_STREAM, headers=HEADERS_STREAM) httpx_mock.add_response(url=TOKEN_URL, json=TOKEN) @@ -209,7 +216,7 @@ def test_stream_update_token_success(httpx_mock): assert response[2].choices[0].finish_reason == "stop" -def test_stream_update_token_error(httpx_mock): +def test_stream_update_token_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, status_code=401) httpx_mock.add_response(url=TOKEN_URL, json=TOKEN) access_token = "access_token" @@ -224,37 +231,37 @@ def test_stream_update_token_error(httpx_mock): @pytest.mark.asyncio() -async def test_aget_models(httpx_mock): +async def test_aget_models(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODELS_URL, json=MODELS) - async with GigaChatAsyncClient(base_url=BASE_URL, use_auth=False) as client: + async with GigaChatAsyncClient(base_url=BASE_URL) as client: response = await client.aget_models() assert isinstance(response, Models) @pytest.mark.asyncio() -async def test_aget_model(httpx_mock): +async def test_aget_model(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=MODEL_URL, json=MODEL) - async with GigaChatAsyncClient(base_url=BASE_URL, use_auth=False) as client: + async with GigaChatAsyncClient(base_url=BASE_URL) as client: response = await client.aget_model("model") assert isinstance(response, Model) @pytest.mark.asyncio() -async def test_achat(httpx_mock): +async def test_achat(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, json=CHAT_COMPLETION) - async with GigaChatAsyncClient(base_url=BASE_URL, use_auth=False) as client: - response = await client.achat(CHAT) + async with GigaChatAsyncClient(base_url=BASE_URL) as client: + response = await client.achat("text") assert isinstance(response, ChatCompletion) @pytest.mark.asyncio() -async def test_achat_access_token(httpx_mock): +async def test_achat_access_token(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, json=CHAT_COMPLETION) access_token = "access_token" @@ -265,7 +272,7 @@ async def test_achat_access_token(httpx_mock): @pytest.mark.asyncio() -async def test_achat_credentials(httpx_mock): +async def test_achat_credentials(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=AUTH_URL, json=ACCESS_TOKEN) httpx_mock.add_response(url=CHAT_URL, json=CHAT_COMPLETION) @@ -276,7 +283,7 @@ async def test_achat_credentials(httpx_mock): @pytest.mark.asyncio() -async def test_achat_user_password(httpx_mock): +async def test_achat_user_password(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, json=CHAT_COMPLETION) httpx_mock.add_response(url=TOKEN_URL, json=TOKEN) @@ -287,7 +294,7 @@ async def test_achat_user_password(httpx_mock): @pytest.mark.asyncio() -async def test_achat_authentication_error(httpx_mock): +async def test_achat_authentication_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=AUTH_URL, json=ACCESS_TOKEN) httpx_mock.add_response(url=CHAT_URL, status_code=401) @@ -297,7 +304,7 @@ async def test_achat_authentication_error(httpx_mock): @pytest.mark.asyncio() -async def test_achat_update_token(httpx_mock): +async def test_achat_update_token_false(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, status_code=401) access_token = "access_token" @@ -305,11 +312,11 @@ async def test_achat_update_token(httpx_mock): assert client.token == access_token with pytest.raises(AuthenticationError): await client.achat(CHAT) - assert client.token is None + assert client.token == access_token @pytest.mark.asyncio() -async def test_achat_update_token_credentials(httpx_mock): +async def test_achat_update_token_credentials(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=AUTH_URL, json=ACCESS_TOKEN) httpx_mock.add_response(url=CHAT_URL, status_code=401) access_token = "access_token" @@ -325,7 +332,7 @@ async def test_achat_update_token_credentials(httpx_mock): @pytest.mark.asyncio() -async def test_achat_update_token_user_password(httpx_mock): +async def test_achat_update_token_user_password(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, status_code=401) httpx_mock.add_response(url=TOKEN_URL, json=TOKEN) access_token = "access_token" @@ -341,10 +348,10 @@ async def test_achat_update_token_user_password(httpx_mock): @pytest.mark.asyncio() -async def test_astream(httpx_mock): +async def test_astream(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, content=CHAT_COMPLETION_STREAM, headers=HEADERS_STREAM) - async with GigaChatAsyncClient(base_url=BASE_URL, use_auth=False) as client: + async with GigaChatAsyncClient(base_url=BASE_URL) as client: response = [chunk async for chunk in client.astream(CHAT)] assert len(response) == 3 @@ -353,11 +360,13 @@ async def test_astream(httpx_mock): @pytest.mark.asyncio() -async def test_astream_access_token(httpx_mock): +async def test_astream_access_token(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, content=CHAT_COMPLETION_STREAM, headers=HEADERS_STREAM) access_token = "access_token" - async with GigaChatAsyncClient(base_url=BASE_URL, access_token=access_token) as client: + async with GigaChatAsyncClient( + base_url=BASE_URL, access_token=access_token, user="user", password="password" + ) as client: response = [chunk async for chunk in client.astream(CHAT)] assert len(response) == 3 @@ -366,7 +375,7 @@ async def test_astream_access_token(httpx_mock): @pytest.mark.asyncio() -async def test_astream_authentication_error(httpx_mock): +async def test_astream_authentication_error(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=AUTH_URL, json=ACCESS_TOKEN) httpx_mock.add_response(url=CHAT_URL, status_code=401) @@ -376,7 +385,7 @@ async def test_astream_authentication_error(httpx_mock): @pytest.mark.asyncio() -async def test_astream_update_token(httpx_mock): +async def test_astream_update_token(httpx_mock: HTTPXMock) -> None: httpx_mock.add_response(url=CHAT_URL, status_code=401) httpx_mock.add_response(url=TOKEN_URL, json=TOKEN) access_token = "access_token" @@ -389,3 +398,14 @@ async def test_astream_update_token(httpx_mock): _ = [chunk async for chunk in client.astream(CHAT)] assert client.token assert client.token != access_token + + +def test__update_token() -> None: + with GigaChatSyncClient(base_url=BASE_URL) as client: + client._update_token() + + +@pytest.mark.asyncio() +async def test__aupdate_token() -> None: + async with GigaChatAsyncClient(base_url=BASE_URL) as client: + await client._aupdate_token() diff --git a/tests/unit_tests/gigachat/test_settings.py b/tests/unit_tests/gigachat/test_settings.py index d3e6142..6e95969 100644 --- a/tests/unit_tests/gigachat/test_settings.py +++ b/tests/unit_tests/gigachat/test_settings.py @@ -1,9 +1,5 @@ from gigachat.settings import Settings -def test_settings(): +def test_settings() -> None: assert Settings() - - -def test_settings_use_auth_false(): - assert Settings(use_auth=False)