Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move functional tests out of unit test suite #832

Merged
merged 3 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions tests/func/test_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
import asyncio
import os
import sys
from pathlib import Path

import pytest
from fsspec.asyn import sync
from hypothesis import HealthCheck, assume, given, settings
from hypothesis import strategies as st
from tqdm import tqdm

from datachain.asyn import get_loop
from datachain.client import Client
from tests.data import ENTRIES
from tests.utils import uppercase_scheme

_non_null_text = st.text(
alphabet=st.characters(blacklist_categories=["Cc", "Cs"]), min_size=1
)


@pytest.fixture
Expand Down Expand Up @@ -91,3 +101,89 @@ def test_fetch_dir_does_not_return_self(client, cloud_type):
)

assert "directory" not in subdirs


@settings(suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None)
@given(rel_path=_non_null_text)
def test_parse_url(cloud_test_catalog, rel_path, cloud_type):
if cloud_type == "file":
assume(not rel_path.startswith("/"))
bucket_uri = cloud_test_catalog.src_uri
url = f"{bucket_uri}/{rel_path}"
uri, rel_part = Client.parse_url(url)
if cloud_type == "file":
assert uri == url.rsplit("/", 1)[0]
assert rel_part == url.rsplit("/", 1)[1]
else:
assert uri == bucket_uri
assert rel_part == rel_path


@settings(suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None)
@given(rel_path=_non_null_text)
def test_get_client(cloud_test_catalog, rel_path, cloud_type):
catalog = cloud_test_catalog.catalog
bucket_uri = cloud_test_catalog.src_uri
url = f"{bucket_uri}/{rel_path}"
client = Client.get_client(url, catalog.cache)
assert client
assert client.uri


@settings(suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=None)
@given(rel_path=_non_null_text)
def test_parse_url_uppercase_scheme(cloud_test_catalog, rel_path, cloud_type):
if cloud_type == "file":
assume(not rel_path.startswith("/"))
bucket_uri = cloud_test_catalog.src_uri
bucket_uri_upper = uppercase_scheme(bucket_uri)
url = f"{bucket_uri_upper}/{rel_path}"
uri, rel_part = Client.parse_url(url)
if cloud_type == "file":
url = f"{bucket_uri}/{rel_path}"
assert uri == url.rsplit("/", 1)[0]
assert rel_part == url.rsplit("/", 1)[1]
else:
assert uri == bucket_uri
assert rel_part == rel_path


@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
def test_parse_file_absolute_path_without_protocol(cloud_test_catalog):
working_dir = Path().absolute()
uri, rel_part = Client.parse_url(str(working_dir / Path("animals")))
assert uri == working_dir.as_uri()
assert rel_part == "animals"


@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
def test_parse_file_relative_path_multiple_dirs_back(cloud_test_catalog):
uri, rel_part = Client.parse_url("../../animals".replace("/", os.sep))
assert uri == Path().absolute().parents[1].as_uri()
assert rel_part == "animals"


@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
@pytest.mark.parametrize("url", ["./animals".replace("/", os.sep), "animals"])
def test_parse_file_relative_path_working_dir(cloud_test_catalog, url):
uri, rel_part = Client.parse_url(url)
assert uri == Path().absolute().as_uri()
assert rel_part == "animals"


@pytest.mark.parametrize("cloud_type", ["file"], indirect=True)
def test_parse_file_relative_path_home_dir(cloud_test_catalog):
if sys.platform == "win32":
# home dir shortcut is not available on windows
pytest.skip()
uri, rel_part = Client.parse_url("~/animals")
assert uri == Path().home().as_uri()
assert rel_part == "animals"


@pytest.mark.parametrize("cloud_type", ["s3", "azure", "gs"], indirect=True)
def test_parse_cloud_path_ends_with_slash(cloud_test_catalog):
uri = f"{cloud_test_catalog.src_uri}/animals/"
uri, rel_part = Client.parse_url(uri)
assert uri == cloud_test_catalog.src_uri
assert rel_part == "animals/"
160 changes: 160 additions & 0 deletions tests/func/test_data_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from datetime import datetime
from typing import Any

import pytest

