From 877f90d918e99fb176b795598ca2fccf10461075 Mon Sep 17 00:00:00 2001 From: augray Date: Wed, 4 Sep 2024 11:36:39 -0700 Subject: [PATCH] Add http API wrapper (#2) Add a wrapper of our HTTP API for the rest of the SDK to use. Testing -------- In addition to the unit tests, tested with the following driver code while pointing at the dev Airtrain deployment (remember, sdk users will NOT be using this API directly; this is just to test the wrapper): ```python import io import pyarrow.json as jsonl import pyarrow.parquet as pq from airtrain.client import client def main() -> None: json_buffer = io.BytesIO() json_buffer.write(b'{"text": "hello"}\n') json_buffer.write(b'{"text": "world"}') json_buffer.seek(0) parquet_buffer = io.BytesIO() table = jsonl.read_json(json_buffer) pq.write_table(table, parquet_buffer) response = client().create_dataset(name="From API Client", embedding_column_name=None) parquet_buffer.seek(0) client().upload_dataset_data(response.dataset_id, parquet_buffer) ingest_response = client().trigger_dataset_ingest(response.dataset_id) print( f"Created dataset {response.dataset_id} " f"with ingest job {ingest_response.ingest_job_id}" ) if __name__ == "__main__": main() ``` --- .gitignore | 2 + Makefile | 10 ++ README.md | 5 +- pyproject.toml | 7 +- src/airtrain/__init__.py | 1 + src/airtrain/client.py | 244 +++++++++++++++++++++++++++++++++++++++ src/airtrain/retry.py | 222 +++++++++++++++++++++++++++++++++++ src/tests/test_client.py | 116 +++++++++++++++++++ src/tests/utils.py | 35 ++++++ uv.lock | 1 - 10 files changed, 636 insertions(+), 7 deletions(-) create mode 100644 src/airtrain/client.py create mode 100644 src/airtrain/retry.py create mode 100644 src/tests/test_client.py create mode 100644 src/tests/utils.py diff --git a/.gitignore b/.gitignore index bce382d..42cccb6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +tmp.* + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/Makefile b/Makefile index b84551a..afcd940 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,17 @@ PY_VERSION := "3.11" .PHONY: wheel wheel: + rm -rf dist build src/*.egg-info uvx pip wheel -w dist . +.PHONY: test-release +test-release: wheel + uvx twine check dist/*airtrain*.whl + uvx twine upload --repository testpypi dist/*airtrain*.whl + +.PHONY: release +release: wheel + uvx twine check dist/*airtrain*.whl + uvx twine upload dist/*airtrain*.whl .PHONY: test-release test-release: wheel diff --git a/README.md b/README.md index 53fb380..ad490a7 100644 --- a/README.md +++ b/README.md @@ -27,9 +27,6 @@ Python Version - - CI status -

@@ -51,7 +48,7 @@ Then you may upload a new dataset as follows: import airtrain as at # Can also be set with the environment variable AIRTRAIN_API_KEY -at.api_key = "sUpErSeCr3t" +at.set_api_key("sUpErSeCr3t") url = at.upload_from_dicts( [ diff --git a/pyproject.toml b/pyproject.toml index 454fb2c..d7bd282 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,8 +18,6 @@ dependencies = [ "numpy>=1.26.0; python_version >= '3.12'", "numpy<=1.24.4; python_version == '3.8'", "numpy>=1.19.3; python_version >= '3.9'", - - "airtrain", ] classifiers = [ @@ -55,6 +53,11 @@ Documentation = "https://docs.airtrain.ai/" requires = ["setuptools", "setuptools-scm"] build-backend = "setuptools.build_meta" +[tool.setuptools.packages.find] +where = ["src"] +include = ["airtrain*"] +exclude = ["tests*", "*.tests*"] + [tool.ruff] line-length = 90 diff --git a/src/airtrain/__init__.py b/src/airtrain/__init__.py index 973e3bc..1b8950f 100644 --- a/src/airtrain/__init__.py +++ b/src/airtrain/__init__.py @@ -1 +1,2 @@ +from airtrain.client import set_api_key # noqa: F401 from airtrain.core import DatasetMetadata, upload_from_dicts # noqa: F401 diff --git a/src/airtrain/client.py b/src/airtrain/client.py new file mode 100644 index 0000000..e99c184 --- /dev/null +++ b/src/airtrain/client.py @@ -0,0 +1,244 @@ +import io +import logging +import os +from dataclasses import dataclass +from functools import lru_cache +from typing import Any, Dict, Iterable, Optional + +import httpx + + +logger = logging.getLogger(__name__) + + +RequestJson = Dict[str, Any] +ResponseJson = Dict[str, Any] + + +API_KEY_ENV_VAR: str = "AIRTRAIN_API_KEY" +_DEFAULT_API_KEY: Optional[str] = None +_DEFAULT_BASE_URL: str = "https://api.airtrain.ai" +_BUFFER_CHUNK_SIZE = 8192 + + +class BadRequestError(Exception): + """The request was bad for some reason.""" + + pass + + +class AuthenticationError(BadRequestError): + """The caller does not have permission to complete the request.""" + + pass + + +class NotFoundError(BadRequestError): + """The caller requested something that doesn't exist.""" + + pass + + +class ServerError(Exception): + """There was some problem with the server.""" + + pass + + +@dataclass +class CreateDatasetResponse: + dataset_id: str + row_limit: int + + +@dataclass +class TriggerIngestResponse: + ingest_job_id: str + + +class AirtrainClient: + """A direct wrapper around Airtrain's HTTP API. Intended for internal package use. + + SDK users should NOT use this class directly and should not assume it will have a + stable API. It makes no attempt to make sure calls are sequenced in an appropriate + order or that information is passed between calls in a logical way. There should be + a direct 1:1 correspondence between methods here and API endpoints that the SDK + needs to interact with. + """ + + def __init__( + self, api_key: Optional[str] = None, base_url: Optional[str] = None + ) -> None: + self._api_key: str = api_key or _find_api_key() # type: ignore + self._base_url: str = base_url or os.environ.get( # type: ignore + "AIRTRAIN_API_URL", _DEFAULT_BASE_URL + ) + if self._base_url.endswith("/"): + self._base_url = self._base_url[:-1] + self._http_client = httpx.Client() + + if self._api_key is None: + raise AuthenticationError( + "No Api key found. " + "Set one with the environment variable 'AIRTRAIN_API_KEY' or the " + "function airtrain.set_api_key" + ) + + def trigger_dataset_ingest(self, dataset_id: str) -> TriggerIngestResponse: + """Wraps: POST /dataset/[id]/ingest""" + response = self._post_json(url_path=f"dataset/{dataset_id}/ingest", content={}) + job_id = response.get("ingestionJobId") + if not isinstance(job_id, str): + raise ServerError(f"Malformed response: {response}") + return TriggerIngestResponse(ingest_job_id=job_id) + + def create_dataset( + self, name: str, embedding_column_name: Optional[str] + ) -> CreateDatasetResponse: + """Wraps: POST /dataset""" + response = self._post_json( + "dataset", dict(name=name, embeddingColumn=embedding_column_name) + ) + dataset_id = response.get("datasetId") + row_limit = response.get("rowLimit") + + if not (isinstance(dataset_id, str) and isinstance(row_limit, int)): + raise ServerError(f"Malformed response: {response}") + return CreateDatasetResponse(dataset_id=dataset_id, row_limit=row_limit) + + def upload_dataset_data(self, dataset_id: str, data: io.BufferedIOBase) -> None: + """Wraps: PUT /dataset/[id]/source""" + self._put_bytes( + url_path=f"dataset/{dataset_id}/source", + content=data, + params={"format": "parquet"}, + ) + + def _post_json( + self, url_path: str, content: RequestJson, params: Optional[Dict[str, str]] = None + ) -> ResponseJson: + headers = { + "Authorization": f"Bearer {self._api_key}", + } + url = self._full_url(url_path) + response = self._http_client.post( + url, headers=headers, json=content, params=params + ) + response_json = self._handle_response(response, expect_json=True) + assert response_json is not None # please mypy + return response_json + + def _put_bytes( + self, + url_path: str, + content: io.BufferedIOBase, + params: Optional[Dict[str, str]] = None, + ) -> None: + headers = { + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/octet-stream", + } + url = self._full_url(url_path) + + response = self._http_client.put( + url, + headers=headers, + content=iter([b""]), # send some dummy data to not consume the stream + params=params, + follow_redirects=False, + ) + if response.next_request is None: + logger.error("Response text:\n%s", response.text) + raise ServerError(f"Expected redirect but got: {response.status_code}") + + response = self._http_client.put( + response.next_request.url, + headers=response.next_request.headers, + content=_buffer_to_byte_iterable(content), + follow_redirects=False, + ) + self._handle_response(response, expect_json=False) + + def _handle_response( + self, response: httpx.Response, expect_json: bool + ) -> Optional[ResponseJson]: + response_json: Optional[ResponseJson] = None + error_message: Optional[str] = None + request_kind = response.request.method + url_path = response.request.url.path + + try: + response_json = response.json() + except Exception as e: + # raise more appropriate error below, but log content here + # to help debug. + logger.debug("Response did not contain json. %s", e) + + if isinstance(response_json, dict): + error_message = response_json.get("errorMessage") + error_message = response_json.get("errorMessageDisplay") or error_message + + base_message = ( + error_message + or f"Got '{response.status_code}' from {request_kind} to {url_path}" + ) + status_code = response.status_code + if status_code in (401, 403): + logger.error("Authentication error response text:\n%s", response.text) + raise AuthenticationError(f"You may not have access. {base_message}") + if status_code == 404: + raise NotFoundError(f"The resource may not exist. {base_message}") + if 400 <= status_code < 500: + raise BadRequestError(f"Bad Request. {base_message}") + if status_code // 100 != 2: + logger.error("Server error response text:\n%s", response.text) + # Consider 100s, 300s, 500s to all be server errors because they are not + # expected from the API. + raise ServerError(f"Server error. {base_message}") + + if expect_json and not ( + isinstance(response_json, dict) and "data" in response_json + ): + # All our json APIs return dicts. + logger.error("Malformed response text:\n%s", response.text) + raise ServerError("Malformed server response.") + + if expect_json: + return response_json["data"] # type: ignore + return None + + def _full_url(self, url_path: str) -> str: + return f"{self._base_url}/{url_path}" + + +@lru_cache(maxsize=1) +def client() -> AirtrainClient: + """Get the default Airtrain client. This is an internal API for advanced usage.""" + # Since we have used the lru_cache this will always return the same instance. + return AirtrainClient() + + +def _find_api_key() -> Optional[str]: + global _DEFAULT_API_KEY + if _DEFAULT_API_KEY is not None: + return _DEFAULT_API_KEY + return os.environ.get(API_KEY_ENV_VAR) + + +def set_api_key(api_key: str) -> None: + """Explicitly set the API key for the default client.""" + if api_key is None: + raise AuthenticationError("Invalid API key; must not be None") + global _DEFAULT_API_KEY + _DEFAULT_API_KEY = api_key + + # In case the default client already exists. + client()._api_key = api_key + + +def _buffer_to_byte_iterable(buffer: io.BufferedIOBase) -> Iterable[bytes]: + while True: + chunk = buffer.read(_BUFFER_CHUNK_SIZE) + if not chunk: + break + yield chunk diff --git a/src/airtrain/retry.py b/src/airtrain/retry.py new file mode 100644 index 0000000..7017fab --- /dev/null +++ b/src/airtrain/retry.py @@ -0,0 +1,222 @@ +# Credited to invl: https://github.com/invl/retry/blob/master/retry/api.py +# adapted to follow our linting and docstring style. Used here instead +# of depending on that library to avoid taking on another explicit +# third-party dep. +# Retrieved on 08/10/22, original source code uses the Apache 2 +# license. + +# Standard Library +import logging +import random +import time +from functools import partial, wraps +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + + +logging_logger = logging.getLogger(__name__) + + +def decorator(caller): + """Turns caller into a decorator. + Unlike decorator module, function signature is not preserved. + + Parameters + ---------- + caller: caller(f, *args, **kwargs) + """ + + def decor(f): + @wraps(f) + def wrapper(*args, **kwargs): + return caller(f, *args, **kwargs) + + return wrapper + + return decor + + +def __retry_internal( + f: Callable[[], Any], + exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception, + tries: int = -1, + delay: float = 0, + max_delay: Optional[float] = None, + backoff: float = 1, + jitter: Union[float, Tuple[float, float]] = 0, + logger: Optional[logging.Logger] = logging_logger, +): + """Executes a function and retries it if it failed. + + Parameters + ---------- + f: + the function to execute. + exceptions: + an exception or a tuple of exceptions to catch. default: Exception. + tries: + the maximum number of attempts. default: -1 (infinite). + delay: + initial delay between attempts. default: 0. + max_delay: + the maximum value of delay. default: None (no limit). + backoff: + multiplier applied to delay between attempts. default: 1 (no backoff). + jitter: + extra seconds added to delay between attempts. default: 0. fixed if a number, + random if a range tuple (min, max). + logger: + logger.warning(fmt, error, delay) will be called on failed attempts. + default: retry.logging_logger. if None, logging is disabled. + + Returns + ------- + the result of the f function. + """ + _tries, _delay = tries, delay + while _tries != 0: + try: + return f() + except exceptions as e: + _tries = max(-1, _tries - 1) + if _tries == 0: + raise + + if logger is not None: + logger.warning(e) + logger.warning( + "Retrying %s in %s seconds with %s tries left...", + f.__name__, + _delay, + _tries, + ) + + time.sleep(_delay) + _delay *= backoff + + if isinstance(jitter, tuple): + _delay += random.uniform(*jitter) + else: + _delay += jitter + + if max_delay is not None: + _delay = min(_delay, max_delay) + + +def retry( + exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception, + tries: int = -1, + delay: float = 0, + max_delay: Optional[float] = None, + backoff: float = 1, + jitter: float = 0, + logger: Optional[logging.Logger] = logging_logger, +): + """Returns a retry decorator. + + Parameters + ---------- + exceptions: + an exception or a tuple of exceptions to catch. default: Exception. + tries: + the maximum number of attempts. default: -1 (infinite). + delay: + initial delay between attempts. default: 0. + max_delay: + the maximum value of delay. default: None (no limit). + backoff: + multiplier applied to delay between attempts. default: 1 (no backoff). + jitter: + extra seconds added to delay between attempts. default: 0. + fixed if a number, random if a range tuple (min, max) + logger: + logger.warning(fmt, error, delay) will be called on failed attempts. + default: retry.logging_logger. if None, logging is disabled. + + Returns + ------- + A retry decorator. + """ + + @decorator + def retry_decorator(f, *fargs, **fkwargs): + args = fargs if fargs is not None else list() + kwargs = fkwargs if fkwargs is not None else dict() + partialed = _named_partial(f, *args, **kwargs) + + return __retry_internal( + f=partialed, + exceptions=exceptions, + tries=tries, + delay=delay, + max_delay=max_delay, + backoff=backoff, + jitter=jitter, + logger=logger, + ) + + return retry_decorator + + +def retry_call( + f: Callable[..., Any], + fargs: Optional[List[Any]] = None, + fkwargs: Optional[Dict[str, Any]] = None, + exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = Exception, + tries: int = -1, + delay: float = 0, + max_delay: Optional[float] = None, + backoff: float = 1, + jitter: float = 0, + logger: Optional[logging.Logger] = logging_logger, +): + """Calls a function and re-executes it if it failed. + + Parameters + ---------- + f: + the function to execute. + fargs: + the positional arguments of the function to execute. + fkwargs: + the named arguments of the function to execute. + exceptions: + an exception or a tuple of exceptions to catch. default: Exception. + tries: + the maximum number of attempts. default: -1 (infinite). + delay: + initial delay between attempts. default: 0. + max_delay: + the maximum value of delay. default: None (no limit). + backoff: + multiplier applied to delay between attempts. default: 1 (no backoff). + jitter: + extra seconds added to delay between attempts. default: 0. + fixed if a number, random if a range tuple (min, max) + logger: + logger.warning(fmt, error, delay) will be called on failed attempts. + default: retry.logging_logger. if None, logging is disabled. + + Returns + -------- + the result of the f function. + """ + args = fargs if fargs is not None else list() + kwargs = fkwargs if fkwargs is not None else dict() + partialed = _named_partial(f, *args, **kwargs) + + return __retry_internal( + f=partialed, + exceptions=exceptions, + tries=tries, + delay=delay, + max_delay=max_delay, + backoff=backoff, + jitter=jitter, + logger=logger, + ) + + +def _named_partial(f: Callable[..., Any], *args, **kwargs) -> Callable[..., Any]: + partialed = partial(f, *args, **kwargs) + setattr(partialed, "__name__", getattr(f, "__name__", str(f))) + return partialed diff --git a/src/tests/test_client.py b/src/tests/test_client.py new file mode 100644 index 0000000..ed624eb --- /dev/null +++ b/src/tests/test_client.py @@ -0,0 +1,116 @@ +import io +import itertools +import os +from unittest.mock import MagicMock + +import pytest + +from airtrain.client import ( + API_KEY_ENV_VAR, + AirtrainClient, + AuthenticationError, + BadRequestError, + NotFoundError, + ServerError, + client, + set_api_key, + _buffer_to_byte_iterable, +) +from tests.utils import environment_variables + + +def test_set_api_key(): + client.cache_clear() + + # make sure no existing env var is clouding the test. + assert os.environ.get(API_KEY_ENV_VAR) is None + with pytest.raises(AuthenticationError): + client() + + with environment_variables({API_KEY_ENV_VAR: "foo"}): + c = client() + assert isinstance(c, AirtrainClient) + assert c._api_key == "foo" + + client.cache_clear() + with pytest.raises(AuthenticationError): + client() + + set_api_key("bar") + c = client() + assert isinstance(c, AirtrainClient) + assert c._api_key == "bar" + + set_api_key("baz") + assert c._api_key == "baz" + + +def test_default_client(): + client.cache_clear() + set_api_key("secret") + + assert client() is client() + + +def test_buffer_to_byte_iterable(): + hex_digits = "0123456789abcdef" + + # some non-repeating byte data that should be larger than the byte buffer + original_bytes = bytes.fromhex( + "".join( + "".join(p) + for p in itertools.islice(itertools.permutations(hex_digits, 16), 2**12) + ) + ) + + buffer_iter = _buffer_to_byte_iterable(io.BytesIO(original_bytes)) + + counter = itertools.count() + read_bytes = b"".join(buffer for buffer, _ in zip(buffer_iter, counter)) + assert read_bytes == original_bytes + + # ensure we had to use more than 1 buffer. + assert next(counter) > 2 + + +def test_handle_response(): + c = AirtrainClient(api_key="secret") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"data": 42} + assert c._handle_response(mock_response, expect_json=True) == 42 + + mock_response.reset_mock() + mock_response.status_code = 200 + mock_response.json.return_value = 42 + with pytest.raises(ServerError): + c._handle_response(mock_response, expect_json=True) + assert c._handle_response(mock_response, expect_json=False) is None + + mock_response.status_code = 404 + mock_response.json.return_value = { + "data": 42, + "errorMessage": "These are not the droids you're looking for", + } + with pytest.raises(NotFoundError, match=r".*not the droids.*"): + c._handle_response(mock_response, expect_json=True) + + mock_response.status_code = 401 + with pytest.raises(AuthenticationError): + c._handle_response(mock_response, expect_json=True) + + mock_response.status_code = 400 + with pytest.raises(BadRequestError): + c._handle_response(mock_response, expect_json=True) + + mock_response.status_code = 500 + with pytest.raises(ServerError): + c._handle_response(mock_response, expect_json=True) + + mock_response.reset_mock() + mock_response.status_code = 200 + mock_response.json.side_effect = ValueError("A problem has be to your computer") + with pytest.raises(ServerError): + c._handle_response(mock_response, expect_json=True) + c._handle_response(mock_response, expect_json=False) diff --git a/src/tests/utils.py b/src/tests/utils.py new file mode 100644 index 0000000..292c5e3 --- /dev/null +++ b/src/tests/utils.py @@ -0,0 +1,35 @@ +import contextlib +import os +from typing import Dict, Optional + + +@contextlib.contextmanager +def environment_variables(to_set: Dict[str, Optional[str]]): + """ + Context manager to configure the os environ. + + After exiting the context, the original env vars will be back in place. + + Parameters + ---------- + to_set: + A dict from env var name to env var value. If the env var value is None, that will + be treated as indicating that the env var should be unset within the managed + context. + """ + backup_of_changed_keys = {k: os.environ.get(k, None) for k in to_set.keys()} + + def update_environ_with(env_dict): + for key, value in env_dict.items(): + if value is None: + if key in os.environ: + del os.environ[key] + else: + os.environ[key] = value + + update_environ_with(to_set) + + try: + yield + finally: + update_environ_with(backup_of_changed_keys) diff --git a/uv.lock b/uv.lock index 30fd063..c696e94 100644 --- a/uv.lock +++ b/uv.lock @@ -28,7 +28,6 @@ dev = [ [package.metadata] requires-dist = [ - { name = "airtrain", editable = "." }, { name = "httpx", specifier = ">=0.25.0" }, { name = "numpy", marker = "python_full_version == '3.8.*'", specifier = "<=1.24.4" }, { name = "numpy", marker = "python_full_version >= '3.9'", specifier = ">=1.19.3" },