Skip to content

Commit

Permalink
Add table tests
Browse files Browse the repository at this point in the history
  • Loading branch information
augray committed Sep 1, 2024
1 parent 07e12f9 commit d139041
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
10 changes: 5 additions & 5 deletions src/airtrain/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@
from airtrain.client import client


# Unpack is only >3.11 . We'll just rely on type
# checking in those versions to catch mistakes.
# This will make Unpack[CreationArgs] into
# Optional[Any] for lower versions, which should
# pass checks.
if sys.version_info > (3, 11):
from typing import TypedDict, Unpack

class CreationArgs(TypedDict):
name: Optional[str]
embedding_column: Optional[str]
else:
# Unpack is only >3.11 . We'll just rely on type
# checking in those versions to catch mistakes.
# This will make Unpack[CreationArgs] into
# Optional[Any] for lower versions, which should
# pass checks.
from typing import Optional as Unpack # noqa
from typing import Any as CreationArgs # noqa

Expand Down
22 changes: 21 additions & 1 deletion src/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pyarrow as pa
import pytest

from airtrain.core import DatasetMetadata, upload_from_dicts
from airtrain.core import DatasetMetadata, upload_from_arrow_tables, upload_from_dicts
from tests.fixtures import MockAirtrainClient, mock_client # noqa: F401


Expand Down Expand Up @@ -139,3 +139,23 @@ def test_bad_embeds(mock_client: MockAirtrainClient): # noqa: F811
]
with pytest.raises(ValueError):
upload_from_dicts(data, embedding_column="bar")


def test_upload_from_arrow_tables(mock_client: MockAirtrainClient): # noqa: F811
table_1 = pa.table({"foo": [1, 2, 3], "bar": ["a", "b", "c"]})
table_2 = pa.table({"foo": [4, 5, 6], "bar": ["d", "e", "f"]})
uploaded = upload_from_arrow_tables([table_1, table_2], name="My Arrow")
fake_dataset = mock_client.get_fake_dataset(uploaded.id)
table = fake_dataset.ingested
assert table is not None
assert table.shape[0] == table_1.shape[0] + table_2.shape[0]
assert table["foo"].to_pylist() == [1, 2, 3, 4, 5, 6]
assert table["bar"].to_pylist() == ["a", "b", "c", "d", "e", "f"]


def test_upload_from_mismatched_tables(mock_client: MockAirtrainClient): # noqa: F811
table_1 = pa.table({"foo": [1, 2, 3], "bar": ["a", "b", "c"]})
table_2 = pa.table({"foo": ["d", "e", "f"], "bar": [4, 5, 6]})

with pytest.raises(ValueError):
upload_from_arrow_tables([table_1, table_2], name="My Arrow")

0 comments on commit d139041

Please sign in to comment.