Skip to content

Commit

Permalink
fix(python-sdk): organization environment variable and default client
Browse files Browse the repository at this point in the history
* Use the correct environment variable `NUMEROUS_ORGANIZATION_ID`.
* Define exception types for relevant exceptions in `_client` module.
* Raise exception if organization ID is not configured.
* Default to singleton client.
  • Loading branch information
jfeodor committed Sep 25, 2024
1 parent 2183b1d commit 29854ad
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 41 deletions.
94 changes: 63 additions & 31 deletions python/src/numerous/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,46 @@
from numerous.threaded_event_loop import ThreadedEventLoop


API_URL_NOT_SET = "NUMEROUS_API_URL environment variable is not set"
MESSAGE_NOT_SET = "NUMEROUS_API_ACCESS_TOKEN environment variable is not set"
COLLECTED_OBJECTS_NUMBER = 100


class APIURLMissingError(Exception):
_msg = "NUMEROUS_API_URL environment variable is not set"

def __init__(self) -> None:
super().__init__(self._msg)


class APIAccessTokenMissingError(Exception):
_msg = "NUMEROUS_API_ACCESS_TOKEN environment variable is not set"

def __init__(self) -> None:
super().__init__(self._msg)


class OrganizationIDMissingError(Exception):
_msg = "NUMEROUS_ORGANIZATION_ID environment variable is not set"

def __init__(self) -> None:
super().__init__(self._msg)


class Client:
def __init__(self, client: GQLClient) -> None:
self.client = client
def __init__(self, gql: GQLClient) -> None:
self._gql = gql
self._threaded_event_loop = ThreadedEventLoop()
self._threaded_event_loop.start()
self.organization_id = os.getenv("ORGANIZATION_ID", "default_organization")

organization_id = os.getenv("NUMEROUS_ORGANIZATION_ID")
if not organization_id:
raise OrganizationIDMissingError
self._organization_id = organization_id

auth_token = os.getenv("NUMEROUS_API_ACCESS_TOKEN")
if not auth_token:
raise ValueError(MESSAGE_NOT_SET)
raise APIAccessTokenMissingError

self.kwargs = {"headers": {"Authorization": f"Bearer {auth_token}"}}
self._headers = {"Authorization": f"Bearer {auth_token}"}

