Skip to content

Commit

Permalink
move functional tests out of unit test suite (#832)
Browse files Browse the repository at this point in the history
* move tests using cloud_test_catalog into func directory

* move tests using tmpfile catalog

* move long running tests that read/write from disk
  • Loading branch information
mattseddon authored Jan 20, 2025
1 parent 2c9c71f commit d4698fe
Show file tree
Hide file tree
Showing 17 changed files with 587 additions and 544 deletions.
96 changes: 96 additions & 0 deletions tests/func/test_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
import asyncio
import os
import sys
from pathlib import Path

import pytest
from fsspec.asyn import sync
from hypothesis import HealthCheck, assume, given, settings
from hypothesis import strategies as st
from tqdm import tqdm

from datachain.asyn import get_loop
from datachain.client import Client
from tests.data import ENTRIES
from tests.utils import uppercase_scheme

_non_null_text = st.text(
alphabet=st.characters(blacklist_categories=["Cc", "Cs"]), min_size=1
)


@pytest.fixture
Expand Down Expand Up @@ -91,3 +101,89 @@ def test_fetch_dir_does_not_return_self(client, cloud_type):
)

assert "directory" not in subdirs


@settings(suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None)
@given(rel_path=_non_null_text)
def test_parse_url(cloud_test_catalog, rel_path, cloud_type):
if cloud_type == "file":
assume(not rel_path.startswith("/"))
bucket_uri = cloud_test_catalog.src_uri
url = f"{bucket_uri}/{rel_path}"
uri, rel_part = Client.parse_url(url)
if cloud_type == "file":
assert uri == url.rsplit("/", 1)[0]
assert rel_part == url.rsplit("/", 1)[1]
else:
assert uri == bucket_uri
assert rel_part == rel_path


@settings(suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None)
@given(rel_path=_non_null_text)
def test_get_client(cloud_test_catalog, rel_path, cloud_type):
catalog = cloud_test_catalog.catalog
bucket_uri = cloud_test_catalog.src_uri
url = f"{bucket_uri}/{rel_path}"
client = Client.get_client(url, catalog.cache)
assert client
assert client.uri


@settings(suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None)
@given(rel_path=_non_null_text)
def test_parse_url_uppercase_scheme(cloud_test_catalog, rel_path, cloud_type):
if cloud_type == "file":
assume(not rel_path.startswith("/"))
bucket_uri = cloud_test_catalog.src_uri
bucket_uri_upper = uppercase_scheme(bucket_uri)
url = f"{bucket_uri_upper}/{rel_path}"
uri, rel_part = Client.parse_url(url)
if cloud_type == "file":
url = f"{bucket_uri}/{rel_path}"
assert uri == url.rsplit("/", 1)[0]
assert rel_part == url.rsplit("/", 1)[1]
else:
assert uri == bucket_uri
assert rel_part == rel_path


@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
def test_parse_file_absolute_path_without_protocol(cloud_test_catalog):
working_dir = Path().absolute()
uri, rel_part = Client.parse_url(str(working_dir / Path("animals")))
assert uri == working_dir.as_uri()
assert rel_part == "animals"


@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
def test_parse_file_relative_path_multiple_dirs_back(cloud_test_catalog):
uri, rel_part = Client.parse_url("../../animals".replace("/", os.sep))
assert uri == Path().absolute().parents[1].as_uri()
assert rel_part == "animals"


@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
@pytest.mark.parametrize("url", ["./animals".replace("/", os.sep), "animals"])
def test_parse_file_relative_path_working_dir(cloud_test_catalog, url):
uri, rel_part = Client.parse_url(url)
assert uri == Path().absolute().as_uri()
assert rel_part == "animals"


@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
def test_parse_file_relative_path_home_dir(cloud_test_catalog):
if sys.platform == "win32":
# home dir shortcut is not available on windows
pytest.skip()
uri, rel_part = Client.parse_url("~/animals")
assert uri == Path().home().as_uri()
assert rel_part == "animals"


@pytest.mark.parametrize("cloud_type", ["s3", "azure", "gs"], indirect=True)
def test_parse_cloud_path_ends_with_slash(cloud_test_catalog):
uri = f"{cloud_test_catalog.src_uri}/animals/"
uri, rel_part = Client.parse_url(uri)
assert uri == cloud_test_catalog.src_uri
assert rel_part == "animals/"
160 changes: 160 additions & 0 deletions tests/func/test_data_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from datetime import datetime
from typing import Any

