Skip to content

Commit

Permalink
Add http API wrapper (#2)
Browse files Browse the repository at this point in the history
Add a wrapper of our HTTP API for the rest of the SDK to use.


Testing
--------

In addition to the unit tests, tested with the following driver code
while pointing at the dev Airtrain deployment (remember, sdk users will
NOT be using this API directly; this is just to test the wrapper):

```python
import io

import pyarrow.json as jsonl
import pyarrow.parquet as pq

from airtrain.client import client


def main() -> None:
    json_buffer = io.BytesIO()
    json_buffer.write(b'{"text": "hello"}\n')
    json_buffer.write(b'{"text": "world"}')
    json_buffer.seek(0)

    parquet_buffer = io.BytesIO()
    table = jsonl.read_json(json_buffer)
    pq.write_table(table, parquet_buffer)
    response = client().create_dataset(name="From API Client", embedding_column_name=None)

    parquet_buffer.seek(0)
    client().upload_dataset_data(response.dataset_id, parquet_buffer)

    ingest_response = client().trigger_dataset_ingest(response.dataset_id)

    print(
        f"Created dataset {response.dataset_id} "
        f"with ingest job {ingest_response.ingest_job_id}"
    )


if __name__ == "__main__":
    main()
```
  • Loading branch information
augray authored Sep 4, 2024
1 parent b4140df commit 877f90d
Show file tree
Hide file tree
Showing 10 changed files with 636 additions and 7 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
tmp.*

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
10 changes: 10 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@ PY_VERSION := "3.11"

.PHONY: wheel
wheel:
rm -rf dist build src/*.egg-info
uvx pip wheel -w dist .
.PHONY: test-release
test-release: wheel
uvx twine check dist/*airtrain*.whl
uvx twine upload --repository testpypi dist/*airtrain*.whl

.PHONY: release
release: wheel
uvx twine check dist/*airtrain*.whl
uvx twine upload dist/*airtrain*.whl

.PHONY: test-release
test-release: wheel
Expand Down
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
<a href="https://docs.python.org/3.12/" target="_blank">
<img height="30px" src="https://img.shields.io/badge/Python-3.12-blue?style=for-the-badge&logo=python" alt="Python Version">
</a>
<a href="https://github.com/sematic-ai/py-airtrain/actions/workflows/ci.yaml?query=branch%3Amain+" target="_blank">
<img height="30px" src="https://github.com/sematic-ai/py-airtrain/actions/workflows/ci.yaml/badge.svg?branch=main" alt="CI status">
</a>
</p>


Expand All @@ -51,7 +48,7 @@ Then you may upload a new dataset as follows:
import airtrain as at

# Can also be set with the environment variable AIRTRAIN_API_KEY
at.api_key = "sUpErSeCr3t"
at.set_api_key("sUpErSeCr3t")

url = at.upload_from_dicts(
[
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ dependencies = [
"numpy>=1.26.0; python_version >= '3.12'",
"numpy<=1.24.4; python_version == '3.8'",
"numpy>=1.19.3; python_version >= '3.9'",

"airtrain",
]

classifiers = [
Expand Down Expand Up @@ -55,6 +53,11 @@ Documentation = "https://docs.airtrain.ai/"
requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["src"]
include = ["airtrain*"]
exclude = ["tests*", "*.tests*"]

[tool.ruff]
line-length = 90

Expand Down
1 change: 1 addition & 0 deletions src/airtrain/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from airtrain.client import set_api_key # noqa: F401
from airtrain.core import DatasetMetadata, upload_from_dicts # noqa: F401
244 changes: 244 additions & 0 deletions src/airtrain/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import io
import logging
import os
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Dict, Iterable, Optional

import httpx


logger = logging.getLogger(__name__)


RequestJson = Dict[str, Any]
ResponseJson = Dict[str, Any]


API_KEY_ENV_VAR: str = "AIRTRAIN_API_KEY"
_DEFAULT_API_KEY: Optional[str] = None
_DEFAULT_BASE_URL: str = "https://api.airtrain.ai"
_BUFFER_CHUNK_SIZE = 8192


class BadRequestError(Exception):
"""The request was bad for some reason."""

pass


class AuthenticationError(BadRequestError):
"""The caller does not have permission to complete the request."""

pass


class NotFoundError(BadRequestError):
"""The caller requested something that doesn't exist."""

pass


class ServerError(Exception):
"""There was some problem with the server."""

pass


@dataclass
class CreateDatasetResponse:
dataset_id: str
row_limit: int


@dataclass
class TriggerIngestResponse:
ingest_job_id: str


class AirtrainClient:
"""A direct wrapper around Airtrain's HTTP API. Intended for internal package use.
SDK users should NOT use this class directly and should not assume it will have a
stable API. It makes no attempt to make sure calls are sequenced in an appropriate
order or that information is passed between calls in a logical way. There should be
a direct 1:1 correspondence between methods here and API endpoints that the SDK
needs to interact with.
"""

def __init__(
self, api_key: Optional[str] = None, base_url: Optional[str] = None
) -> None:
self._api_key: str = api_key or _find_api_key() # type: ignore
self._base_url: str = base_url or os.environ.get( # type: ignore
"AIRTRAIN_API_URL", _DEFAULT_BASE_URL
)
if self._base_url.endswith("/"):
self._base_url = self._base_url[:-1]
self._http_client = httpx.Client()

if self._api_key is None:
raise AuthenticationError(
"No Api key found. "
"Set one with the environment variable 'AIRTRAIN_API_KEY' or the "
"function airtrain.set_api_key"
)

def trigger_dataset_ingest(self, dataset_id: str) -> TriggerIngestResponse:
"""Wraps: POST /dataset/[id]/ingest"""
response = self._post_json(url_path=f"dataset/{dataset_id}/ingest", content={})
job_id = response.get("ingestionJobId")
if not isinstance(job_id, str):
raise ServerError(f"Malformed response: {response}")
return TriggerIngestResponse(ingest_job_id=job_id)

def create_dataset(
self, name: str, embedding_column_name: Optional[str]
) -> CreateDatasetResponse:
"""Wraps: POST /dataset"""
response = self._post_json(
"dataset", dict(name=name, embeddingColumn=embedding_column_name)
)
dataset_id = response.get("datasetId")
row_limit = response.get("rowLimit")

if not (isinstance(dataset_id, str) and isinstance(row_limit, int)):
raise ServerError(f"Malformed response: {response}")
return CreateDatasetResponse(dataset_id=dataset_id, row_limit=row_limit)

def upload_dataset_data(self, dataset_id: str, data: io.BufferedIOBase) -> None:
"""Wraps: PUT /dataset/[id]/source"""
self._put_bytes(
url_path=f"dataset/{dataset_id}/source",
content=data,
params={"format": "parquet"},
)

def _post_json(
self, url_path: str, content: RequestJson, params: Optional[Dict[str, str]] = None
) -> ResponseJson:
headers = {
"Authorization": f"Bearer {self._api_key}",
}
url = self._full_url(url_path)
response = self._http_client.post(
url, headers=headers, json=content, params=params
)
response_json = self._handle_response(response, expect_json=True)
assert response_json is not None # please mypy
return response_json

def _put_bytes(
self,
url_path: str,
content: io.BufferedIOBase,
params: Optional[Dict[str, str]] = None,
) -> None:
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/octet-stream",
}
url = self._full_url(url_path)

response = self._http_client.put(
url,
headers=headers,
content=iter([b""]), # send some dummy data to not consume the stream
params=params,
follow_redirects=False,
)
if response.next_request is None:
logger.error("Response text:\n%s", response.text)
raise ServerError(f"Expected redirect but got: {response.status_code}")

response = self._http_client.put(
response.next_request.url,
headers=response.next_request.headers,
content=_buffer_to_byte_iterable(content),
follow_redirects=False,
)
self._handle_response(response, expect_json=False)

def _handle_response(
self, response: httpx.Response, expect_json: bool
) -> Optional[ResponseJson]:
response_json: Optional[ResponseJson] = None
error_message: Optional[str] = None
request_kind = response.request.method
url_path = response.request.url.path

try:
response_json = response.json()
except Exception as e:
# raise more appropriate error below, but log content here
# to help debug.
logger.debug("Response did not contain json. %s", e)

if isinstance(response_json, dict):
error_message = response_json.get("errorMessage")
error_message = response_json.get("errorMessageDisplay") or error_message

base_message = (
error_message
or f"Got '{response.status_code}' from {request_kind} to {url_path}"
)
status_code = response.status_code
if status_code in (401, 403):
logger.error("Authentication error response text:\n%s", response.text)
raise AuthenticationError(f"You may not have access. {base_message}")
if status_code == 404:
raise NotFoundError(f"The resource may not exist. {base_message}")
if 400 <= status_code < 500:
raise BadRequestError(f"Bad Request. {base_message}")
if status_code // 100 != 2:
logger.error("Server error response text:\n%s", response.text)
# Consider 100s, 300s, 500s to all be server errors because they are not
# expected from the API.
raise ServerError(f"Server error. {base_message}")

if expect_json and not (
isinstance(response_json, dict) and "data" in response_json
):
# All our json APIs return dicts.
logger.error("Malformed response text:\n%s", response.text)
raise ServerError("Malformed server response.")

if expect_json:
return response_json["data"] # type: ignore
return None

def _full_url(self, url_path: str) -> str:
return f"{self._base_url}/{url_path}"


@lru_cache(maxsize=1)
def client() -> AirtrainClient:
"""Get the default Airtrain client. This is an internal API for advanced usage."""
# Since we have used the lru_cache this will always return the same instance.
return AirtrainClient()


def _find_api_key() -> Optional[str]:
global _DEFAULT_API_KEY
if _DEFAULT_API_KEY is not None:
return _DEFAULT_API_KEY
return os.environ.get(API_KEY_ENV_VAR)


def set_api_key(api_key: str) -> None:
"""Explicitly set the API key for the default client."""
if api_key is None:
raise AuthenticationError("Invalid API key; must not be None")
global _DEFAULT_API_KEY
_DEFAULT_API_KEY = api_key

# In case the default client already exists.
client()._api_key = api_key


def _buffer_to_byte_iterable(buffer: io.BufferedIOBase) -> Iterable[bytes]:
while True:
chunk = buffer.read(_BUFFER_CHUNK_SIZE)
if not chunk:
break
yield chunk
Loading

0 comments on commit 877f90d

Please sign in to comment.