def _create_collection_ref(
self,
Expand All @@ -82,11 +105,11 @@ def _create_collection_ref(
async def _create_collection(
self, collection_key: str, parent_collection_key: Optional[str] = None
) -> Optional[CollectionReference]:
response = await self.client.collection_create(
self.organization_id,
response = await self._gql.collection_create(
self._organization_id,
collection_key,
parent_collection_key,
**self.kwargs,
headers=self._headers,
)
return self._create_collection_ref(response.collection_create)

Expand Down Expand Up @@ -142,11 +165,11 @@ def _create_collection_document_ref(
async def _get_collection_document(
self, collection_key: str, document_key: str
) -> Optional[CollectionDocumentReference]:
response = await self.client.collection_document(
self.organization_id,
response = await self._gql.collection_document(
self._organization_id,
collection_key,
document_key,
**self.kwargs,
headers=self._headers,
)
if isinstance(
response.collection_create,
Expand All @@ -165,11 +188,11 @@ def get_collection_document(
async def _set_collection_document(
self, collection_id: str, document_key: str, document_data: str
) -> Optional[CollectionDocumentReference]:
response = await self.client.collection_document_set(
response = await self._gql.collection_document_set(
collection_id,
document_key,
document_data,
**self.kwargs,
headers=self._headers,
)
return self._create_collection_document_ref(response.collection_document_set)

Expand All @@ -183,8 +206,8 @@ def set_collection_document(
async def _delete_collection_document(
self, document_id: str
) -> Optional[CollectionDocumentReference]:
response = await self.client.collection_document_delete(
document_id, **self.kwargs
response = await self._gql.collection_document_delete(
document_id, headers=self._headers
)
return self._create_collection_document_ref(response.collection_document_delete)

Expand All @@ -198,8 +221,8 @@ def delete_collection_document(
async def _add_collection_document_tag(
self, document_id: str, tag: TagInput
) -> Optional[CollectionDocumentReference]:
response = await self.client.collection_document_tag_add(
document_id, tag, **self.kwargs
response = await self._gql.collection_document_tag_add(
document_id, tag, headers=self._headers
)
return self._create_collection_document_ref(
response.collection_document_tag_add
Expand All @@ -215,8 +238,8 @@ def add_collection_document_tag(
async def _delete_collection_document_tag(
self, document_id: str, tag_key: str
) -> Optional[CollectionDocumentReference]:
response = await self.client.collection_document_tag_delete(
document_id, tag_key, **self.kwargs
response = await self._gql.collection_document_tag_delete(
document_id, tag_key, headers=self._headers
)
return self._create_collection_document_ref(
response.collection_document_tag_delete
Expand All @@ -235,13 +258,13 @@ async def _get_collection_documents(
end_cursor: str,
tag_input: Optional[TagInput],
) -> tuple[Optional[list[Optional[CollectionDocumentReference]]], bool, str]:
response = await self.client.collection_documents(
self.organization_id,
response = await self._gql.collection_documents(
self._organization_id,
collection_key,
tag_input,
after=end_cursor,
first=COLLECTED_OBJECTS_NUMBER,
**self.kwargs,
headers=self._headers,
)

collection = response.collection_create
Expand Down Expand Up @@ -269,12 +292,12 @@ def get_collection_documents(
async def _get_collection_collections(
self, collection_key: str, end_cursor: str
) -> tuple[Optional[list[Optional[CollectionReference]]], bool, str]:
response = await self.client.collection_collections(
self.organization_id,
response = await self._gql.collection_collections(
self._organization_id,
collection_key,
after=end_cursor,
first=COLLECTED_OBJECTS_NUMBER,
**self.kwargs,
headers=self._headers,
)

collection = response.collection_create
Expand All @@ -300,9 +323,18 @@ def get_collection_collections(
)


def _open_client() -> Client:
_client: Optional[Client] = None


def _get_client() -> Client:
global _client # noqa: PLW0603

if _client is not None:
return _client

url = os.getenv("NUMEROUS_API_URL")
if not url:
raise ValueError(API_URL_NOT_SET)
client = GQLClient(url=url)
return Client(client)
raise APIURLMissingError

_client = Client(GQLClient(url=url))
return _client
4 changes: 2 additions & 2 deletions python/src/numerous/collection/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Optional

from numerous._client import Client, _open_client
from numerous._client import Client, _get_client
from numerous.collection.numerous_collection import NumerousCollection


Expand All @@ -11,6 +11,6 @@ def collection(
) -> NumerousCollection:
"""Get or create a collection by name."""
if _client is None:
_client = _open_client()
_client = _get_client()
collection_ref_key = _client.get_collection_reference(collection_key)
return NumerousCollection(collection_ref_key, _client)
2 changes: 1 addition & 1 deletion python/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _collection_create_collection_not_found(ref_id: str) -> CollectionCreate:
@pytest.fixture(autouse=True)
def _set_env_vars(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("NUMEROUS_API_URL", "url_value")
monkeypatch.setenv("ORGANIZATION_ID", ORGANIZATION_ID)
monkeypatch.setenv("NUMEROUS_ORGANIZATION_ID", ORGANIZATION_ID)
monkeypatch.setenv("NUMEROUS_API_ACCESS_TOKEN", "token")


Expand Down
11 changes: 4 additions & 7 deletions python/tests/test_numerous_client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import pytest

from numerous._client import Client, _open_client
from numerous._client import Client, _get_client


@pytest.fixture(autouse=True)
def _set_env_vars(monkeypatch: pytest.MonkeyPatch) -> None:
def test_open_client_returns_new_client(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("NUMEROUS_API_URL", "url_value")
monkeypatch.setenv("NUMEROUS_API_ACCESS_TOKEN", "token")
monkeypatch.setenv("NUMEROUS_ORGANIZATION_ID", "organization-id")


def test_open_client_returns_new_client() -> None:
"""Testing client."""
client = _open_client()
client = _get_client()

assert isinstance(client, Client)

0 comments on commit 29854ad

Please sign in to comment.