import pytest

from datachain.sql.types import (
JSON,
Array,
Boolean,
DateTime,
Float,
Float32,
Float64,
Int,
String,
)
from tests.utils import (
DEFAULT_TREE,
TARRED_TREE,
create_tar_dataset_with_legacy_columns,
)

COMPLEX_TREE: dict[str, Any] = {
**TARRED_TREE,
**DEFAULT_TREE,
"nested": {"dir": {"path": {"abc.txt": "abc"}}},
}


@pytest.mark.parametrize("tree", [COMPLEX_TREE], indirect=True)
def test_dir_expansion(cloud_test_catalog, version_aware, cloud_type):
has_version = version_aware or cloud_type == "gs"

ctc = cloud_test_catalog
session = ctc.session
catalog = ctc.catalog
src_uri = ctc.src_uri
if cloud_type == "file":
# we don't want to index things in parent directory
src_uri += "/"

dc = create_tar_dataset_with_legacy_columns(session, ctc.src_uri, "dc")
dataset = catalog.get_dataset(dc.name)
with catalog.warehouse.clone() as warehouse:
dr = warehouse.dataset_rows(dataset, object_name="file")
de = dr.dir_expansion()
q = de.query(dr.get_table())

columns = (
"id",
"is_dir",
"source",
"path",
"version",
"location",
)

result = [dict(zip(columns, r)) for r in warehouse.db.execute(q)]
to_compare = [(r["path"], r["is_dir"], r["version"] != "") for r in result]

assert all(r["source"] == ctc.src_uri for r in result)

# Note, we have both a file and a directory entry for expanded tar files
expected = [
("animals.tar", 0, has_version),
("animals.tar", 1, False),
("animals.tar/cats", 1, False),
("animals.tar/cats/cat1", 0, has_version),
("animals.tar/cats/cat2", 0, has_version),
("animals.tar/description", 0, has_version),
("animals.tar/dogs", 1, False),
("animals.tar/dogs/dog1", 0, has_version),
("animals.tar/dogs/dog2", 0, has_version),
("animals.tar/dogs/dog3", 0, has_version),
("animals.tar/dogs/others", 1, False),
("animals.tar/dogs/others/dog4", 0, has_version),
("cats", 1, False),
("cats/cat1", 0, has_version),
("cats/cat2", 0, has_version),
("description", 0, has_version),
("dogs", 1, False),
("dogs/dog1", 0, has_version),
("dogs/dog2", 0, has_version),
("dogs/dog3", 0, has_version),
("dogs/others", 1, False),
("dogs/others/dog4", 0, has_version),
("nested", 1, False),
("nested/dir", 1, False),
("nested/dir/path", 1, False),
("nested/dir/path/abc.txt", 0, has_version),
]

assert to_compare == expected


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
indirect=True,
)
def test_convert_type(cloud_test_catalog):
ctc = cloud_test_catalog
catalog = ctc.catalog
warehouse = catalog.warehouse
now = datetime.now()

def run_convert_type(value, sql_type):
return warehouse.convert_type(
value,
sql_type,
warehouse.python_type(sql_type),
type(sql_type).__name__,
"test_column",
)

# convert int to float
for f in [Float, Float32, Float64]:
converted = run_convert_type(1, f())
assert converted == 1.0
assert isinstance(converted, float)

# types match, nothing to convert
assert run_convert_type(1, Int()) == 1
assert run_convert_type(1.5, Float()) == 1.5
assert run_convert_type(True, Boolean()) is True
assert run_convert_type("s", String()) == "s"
assert run_convert_type(now, DateTime()) == now
assert run_convert_type([1, 2], Array(Int)) == [1, 2]
assert run_convert_type([1.5, 2.5], Array(Float)) == [1.5, 2.5]
assert run_convert_type(["a", "b"], Array(String)) == ["a", "b"]
assert run_convert_type([[1, 2], [3, 4]], Array(Array(Int))) == [
[1, 2],
[3, 4],
]

