-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
move functional tests out of unit test suite #832
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
from datetime import datetime | ||
from typing import Any | ||
|
||
import pytest | ||
|
||
from datachain.sql.types import ( | ||
JSON, | ||
Array, | ||
Boolean, | ||
DateTime, | ||
Float, | ||
Float32, | ||
Float64, | ||
Int, | ||
String, | ||
) | ||
from tests.utils import ( | ||
DEFAULT_TREE, | ||
TARRED_TREE, | ||
create_tar_dataset_with_legacy_columns, | ||
) | ||
|
||
COMPLEX_TREE: dict[str, Any] = { | ||
**TARRED_TREE, | ||
**DEFAULT_TREE, | ||
"nested": {"dir": {"path": {"abc.txt": "abc"}}}, | ||
} | ||
|
||
|
||
@pytest.mark.parametrize("tree", [COMPLEX_TREE], indirect=True) | ||
def test_dir_expansion(cloud_test_catalog, version_aware, cloud_type): | ||
has_version = version_aware or cloud_type == "gs" | ||
|
||
ctc = cloud_test_catalog | ||
session = ctc.session | ||
catalog = ctc.catalog | ||
src_uri = ctc.src_uri | ||
if cloud_type == "file": | ||
# we don't want to index things in parent directory | ||
src_uri += "/" | ||
|
||
dc = create_tar_dataset_with_legacy_columns(session, ctc.src_uri, "dc") | ||
dataset = catalog.get_dataset(dc.name) | ||
with catalog.warehouse.clone() as warehouse: | ||
dr = warehouse.dataset_rows(dataset, object_name="file") | ||
de = dr.dir_expansion() | ||
q = de.query(dr.get_table()) | ||
|
||
columns = ( | ||
"id", | ||
"is_dir", | ||
"source", | ||
"path", | ||
"version", | ||
"location", | ||
) | ||
|
||
result = [dict(zip(columns, r)) for r in warehouse.db.execute(q)] | ||
to_compare = [(r["path"], r["is_dir"], r["version"] != "") for r in result] | ||
|
||
assert all(r["source"] == ctc.src_uri for r in result) | ||
|
||
# Note, we have both a file and a directory entry for expanded tar files | ||
expected = [ | ||
("animals.tar", 0, has_version), | ||
("animals.tar", 1, False), | ||
("animals.tar/cats", 1, False), | ||
("animals.tar/cats/cat1", 0, has_version), | ||
("animals.tar/cats/cat2", 0, has_version), | ||
("animals.tar/description", 0, has_version), | ||
("animals.tar/dogs", 1, False), | ||
("animals.tar/dogs/dog1", 0, has_version), | ||
("animals.tar/dogs/dog2", 0, has_version), | ||
("animals.tar/dogs/dog3", 0, has_version), | ||
("animals.tar/dogs/others", 1, False), | ||
("animals.tar/dogs/others/dog4", 0, has_version), | ||
("cats", 1, False), | ||
("cats/cat1", 0, has_version), | ||
("cats/cat2", 0, has_version), | ||
("description", 0, has_version), | ||
("dogs", 1, False), | ||
("dogs/dog1", 0, has_version), | ||
("dogs/dog2", 0, has_version), | ||
("dogs/dog3", 0, has_version), | ||
("dogs/others", 1, False), | ||
("dogs/others/dog4", 0, has_version), | ||
("nested", 1, False), | ||
("nested/dir", 1, False), | ||
("nested/dir/path", 1, False), | ||
("nested/dir/path/abc.txt", 0, has_version), | ||
] | ||
|
||
assert to_compare == expected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"cloud_type,version_aware", | ||
[("s3", True)], | ||
indirect=True, | ||
) | ||
def test_convert_type(cloud_test_catalog): | ||
ctc = cloud_test_catalog | ||
catalog = ctc.catalog | ||
warehouse = catalog.warehouse | ||
now = datetime.now() | ||
|
||
def run_convert_type(value, sql_type): | ||
return warehouse.convert_type( | ||
value, | ||
sql_type, | ||
warehouse.python_type(sql_type), | ||
type(sql_type).__name__, | ||
"test_column", | ||
) | ||
|
||
# convert int to float | ||
for f in [Float, Float32, Float64]: | ||
converted = run_convert_type(1, f()) | ||
assert converted == 1.0 | ||
assert isinstance(converted, float) | ||
|
||
# types match, nothing to convert | ||
assert run_convert_type(1, Int()) == 1 | ||
assert run_convert_type(1.5, Float()) == 1.5 | ||
assert run_convert_type(True, Boolean()) is True | ||
assert run_convert_type("s", String()) == "s" | ||
assert run_convert_type(now, DateTime()) == now | ||
assert run_convert_type([1, 2], Array(Int)) == [1, 2] | ||
assert run_convert_type([1.5, 2.5], Array(Float)) == [1.5, 2.5] | ||
assert run_convert_type(["a", "b"], Array(String)) == ["a", "b"] | ||
assert run_convert_type([[1, 2], [3, 4]], Array(Array(Int))) == [ | ||
[1, 2], | ||
[3, 4], | ||
] | ||
|
||
# JSON Tests | ||
assert run_convert_type('{"a": 1}', JSON()) == '{"a": 1}' | ||
assert run_convert_type({"a": 1}, JSON()) == '{"a": 1}' | ||
assert run_convert_type([{"a": 1}], JSON()) == '[{"a": 1}]' | ||
with pytest.raises(ValueError): | ||
run_convert_type(0.5, JSON()) | ||
|
||
# convert array to compatible type | ||
converted = run_convert_type([1, 2], Array(Float)) | ||
assert converted == [1.0, 2.0] | ||
assert all(isinstance(c, float) for c in converted) | ||
|
||
# convert nested array to compatible type | ||
converted = run_convert_type([[1, 2], [3, 4]], Array(Array(Float))) | ||
assert converted == [[1.0, 2.0], [3.0, 4.0]] | ||
assert all(isinstance(c, float) for c in converted[0]) | ||
assert all(isinstance(c, float) for c in converted[1]) | ||
|
||
# error, float to int | ||
with pytest.raises(ValueError): | ||
run_convert_type(1.5, Int()) | ||
|
||
# error, float to int in list | ||
with pytest.raises(ValueError): | ||
run_convert_type([1.5, 1], Array(Int)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import functools | ||
import json | ||
import math | ||
import os | ||
import pickle | ||
|
@@ -555,6 +556,23 @@ def test_mutate_existing_column(test_session): | |
assert list(ds.order_by("ids").collect()) == [(2,), (3,), (4,)] | ||
|
||
|
||
@pytest.mark.parametrize("processes", [False, 2, True]) | ||
@pytest.mark.xdist_group(name="tmpfile") | ||
def test_parallel(processes, test_session_tmpfile): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What was the reason to move this to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Run time, the fact it is write to disk and the group it is in |
||
prefix = "t & " | ||
vals = ["a", "b", "c", "d", "e", "f", "g", "h", "i"] | ||
|
||
res = list( | ||
DataChain.from_values(key=vals, session=test_session_tmpfile) | ||
.settings(parallel=processes) | ||
.map(res=lambda key: prefix + key) | ||
.order_by("res") | ||
.collect("res") | ||
) | ||
|
||
assert res == [prefix + v for v in vals] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"cloud_type,version_aware", | ||
[("s3", True)], | ||
|
@@ -613,6 +631,36 @@ def name_len(name): | |
assert count == 7 | ||
|
||
|
||
@pytest.mark.xdist_group(name="tmpfile") | ||
def test_udf_parallel_boostrap(test_session_tmpfile): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question here: what was the reason to move this to |
||
vals = ["a", "b", "c", "d", "e", "f"] | ||
|
||
class MyMapper(Mapper): | ||
DEFAULT_VALUE = 84 | ||
BOOTSTRAP_VALUE = 1452 | ||
TEARDOWN_VALUE = 98763 | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.value = MyMapper.DEFAULT_VALUE | ||
self._had_teardown = False | ||
|
||
def process(self, *args) -> int: | ||
return self.value | ||
|
||
def setup(self): | ||
self.value = MyMapper.BOOTSTRAP_VALUE | ||
|
||
def teardown(self): | ||
self.value = MyMapper.TEARDOWN_VALUE | ||
|
||
chain = DataChain.from_values(key=vals, session=test_session_tmpfile) | ||
|
||
res = list(chain.settings(parallel=4).map(res=MyMapper()).collect("res")) | ||
|
||
assert res == [MyMapper.BOOTSTRAP_VALUE] * len(vals) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"cloud_type,version_aware", | ||
[("s3", True)], | ||
|
@@ -1653,6 +1701,47 @@ def test_to_from_parquet_partitioned_remote(cloud_test_catalog_upload): | |
assert df_equal(df1, df) | ||
|
||
|
||
# These deprecation warnings occur in the datamodel-code-generator package. | ||
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20") | ||
def test_to_from_json(tmp_dir, test_session): | ||
df = pd.DataFrame(DF_DATA) | ||
dc_to = DataChain.from_pandas(df, session=test_session) | ||
path = tmp_dir / "test.json" | ||
dc_to.order_by("first_name", "age").to_json(path) | ||
|
||
with open(path) as f: | ||
values = json.load(f) | ||
assert values == [ | ||
{"first_name": n, "age": a, "city": c} | ||
for n, a, c in zip(DF_DATA["first_name"], DF_DATA["age"], DF_DATA["city"]) | ||
] | ||
|
||
dc_from = DataChain.from_json(path.as_uri(), session=test_session) | ||
df1 = dc_from.select("json.first_name", "json.age", "json.city").to_pandas() | ||
df1 = df1["json"] | ||
assert df_equal(df1, df) | ||
|
||
|
||
# These deprecation warnings occur in the datamodel-code-generator package. | ||
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20") | ||
def test_from_json_jmespath(tmp_dir, test_session): | ||
df = pd.DataFrame(DF_DATA) | ||
values = [ | ||
{"first_name": n, "age": a, "city": c} | ||
for n, a, c in zip(DF_DATA["first_name"], DF_DATA["age"], DF_DATA["city"]) | ||
] | ||
path = tmp_dir / "test.json" | ||
with open(path, "w") as f: | ||
json.dump({"author": "Test User", "version": 5, "values": values}, f) | ||
|
||
dc_from = DataChain.from_json( | ||
path.as_uri(), jmespath="values", session=test_session | ||
) | ||
df1 = dc_from.select("values.first_name", "values.age", "values.city").to_pandas() | ||
df1 = df1["values"] | ||
assert df_equal(df1, df) | ||
|
||
|
||
# These deprecation warnings occur in the datamodel-code-generator package. | ||
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20") | ||
def test_to_from_json_remote(cloud_test_catalog_upload): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test looks like an unittest. Unrelated, but this probably does not
s3
fixture (I am not asking you to fix this, just a drive-by comment).