Skip to content

Commit

Permalink
feat: support SSL context settings (#41)
Browse files Browse the repository at this point in the history
* feat: support SSL context settings
* Readme update
* fix: Pass ssl context in settings
* chore: version up
  • Loading branch information
Rai220 authored Dec 28, 2024
1 parent a1e1ef7 commit 9429a07
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ giga = GigaChat(
cert_file="certs/tls.pem", # published_pem.txt
key_file="certs/tls.key",
key_file_password="123456",
ssl_context=context # optional ssl.SSLContext instance
)
```

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "gigachat"
version = "0.1.36"
version = "0.1.37"
description = "GigaChat. Python-library for GigaChain and LangChain"
authors = ["Konstantin Krestnikov <[email protected]>", "Sergey Malyshev <[email protected]>"]
license = "MIT"
Expand Down
7 changes: 7 additions & 0 deletions src/gigachat/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import ssl
from functools import cached_property
from typing import (
Any,
Expand Down Expand Up @@ -69,6 +70,8 @@ def _get_kwargs(settings: Settings) -> Dict[str, Any]:
"verify": settings.verify_ssl_certs,
"timeout": httpx.Timeout(settings.timeout),
}
if settings.ssl_context:
kwargs["verify"] = settings.ssl_context
if settings.ca_bundle_file:
kwargs["verify"] = settings.ca_bundle_file
if settings.cert_file:
Expand All @@ -86,6 +89,8 @@ def _get_auth_kwargs(settings: Settings) -> Dict[str, Any]:
"verify": settings.verify_ssl_certs,
"timeout": httpx.Timeout(settings.timeout),
}
if settings.ssl_context:
kwargs["verify"] = settings.ssl_context
if settings.ca_bundle_file:
kwargs["verify"] = settings.ca_bundle_file
return kwargs
Expand Down Expand Up @@ -131,6 +136,7 @@ def __init__(
cert_file: Optional[str] = None,
key_file: Optional[str] = None,
key_file_password: Optional[str] = None,
ssl_context: Optional[ssl.SSLContext] = None,
flags: Optional[List[str]] = None,
**_unknown_kwargs: Any,
) -> None:
Expand All @@ -154,6 +160,7 @@ def __init__(
"cert_file": cert_file,
"key_file": key_file,
"key_file_password": key_file_password,
"ssl_context": ssl_context,
"flags": flags,
}
config = {k: v for k, v in kwargs.items() if v is not None}
Expand Down
2 changes: 2 additions & 0 deletions src/gigachat/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ssl
from typing import List, Optional

from gigachat.pydantic_v1 import BaseSettings
Expand Down Expand Up @@ -33,6 +34,7 @@ class Settings(BaseSettings):

verbose: bool = False

ssl_context: Optional[ssl.SSLContext] = None
ca_bundle_file: Optional[str] = None
cert_file: Optional[str] = None
key_file: Optional[str] = None
Expand Down
18 changes: 18 additions & 0 deletions tests/unit_tests/gigachat/test_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ssl
from typing import List, Optional

import pytest
Expand Down Expand Up @@ -68,16 +69,33 @@
CREDENTIALS = "NmIwNzhlODgtNDlkNC00ZjFmLTljMjMtYjFiZTZjMjVmNTRlOmU3NWJlNjVhLTk4YjAtNGY0Ni1iOWVhLTljMDkwZGE4YTk4MQ=="


def _make_ssl_context() -> ssl.SSLContext:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
return context


def test__get_kwargs() -> None:
settings = Settings(ca_bundle_file="ca.pem", cert_file="tls.pem", key_file="tls.key")
assert _get_kwargs(settings)


def test__get_kwargs_ssl() -> None:
context = _make_ssl_context()
settings = Settings(ssl_context=context)
assert _get_kwargs(settings)["verify"] == context


def test__get_auth_kwargs() -> None:
settings = Settings(ca_bundle_file="ca.pem", cert_file="tls.pem", key_file="tls.key")
assert _get_auth_kwargs(settings)


def test__get_auth_kwargs_ssl() -> None:
context = _make_ssl_context()
settings = Settings(ssl_context=context)
assert _get_kwargs(settings)["verify"] == context


@pytest.mark.parametrize(
("payload_value", "setting_value", "expected"),
[
Expand Down

0 comments on commit 9429a07

Please sign in to comment.