Skip to content

Commit

Permalink
Basic driver working
Browse files Browse the repository at this point in the history
  • Loading branch information
augray committed Aug 29, 2024
1 parent 5cf2d3d commit c4f4432
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 9 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,7 @@ enabled = true
[[tool.mypy.overrides]]
module = "airtrain.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "pyarrow.*"
ignore_missing_imports = true
6 changes: 5 additions & 1 deletion src/airtrain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
from airtrain.client import set_api_key # noqa: F401
from airtrain.core import DatasetMetadata, upload_from_dicts # noqa: F401
from airtrain.core import ( # noqa: F401
DatasetMetadata,
upload_from_arrow_tables,
upload_from_dicts,
)
8 changes: 8 additions & 0 deletions src/airtrain/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ def __init__(
"function airtrain.set_api_key"
)

def dataset_dashboard_url(self, dataset_id: str) -> str:
"""Get the webapp URL for a dataset, given its id."""
api_url = self._base_url
app_url = api_url.replace("://api.dev", "://airtrain.dev").replace(
"://api.", "://app."
)
return f"{app_url}/dataset/{dataset_id}"

def trigger_dataset_ingest(self, dataset_id: str) -> TriggerIngestResponse:
"""Wraps: POST /dataset/[id]/ingest"""
response = self._post_json(url_path=f"dataset/{dataset_id}/ingest", content={})
Expand Down
151 changes: 146 additions & 5 deletions src/airtrain/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,36 @@
import io
import logging
import sys
from collections import defaultdict
from dataclasses import dataclass, fields
from typing import Any, Dict, Iterable, Union
from datetime import datetime
from itertools import islice
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union

import pyarrow as pa
import pyarrow.parquet as pq

from airtrain.client import client


if sys.version_info[:2] > (3, 11):
from typing import TypedDict, Unpack

class CreationArgs(TypedDict):
name: Optional[str]
embedding_column: Optional[str]

CreationKeywordArgs = Unpack[CreationArgs]
else:
# Unpack is only >3.11 . We'll just rely on type
# checking in those versions to catch mistakes.
CreationKeywordArgs = Any


logger = logging.getLogger(__name__)


_MAX_BATCH_SIZE: int = 2000


@dataclass
Expand All @@ -20,12 +51,122 @@ def __post_init__(self) -> None:

def upload_from_dicts(
data: Iterable[Dict[str, Any]],
schema: Optional[pa.Schema] = None,
**kwargs: CreationKeywordArgs,
) -> DatasetMetadata:
"""Upload an Airtrain dataset from the provided dictionaries.
Parameters
----------
data:
An iterable of dictionary data to construct an Airtrain dataset out of.
Each row in the data must be a python dictionary, and use only types
that can be converted into pyarrow types. Data will be intermediately
represented as pyarrow tables.
schema:
Optionally, the Arrow schema the data conforms to. If not provided, the
schema will be inferred from a sample of the data.
kwargs:
See `upload_from_arrow_tables` for other arguments.
Returns
-------
A DatasetMetadata object summarizing the created dataset.
"""
data = iter(data) # to ensure itertools works even if it was a list, etc.
batches = _batched(data, _MAX_BATCH_SIZE)
return upload_from_arrow_tables(
data=_dict_batches_to_tables(batches, schema),
**kwargs,
)


def upload_from_arrow_tables(
data: Iterable[pa.Table],
name: Union[str, None] = None,
embedding_column: Union[str, None] = None,
) -> DatasetMetadata:
name = name or f"My Dataset {datetime.now()}"
c = client()
creation_call_result = c.create_dataset(
name=name, embedding_column_name=embedding_column
)
limit = creation_call_result.row_limit
dataset_id = creation_call_result.dataset_id
size = 0
schema: Optional[pa.Schema] = None

for table in data:
if schema is None:
schema = table.schema
if schema != table.schema:
logger.error("Mismatched schemas:\n%s\n\n%s", schema, table.schema)
raise ValueError("All uploaded tables must have the same schema.")
if embedding_column is not None and embedding_column not in schema.names:
# TODO: validate embedding schema
raise ValueError(
f"No column '{embedding_column}' found in data for embeddings"
)
table = table[: limit - size]

upload_buffer = io.BytesIO()
pq.write_table(table, upload_buffer)
upload_buffer.seek(0)
c.upload_dataset_data(dataset_id, upload_buffer)
size += table.shape[0]

if size >= limit:
break

c.trigger_dataset_ingest(dataset_id)
return DatasetMetadata(
name=name or "My Dataset",
id="abc123",
url="https://example.com",
size=0,
name=name,
id=dataset_id,
url=c.dataset_dashboard_url(dataset_id),
size=size,
)


T = TypeVar("T")


def _dict_batches_to_tables(
batches: Iterable[Tuple[Dict[str, Any], ...]], schema: Optional[pa.Schema] = None
) -> Iterable[pa.Table]:
for batch in batches:
table = _dicts_to_table(batch, schema)
if schema is None:
# ensure later batches use the same schema.
schema = table.schema
yield table


def _dicts_to_table(
dicts: Tuple[Dict[str, Any], ...], schema: Optional[pa.Schema]
) -> pa.Table:
columns: Set[str] = set()
for row in dicts:
if not isinstance(row, dict):
logger.error("Unexpected row: %s", row)
raise ValueError("All data rows must be python dicts.")
columns.update(row.keys())

table_dict: Dict[str, List[Any]] = defaultdict(list)
for row in dicts:
for column in columns:
table_dict[column].append(row.get(column))
return pa.table(table_dict, schema=schema)


# This is in the standard lib in itertools as of 3.12; this code
# is adapted from documentation there.
def _batched(iterable: Iterable[T], n: int) -> Iterable[Tuple[T, ...]]:
if n < 1:
raise ValueError("n must be at least one")
iterator = iter(iterable)
batch: Tuple[T, ...] = ()
while True:
batch = tuple(islice(iterator, n))
if len(batch) == 0:
break
yield batch
9 changes: 6 additions & 3 deletions src/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from airtrain.core import DatasetMetadata, upload_from_dicts
import pytest

from airtrain.client import AuthenticationError
from airtrain.core import upload_from_dicts


def test_upload_from_dicts():
result = upload_from_dicts([{"foo": 42}, {"foo": 43}], name="Foo dataset")
assert isinstance(result, DatasetMetadata)
with pytest.raises(AuthenticationError):
upload_from_dicts([{"foo": 42}, {"foo": 43}], name="Foo dataset")

0 comments on commit c4f4432

Please sign in to comment.