# JSON Tests
assert run_convert_type('{"a": 1}', JSON()) == '{"a": 1}'
assert run_convert_type({"a": 1}, JSON()) == '{"a": 1}'
assert run_convert_type([{"a": 1}], JSON()) == '[{"a": 1}]'
with pytest.raises(ValueError):
run_convert_type(0.5, JSON())

# convert array to compatible type
converted = run_convert_type([1, 2], Array(Float))
assert converted == [1.0, 2.0]
assert all(isinstance(c, float) for c in converted)

# convert nested array to compatible type
converted = run_convert_type([[1, 2], [3, 4]], Array(Array(Float)))
assert converted == [[1.0, 2.0], [3.0, 4.0]]
assert all(isinstance(c, float) for c in converted[0])
assert all(isinstance(c, float) for c in converted[1])

# error, float to int
with pytest.raises(ValueError):
run_convert_type(1.5, Int())

# error, float to int in list
with pytest.raises(ValueError):
run_convert_type([1.5, 1], Array(Int))
89 changes: 89 additions & 0 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import json
import math
import os
import pickle
Expand Down Expand Up @@ -555,6 +556,23 @@ def test_mutate_existing_column(test_session):
assert list(ds.order_by("ids").collect()) == [(2,), (3,), (4,)]


@pytest.mark.parametrize("processes", [False, 2, True])
@pytest.mark.xdist_group(name="tmpfile")
def test_parallel(processes, test_session_tmpfile):
prefix = "t & "
vals = ["a", "b", "c", "d", "e", "f", "g", "h", "i"]

res = list(
DataChain.from_values(key=vals, session=test_session_tmpfile)
.settings(parallel=processes)
.map(res=lambda key: prefix + key)
.order_by("res")
.collect("res")
)

assert res == [prefix + v for v in vals]


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
Expand Down Expand Up @@ -613,6 +631,36 @@ def name_len(name):
assert count == 7


@pytest.mark.xdist_group(name="tmpfile")
def test_udf_parallel_boostrap(test_session_tmpfile):
vals = ["a", "b", "c", "d", "e", "f"]

class MyMapper(Mapper):
DEFAULT_VALUE = 84
BOOTSTRAP_VALUE = 1452
TEARDOWN_VALUE = 98763

def __init__(self):
super().__init__()
self.value = MyMapper.DEFAULT_VALUE
self._had_teardown = False

def process(self, *args) -> int:
return self.value

def setup(self):
self.value = MyMapper.BOOTSTRAP_VALUE

def teardown(self):
self.value = MyMapper.TEARDOWN_VALUE

chain = DataChain.from_values(key=vals, session=test_session_tmpfile)

res = list(chain.settings(parallel=4).map(res=MyMapper()).collect("res"))

assert res == [MyMapper.BOOTSTRAP_VALUE] * len(vals)


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
Expand Down Expand Up @@ -1653,6 +1701,47 @@ def test_to_from_parquet_partitioned_remote(cloud_test_catalog_upload):
assert df_equal(df1, df)


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_to_from_json(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
dc_to = DataChain.from_pandas(df, session=test_session)
path = tmp_dir / "test.json"
dc_to.order_by("first_name", "age").to_json(path)

with open(path) as f:
values = json.load(f)
assert values == [
{"first_name": n, "age": a, "city": c}
for n, a, c in zip(DF_DATA["first_name"], DF_DATA["age"], DF_DATA["city"])
]

dc_from = DataChain.from_json(path.as_uri(), session=test_session)
df1 = dc_from.select("json.first_name", "json.age", "json.city").to_pandas()
df1 = df1["json"]
assert df_equal(df1, df)


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_from_json_jmespath(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
values = [
{"first_name": n, "age": a, "city": c}
for n, a, c in zip(DF_DATA["first_name"], DF_DATA["age"], DF_DATA["city"])
]
path = tmp_dir / "test.json"
with open(path, "w") as f:
json.dump({"author": "Test User", "version": 5, "values": values}, f)

dc_from = DataChain.from_json(
path.as_uri(), jmespath="values", session=test_session
)
df1 = dc_from.select("values.first_name", "values.age", "values.city").to_pandas()
df1 = df1["values"]
assert df_equal(df1, df)


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_to_from_json_remote(cloud_test_catalog_upload):
Expand Down
Loading

0 comments on commit d4698fe

Please sign in to comment.