Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use bigger fixture tree for distributed tests #4

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
91 changes: 76 additions & 15 deletions tests/func/test_dataset_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import os
import pickle
import random
import signal
import subprocess # nosec B404
import uuid
from datetime import datetime, timedelta, timezone
from json import dumps
from textwrap import dedent
from time import sleep
from unittest.mock import ANY, patch

import numpy as np
Expand Down Expand Up @@ -51,6 +54,7 @@
from tests.data import ENTRIES
from tests.utils import (
DEFAULT_TREE,
LARGE_TREE,
NUM_TREE,
SIMPLE_DS_QUERY_RECORDS,
TARRED_TREE,
Expand All @@ -62,6 +66,52 @@
text_embedding,
)

WORKER_COUNT = 1
WORKER_SHUTDOWN_WAIT_SEC = 30


@pytest.fixture()
def run_datachain_worker():
if not os.environ.get("DATACHAIN_DISTRIBUTED"):
pytest.skip("Distributed tests are disabled")
workers = []
worker_cmd = [
"celery",
"-A",
"datachain_server.distributed",
"worker",
"--loglevel=INFO",
"-P",
"solo",
"-Q",
"datachain-worker",
"-n",
"datachain-worker-tests",
]
workers.append(subprocess.Popen(worker_cmd, shell=False)) # noqa: S603
try:
from datachain_server.distributed import app

inspect = app.control.inspect()
attempts = 0
# Wait 10 seconds for the Celery worker(s) to be up
while not inspect.active() and attempts < 10:
sleep(1)
attempts += 1

if attempts == 10:
raise RuntimeError("Celery worker(s) did not start in time")

yield workers
finally:
for worker in workers:
os.kill(worker.pid, signal.SIGTERM)
for worker in workers:
try:
worker.wait(timeout=WORKER_SHUTDOWN_WAIT_SEC)
except subprocess.TimeoutExpired:
os.kill(worker.pid, signal.SIGKILL)


def from_result_row(col_names, row):
return dict(zip(col_names, row))
Expand Down Expand Up @@ -1075,8 +1125,8 @@ def name_len_interrupt(_name):


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
"cloud_type,version_aware,tree",
[("s3", True, LARGE_TREE)],
indirect=True,
)
@pytest.mark.parametrize("batch", [False, True])
Expand All @@ -1086,7 +1136,14 @@ def name_len_interrupt(_name):
reason="Set the DATACHAIN_DISTRIBUTED environment variable "
"to test distributed UDFs",
)
def test_udf_distributed(cloud_test_catalog_tmpfile, batch, workers, datachain_job_id):
def test_udf_distributed(
cloud_test_catalog_tmpfile,
batch,
workers,
tree,
datachain_job_id,
run_datachain_worker,
):
catalog = cloud_test_catalog_tmpfile.catalog
sources = [cloud_test_catalog_tmpfile.src_uri]
globs = [s.rstrip("/") + "/*" for s in sources]
Expand All @@ -1111,14 +1168,14 @@ def name_len_batch(names):

q = (
DatasetQuery(name="animals", version=1, catalog=catalog)
.filter(C.size < 13)
.filter(C.parent.glob("cats*") | (C.size < 4))
.filter(C.size < 90)
.filter(C.parent.glob("cats*") | (C.size > 30))
.add_signals(udf_func, parallel=2, workers=workers)
.select(C.name, C.name_len, C.blank)
)
result = q.db_results()

assert len(result) == 3
assert len(result) == 148
string_default = String.default_value(catalog.warehouse.db.dialect)
for r in result:
# Check that the UDF ran successfully
Expand All @@ -1127,8 +1184,8 @@ def name_len_batch(names):


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
"cloud_type,version_aware,tree",
[("s3", True, LARGE_TREE)],
indirect=True,
)
@pytest.mark.parametrize("workers", (1, 2))
Expand All @@ -1138,7 +1195,7 @@ def name_len_batch(names):
"to test distributed UDFs",
)
def test_udf_distributed_exec_error(
cloud_test_catalog_tmpfile, workers, datachain_job_id
cloud_test_catalog_tmpfile, workers, datachain_job_id, tree, run_datachain_worker
):
catalog = cloud_test_catalog_tmpfile.catalog
sources = [cloud_test_catalog_tmpfile.src_uri]
Expand All @@ -1162,16 +1219,18 @@ def name_len_error(_name):


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
"cloud_type,version_aware,tree",
[("s3", True, LARGE_TREE)],
indirect=True,
)
@pytest.mark.skipif(
"not os.environ.get('DATACHAIN_DISTRIBUTED')",
reason="Set the DATACHAIN_DISTRIBUTED environment variable "
"to test distributed UDFs",
)
def test_udf_distributed_interrupt(cloud_test_catalog_tmpfile, capfd, datachain_job_id):
def test_udf_distributed_interrupt(
cloud_test_catalog_tmpfile, capfd, datachain_job_id, tree, run_datachain_worker
):
catalog = cloud_test_catalog_tmpfile.catalog
sources = [cloud_test_catalog_tmpfile.src_uri]
globs = [s.rstrip("/") + "/*" for s in sources]
Expand All @@ -1197,16 +1256,18 @@ def name_len_interrupt(_name):


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
"cloud_type,version_aware, tree",
[("s3", True, LARGE_TREE)],
indirect=True,
)
@pytest.mark.skipif(
"not os.environ.get('DATACHAIN_DISTRIBUTED')",
reason="Set the DATACHAIN_DISTRIBUTED environment variable "
"to test distributed UDFs",
)
def test_udf_distributed_cancel(cloud_test_catalog_tmpfile, capfd, datachain_job_id):
def test_udf_distributed_cancel(
cloud_test_catalog_tmpfile, capfd, datachain_job_id, tree, run_datachain_worker
):
catalog = cloud_test_catalog_tmpfile.catalog
metastore = catalog.metastore
sources = [cloud_test_catalog_tmpfile.src_uri]
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/lib/test_datachain_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def teardown(self):
assert udf._had_teardown is False


def test_bootstrap_in_chain():
def test_bootstrap_in_chain(catalog):
base = 1278
prime = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]

Expand All @@ -83,7 +83,7 @@ def test_bootstrap_in_chain():
assert res == [base + val for val in prime]


def test_vars_duplication_error():
def test_vars_duplication_error(catalog):
with pytest.raises(DatasetPrepareError):
(
DataChain.from_values(val=[2, 3, 5, 7, 11, 13, 17, 19, 23, 29])
Expand Down
12 changes: 12 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def make_index(catalog, src: str, entries, ttl: int = 1234):
"others": {"dog4": "ruff"},
},
}

# Need to run in a distributed mode to at least have a decent amount of tasks
# Has the same structure as the DEFAULT_TREE - cats and dogs
LARGE_TREE: dict[str, Any] = {
"description": "Cats and Dogs",
"cats": {f"cat{i}": "a" * i for i in range(1, 128)},
"dogs": {
**{f"dogs{i}": "a" * i for i in range(1, 64)},
"others": {f"dogs{i}": "a" * i for i in range(64, 98)},
},
}

NUM_TREE = {f"{i:06d}": f"{i}" for i in range(1024)}


Expand Down