Skip to content

Commit

Permalink
add payload str, mtls, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
serega-nk committed Oct 18, 2023
1 parent 3ad7936 commit 8b06e58
Show file tree
Hide file tree
Showing 17 changed files with 203 additions and 152 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ lint:

.PHONY: mypy ## Perform type-checking
mypy:
poetry run mypy src
poetry run mypy src tests

.PHONY: test ## Run tests and generate a coverage report
test:
Expand Down
46 changes: 15 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,40 +30,22 @@ pip install gigachat
Пример показывает как импортировать библиотеку в [GigaChain](https://github.com/ai-forever/gigachain) и использовать ее для обращения к GigaChat:

```py
"""Пример работы с чатом"""
from gigachat import GigaChat
from gigachat.models import Chat, Messages, MessagesRole


payload = Chat(
messages=[
Messages(
role=MessagesRole.SYSTEM,
content="Ты внимательный бот-психолог, который помогает пользователю решить его проблемы."
)
],
temperature=0.7,
max_tokens=100,
)

# Используйте токен, полученный в личном кабинете в поле Авторизационные данные
# Используйте токен, полученный в личном кабинете из поля Авторизационные данные
with GigaChat(credentials=..., verify_ssl_certs=False) as giga:
while True:
user_input = input("User: ")
payload.messages.append(Messages(role=MessagesRole.USER, content=user_input))
response = giga.chat(payload)
payload.messages.append(response.choices[0].message)
print("Bot: ", response.choices[0].message.content)
response = giga.chat("Какие факторы влияют на стоимость страховки на дом?")
print(response.choices[0].message.content)
```

[Больше примеров](./examples/README.md).

## Способы авторизации
## Авторизация запросов

Авторизация с помощью [токена доступа GigaChat API](https://developers.sber.ru/docs/ru/gigachat/api/authorization):
Авторизация с помощью токена (в личном кабинете из поля Авторизационные данные):

```py
giga = GigaChat(access_token=...)
giga = GigaChat(credentials=...)
```

Авторизация с помощью логина и пароля:
Expand All @@ -76,16 +58,19 @@ giga = GigaChat(
)
```

## Дополнительные настройки

Отключение авторизации:
Взаимная аутентификация по протоколу TLS (mTLS):

```py
giga = GigaChat(use_auth=False)
giga = GigaChat(
base_url="https://gigachat.devices.sberbank.ru/api/v1",
ca_bundle_file="certs/ca.pem",
cert_file="certs/tls.pem",
key_file="certs/tls.key",
key_file_password="123456",
)
```

> [!NOTE]
> Функция может быть полезна, например, при авторизации с помощью Istio service mesh.
## Дополнительные настройки

Отключение проверки сертификатов:

Expand All @@ -96,7 +81,6 @@ giga = GigaChat(verify_ssl_certs=False)
> [!WARNING]
> Отключение проверки сертификатов снижает безопасность обмена данными.

### Настройки в переменных окружения

Чтобы задать настройки с помощью переменных окружения, используйте префикс `GIGACHAT_`.
Expand Down
4 changes: 3 additions & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@

Здесь вы найдете примеры работы с сервисом GigaChat с помощью библиотеки:

* [Пример вопрос - ответ](./example_ask.py)
* [Работа с чатом](./simple_chat.py)
* [Ассинхронная работа с потоковой обработкой токенов](./streaming_asyncio.py)
* [Асинхронная работа с потоковой обработкой токенов](./streaming_asyncio.py)
* [Взаимная аутентификация по протоколу TLS (mTLS)](./example_auth_certs_mtls.py)
7 changes: 7 additions & 0 deletions examples/example_ask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Пример вопрос - ответ"""
from gigachat import GigaChat

# Используйте токен, полученный в личном кабинете из поля Авторизационные данные
with GigaChat(credentials=..., verify_ssl_certs=False) as giga:
response = giga.chat("Какие факторы влияют на стоимость страховки на дом?")
print(response.choices[0].message.content)
13 changes: 13 additions & 0 deletions examples/example_auth_certs_mtls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Взаимная аутентификация по протоколу TLS (mTLS)"""

from gigachat import GigaChat

with GigaChat(
base_url="https://gigachat.devices.sberbank.ru/api/v1",
ca_bundle_file="certs/ca.pem",
cert_file="certs/tls.pem",
key_file="certs/tls.key",
key_file_password="123456",
) as giga:
response = giga.chat("Какие факторы влияют на стоимость страховки на дом?")
print(response.choices[0].message.content)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "gigachat"
version = "0.1.3"
version = "0.1.4"
description = "GigaChat Python Library"
authors = ["Konstantin Krestnikov <[email protected]>"]
license = "MIT"
Expand Down
57 changes: 44 additions & 13 deletions src/gigachat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
Chat,
ChatCompletion,
ChatCompletionChunk,
Messages,
MessagesRole,
Model,
Models,
Token,
Expand All @@ -34,11 +36,16 @@

def _get_kwargs(settings: Settings) -> Dict[str, Any]:
"""Настройки для подключения к API GIGACHAT"""
return {
kwargs = {
"base_url": settings.base_url,
"verify": settings.verify_ssl_certs,
"timeout": httpx.Timeout(settings.timeout),
}
if settings.ca_bundle_file:
kwargs["verify"] = settings.ca_bundle_file
if settings.cert_file:
kwargs["cert"] = (settings.cert_file, settings.key_file, settings.key_file_password)
return kwargs


def _get_auth_kwargs(settings: Settings) -> Dict[str, Any]:
Expand All @@ -49,11 +56,14 @@ def _get_auth_kwargs(settings: Settings) -> Dict[str, Any]:
}


def _parse_chat(chat: Union[Chat, Dict[str, Any]], model: Optional[str]) -> Chat:
payload = Chat.parse_obj(chat)
def _parse_chat(payload: Union[Chat, Dict[str, Any], str], model: Optional[str]) -> Chat:
if isinstance(payload, str):
chat = Chat(messages=[Messages(role=MessagesRole.USER, content=payload)])
else:
chat = Chat.parse_obj(payload)
if model:
payload.model = model
return payload
chat.model = model
return chat


def _build_access_token(token: Token) -> AccessToken:
Expand All @@ -76,24 +86,45 @@ def __init__(
password: Optional[str] = None,
timeout: Optional[float] = None,
verify_ssl_certs: Optional[bool] = None,
use_auth: Optional[bool] = None,
verbose: Optional[bool] = None,
ca_bundle_file: Optional[str] = None,
cert_file: Optional[str] = None,
key_file: Optional[str] = None,
key_file_password: Optional[str] = None,
**_kwargs: Any,
) -> None:
config = {k: v for k, v in locals().items() if k != "self" and v is not None}
kwargs: Dict[str, Any] = {
"base_url": base_url,
"auth_url": auth_url,
"credentials": credentials,
"scope": scope,
"access_token": access_token,
"model": model,
"user": user,
"password": password,
"timeout": timeout,
"verify_ssl_certs": verify_ssl_certs,
"verbose": verbose,
"ca_bundle_file": ca_bundle_file,
"cert_file": cert_file,
"key_file": key_file,
"key_file_password": key_file_password,
}
config = {k: v for k, v in kwargs.items() if v is not None}
self._settings = Settings(**config)
if self._settings.access_token:
self._access_token = AccessToken(access_token=self._settings.access_token, expires_at=0)

@property
def token(self) -> Optional[str]:
if self._settings.use_auth and self._access_token:
if self._access_token:
return self._access_token.access_token
else:
return None

@property
def _use_auth(self) -> bool:
return self._settings.use_auth
return bool(self._settings.credentials or (self._settings.user and self._settings.password))

def _check_validity_token(self) -> bool:
"""Проверить время завершения действия токена"""
Expand Down Expand Up @@ -162,12 +193,12 @@ def get_model(self, model: str) -> Model:
"""Возвращает объект с описанием указанной модели"""
return self._decorator(lambda: get_model.sync(self._client, model=model, access_token=self.token))

def chat(self, payload: Union[Chat, Dict[str, Any]]) -> ChatCompletion:
def chat(self, payload: Union[Chat, Dict[str, Any], str]) -> ChatCompletion:
"""Возвращает ответ модели с учетом переданных сообщений"""
chat = _parse_chat(payload, model=self._settings.model)
return self._decorator(lambda: post_chat.sync(self._client, chat=chat, access_token=self.token))

def stream(self, payload: Union[Chat, Dict[str, Any]]) -> Iterator[ChatCompletionChunk]:
def stream(self, payload: Union[Chat, Dict[str, Any], str]) -> Iterator[ChatCompletionChunk]:
"""Возвращает ответ модели с учетом переданных сообщений"""
chat = _parse_chat(payload, model=self._settings.model)

Expand Down Expand Up @@ -249,7 +280,7 @@ async def _acall() -> Model:

return await self._adecorator(_acall)

async def achat(self, payload: Union[Chat, Dict[str, Any]]) -> ChatCompletion:
async def achat(self, payload: Union[Chat, Dict[str, Any], str]) -> ChatCompletion:
"""Возвращает ответ модели с учетом переданных сообщений"""
chat = _parse_chat(payload, model=self._settings.model)

Expand All @@ -258,7 +289,7 @@ async def _acall() -> ChatCompletion:

return await self._adecorator(_acall)

async def astream(self, payload: Union[Chat, Dict[str, Any]]) -> AsyncIterator[ChatCompletionChunk]:
async def astream(self, payload: Union[Chat, Dict[str, Any], str]) -> AsyncIterator[ChatCompletionChunk]:
"""Возвращает ответ модели с учетом переданных сообщений"""
chat = _parse_chat(payload, model=self._settings.model)

Expand Down
21 changes: 7 additions & 14 deletions src/gigachat/settings.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import logging
from typing import Any, Dict, Optional
from typing import Optional

from gigachat.pydantic_v1 import BaseSettings, root_validator
from gigachat.pydantic_v1 import BaseSettings

ENV_PREFIX = "GIGACHAT_"

# BASE_URL = "https://beta.saluteai.sberdevices.ru/v1"
BASE_URL = "https://gigachat.devices.sberbank.ru/api/v1"
AUTH_URL = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth"
SCOPE = "GIGACHAT_API_CORP"
Expand All @@ -26,17 +24,12 @@ class Settings(BaseSettings):
timeout: float = 30.0
verify_ssl_certs: bool = True

use_auth: bool = True
verbose: bool = False

ca_bundle_file: Optional[str] = None
cert_file: Optional[str] = None
key_file: Optional[str] = None
key_file_password: Optional[str] = None

class Config:
env_prefix = ENV_PREFIX

@root_validator
def check_credentials(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values["use_auth"]:
use_secrets = values["credentials"] or values["access_token"] or (values["user"] and values["password"])
if not use_secrets:
logging.warning("Please provide GIGACHAT_CREDENTIALS environment variables.")

return values
2 changes: 1 addition & 1 deletion tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@


@pytest.fixture(autouse=True)
def _delenv(monkeypatch) -> None:
def _delenv(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(os, "environ", {})
14 changes: 7 additions & 7 deletions tests/unit_tests/gigachat/api/test_get_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import httpx
import pytest
import pytest_httpx
from pytest_httpx import HTTPXMock

from gigachat.api import get_model
from gigachat.exceptions import AuthenticationError, ResponseError
Expand All @@ -14,7 +14,7 @@
MODEL = get_json("model.json")


def test_sync(httpx_mock: pytest_httpx.HTTPXMock):
def test_sync(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=MODEL_URL, json=MODEL)

with httpx.Client(base_url=BASE_URL) as client:
Expand All @@ -23,31 +23,31 @@ def test_sync(httpx_mock: pytest_httpx.HTTPXMock):
assert isinstance(response, Model)


def test_sync_value_error(httpx_mock):
def test_sync_value_error(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=MODEL_URL, json={})

with httpx.Client(base_url=BASE_URL) as client:
with pytest.raises(ValueError, match="3 validation errors for Model*"):
get_model.sync(client, model="model")


def test_sync_authentication_error(httpx_mock):
def test_sync_authentication_error(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=MODEL_URL, status_code=401)

with httpx.Client(base_url=BASE_URL) as client:
with pytest.raises(AuthenticationError):
get_model.sync(client, model="model")


def test_sync_response_error(httpx_mock):
def test_sync_response_error(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=MODEL_URL, status_code=400)

with httpx.Client(base_url=BASE_URL) as client:
with pytest.raises(ResponseError):
get_model.sync(client, model="model")


def test_sync_headers(httpx_mock: pytest_httpx.HTTPXMock):
def test_sync_headers(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=MODEL_URL, json=MODEL)

with httpx.Client(base_url=BASE_URL) as client:
Expand All @@ -64,7 +64,7 @@ def test_sync_headers(httpx_mock: pytest_httpx.HTTPXMock):


@pytest.mark.asyncio()
async def test_asyncio(httpx_mock):
async def test_asyncio(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=MODEL_URL, json=MODEL)

async with httpx.AsyncClient(base_url=BASE_URL) as client:
Expand Down
Loading

0 comments on commit 8b06e58

Please sign in to comment.