from datachain.sql.types import (
JSON,
Array,
Boolean,
DateTime,
Float,
Float32,
Float64,
Int,
String,
)
from tests.utils import (
DEFAULT_TREE,
TARRED_TREE,
create_tar_dataset_with_legacy_columns,
)

COMPLEX_TREE: dict[str, Any] = {
**TARRED_TREE,
**DEFAULT_TREE,
"nested": {"dir": {"path": {"abc.txt": "abc"}}},
}


@pytest.mark.parametrize("tree", [COMPLEX_TREE], indirect=True)
def test_dir_expansion(cloud_test_catalog, version_aware, cloud_type):
has_version = version_aware or cloud_type == "gs"

ctc = cloud_test_catalog
session = ctc.session
catalog = ctc.catalog
src_uri = ctc.src_uri
if cloud_type == "file":
# we don't want to index things in parent directory
src_uri += "/"

dc = create_tar_dataset_with_legacy_columns(session, ctc.src_uri, "dc")
dataset = catalog.get_dataset(dc.name)
with catalog.warehouse.clone() as warehouse:
dr = warehouse.dataset_rows(dataset, object_name="file")
de = dr.dir_expansion()
q = de.query(dr.get_table())

columns = (
"id",
"is_dir",
"source",
"path",
"version",
"location",
)

result = [dict(zip(columns, r)) for r in warehouse.db.execute(q)]
to_compare = [(r["path"], r["is_dir"], r["version"] != "") for r in result]

assert all(r["source"] == ctc.src_uri for r in result)

# Note, we have both a file and a directory entry for expanded tar files
expected = [
("animals.tar", 0, has_version),
("animals.tar", 1, False),
("animals.tar/cats", 1, False),
("animals.tar/cats/cat1", 0, has_version),
("animals.tar/cats/cat2", 0, has_version),
("animals.tar/description", 0, has_version),
("animals.tar/dogs", 1, False),
("animals.tar/dogs/dog1", 0, has_version),
("animals.tar/dogs/dog2", 0, has_version),
("animals.tar/dogs/dog3", 0, has_version),
("animals.tar/dogs/others", 1, False),
("animals.tar/dogs/others/dog4", 0, has_version),
("cats", 1, False),
("cats/cat1", 0, has_version),
("cats/cat2", 0, has_version),
("description", 0, has_version),
("dogs", 1, False),
("dogs/dog1", 0, has_version),
("dogs/dog2", 0, has_version),
("dogs/dog3", 0, has_version),
("dogs/others", 1, False),
("dogs/others/dog4", 0, has_version),
("nested", 1, False),
("nested/dir", 1, False),
("nested/dir/path", 1, False),
("nested/dir/path/abc.txt", 0, has_version),
]

assert to_compare == expected


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
indirect=True,
)
def test_convert_type(cloud_test_catalog):
Comment on lines +96 to +101
Copy link
Member

@skshetry skshetry Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test looks like an unittest. Unrelated, but this probably does not s3 fixture (I am not asking you to fix this, just a drive-by comment).

ctc = cloud_test_catalog
catalog = ctc.catalog
warehouse = catalog.warehouse
now = datetime.now()

def run_convert_type(value, sql_type):
return warehouse.convert_type(
value,
sql_type,
warehouse.python_type(sql_type),
type(sql_type).__name__,
"test_column",
)

# convert int to float
for f in [Float, Float32, Float64]:
converted = run_convert_type(1, f())
assert converted == 1.0
assert isinstance(converted, float)

# types match, nothing to convert
assert run_convert_type(1, Int()) == 1
assert run_convert_type(1.5, Float()) == 1.5
assert run_convert_type(True, Boolean()) is True
assert run_convert_type("s", String()) == "s"
assert run_convert_type(now, DateTime()) == now
assert run_convert_type([1, 2], Array(Int)) == [1, 2]
assert run_convert_type([1.5, 2.5], Array(Float)) == [1.5, 2.5]
assert run_convert_type(["a", "b"], Array(String)) == ["a", "b"]
assert run_convert_type([[1, 2], [3, 4]], Array(Array(Int))) == [
[1, 2],
[3, 4],
]

# JSON Tests
assert run_convert_type('{"a": 1}', JSON()) == '{"a": 1}'
assert run_convert_type({"a": 1}, JSON()) == '{"a": 1}'
assert run_convert_type([{"a": 1}], JSON()) == '[{"a": 1}]'
with pytest.raises(ValueError):
run_convert_type(0.5, JSON())

# convert array to compatible type
converted = run_convert_type([1, 2], Array(Float))
assert converted == [1.0, 2.0]
assert all(isinstance(c, float) for c in converted)

# convert nested array to compatible type
converted = run_convert_type([[1, 2], [3, 4]], Array(Array(Float)))
assert converted == [[1.0, 2.0], [3.0, 4.0]]
assert all(isinstance(c, float) for c in converted[0])
assert all(isinstance(c, float) for c in converted[1])

# error, float to int
with pytest.raises(ValueError):
run_convert_type(1.5, Int())

# error, float to int in list
with pytest.raises(ValueError):
run_convert_type([1.5, 1], Array(Int))
89 changes: 89 additions & 0 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import json
import math
import os
import pickle
Expand Down Expand Up @@ -555,6 +556,23 @@ def test_mutate_existing_column(test_session):
assert list(ds.order_by("ids").collect()) == [(2,), (3,), (4,)]


@pytest.mark.parametrize("processes", [False, 2, True])
@pytest.mark.xdist_group(name="tmpfile")
def test_parallel(processes, test_session_tmpfile):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the reason to move this to func?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Run time, the fact it is write to disk and the group it is in

prefix = "t & "
vals = ["a", "b", "c", "d", "e", "f", "g", "h", "i"]

res = list(
DataChain.from_values(key=vals, session=test_session_tmpfile)
.settings(parallel=processes)
.map(res=lambda key: prefix + key)
.order_by("res")
.collect("res")
)

assert res == [prefix + v for v in vals]


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
Expand Down Expand Up @@ -613,6 +631,36 @@ def name_len(name):
assert count == 7


@pytest.mark.xdist_group(name="tmpfile")
def test_udf_parallel_boostrap(test_session_tmpfile):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here: what was the reason to move this to func tests?

vals = ["a", "b", "c", "d", "e", "f"]

class MyMapper(Mapper):
DEFAULT_VALUE = 84
BOOTSTRAP_VALUE = 1452
TEARDOWN_VALUE = 98763

def __init__(self):
super().__init__()
self.value = MyMapper.DEFAULT_VALUE
self._had_teardown = False

def process(self, *args) -> int:
return self.value

def setup(self):
self.value = MyMapper.BOOTSTRAP_VALUE

def teardown(self):
self.value = MyMapper.TEARDOWN_VALUE

chain = DataChain.from_values(key=vals, session=test_session_tmpfile)

res = list(chain.settings(parallel=4).map(res=MyMapper()).collect("res"))

assert res == [MyMapper.BOOTSTRAP_VALUE] * len(vals)


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
Expand Down Expand Up @@ -1653,6 +1701,47 @@ def test_to_from_parquet_partitioned_remote(cloud_test_catalog_upload):
assert df_equal(df1, df)


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_to_from_json(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
dc_to = DataChain.from_pandas(df, session=test_session)
path = tmp_dir / "test.json"
dc_to.order_by("first_name", "age").to_json(path)

with open(path) as f:
values = json.load(f)
assert values == [
{"first_name": n, "age": a, "city": c}
for n, a, c in zip(DF_DATA["first_name"], DF_DATA["age"], DF_DATA["city"])
]

dc_from = DataChain.from_json(path.as_uri(), session=test_session)
df1 = dc_from.select("json.first_name", "json.age", "json.city").to_pandas()
df1 = df1["json"]
assert df_equal(df1, df)


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_from_json_jmespath(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
values = [
{"first_name": n, "age": a, "city": c}
for n, a, c in zip(DF_DATA["first_name"], DF_DATA["age"], DF_DATA["city"])
]
path = tmp_dir / "test.json"
with open(path, "w") as f:
json.dump({"author": "Test User", "version": 5, "values": values}, f)

dc_from = DataChain.from_json(
path.as_uri(), jmespath="values", session=test_session
)
df1 = dc_from.select("values.first_name", "values.age", "values.city").to_pandas()
df1 = df1["values"]
assert df_equal(df1, df)


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_to_from_json_remote(cloud_test_catalog_upload):
Expand Down
Loading
Loading