diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index 5d508f0d17ff..95074553acd7 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -1,5 +1,7 @@ import re import sys +from functools import partial, update_wrapper +from typing import Dict, Union import numpy as np import pytest @@ -13,8 +15,8 @@ def test_rabit_tracker(): tracker = RabitTracker(host_ip="127.0.0.1", n_workers=1) tracker.start() - with xgb.collective.CommunicatorContext(**tracker.worker_args()): - ret = xgb.collective.broadcast("test1234", 0) + with collective.CommunicatorContext(**tracker.worker_args()): + ret = collective.broadcast("test1234", 0) assert str(ret) == "test1234" @@ -26,7 +28,7 @@ def test_socket_error(): env["dmlc_tracker_port"] = 0 env["dmlc_retry"] = 1 with pytest.raises(ValueError, match="Failed to bootstrap the communication."): - with xgb.collective.CommunicatorContext(**env): + with collective.CommunicatorContext(**env): pass with pytest.raises(ValueError): tracker.free() @@ -70,16 +72,15 @@ def test_rabit_ops(): run_rabit_ops(client, n_workers) -def run_allreduce(client) -> None: - from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args - workers = tm.get_client_workers(client) - rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client) - n_workers = len(workers) +def run_allreduce(pool, n_workers: int) -> None: + tracker = RabitTracker(host_ip="127.0.0.1", n_workers=n_workers) + tracker.start() + args = tracker.worker_args() - def local_test(worker_id: int) -> None: + def local_test(worker_id: int, rabit_args: Dict[str, Union[str, int]]) -> None: x = np.full(shape=(1024 * 1024 * 32), fill_value=1.0) - with CommunicatorContext(**rabit_args): + with collective.CommunicatorContext(**rabit_args): k = np.asarray([1.0]) for i in range(128): m = collective.allreduce(k, collective.Op.SUM) @@ -88,46 +89,48 @@ def local_test(worker_id: int) -> None: y = collective.allreduce(x, collective.Op.SUM) np.testing.assert_allclose(y, np.full_like(y, fill_value=float(n_workers))) - futures = client.map(local_test, range(len(workers)), workers=workers) - results = client.gather(futures) + fn = update_wrapper(partial(local_test, rabit_args=args), local_test) + results = pool.map(fn, range(n_workers)) + for r in results: + assert r is None -@pytest.mark.skipif(**tm.no_dask()) +@pytest.mark.skipif(**tm.no_loky()) def test_allreduce() -> None: - from distributed import Client, LocalCluster + from loky import get_reusable_executor n_workers = 4 - for i in range(2): - with LocalCluster(n_workers=n_workers) as cluster: - with Client(cluster) as client: - for i in range(2): - run_allreduce(client) - + n_trials = 2 + for _ in range(n_trials): + with get_reusable_executor(max_workers=n_workers) as pool: + run_allreduce(pool, n_workers) -def run_broadcast(client): - from xgboost.dask import _get_dask_config, _get_rabit_args - workers = tm.get_client_workers(client) - rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client) +def run_broadcast(pool, n_workers: int) -> None: + tracker = RabitTracker(host_ip="127.0.0.1", n_workers=n_workers) + tracker.start() + args = tracker.worker_args() - def local_test(worker_id): + def local_test(worker_id: int, rabit_args: Dict[str, Union[str, int]]): with collective.CommunicatorContext(**rabit_args): res = collective.broadcast(17, 0) return res - futures = client.map(local_test, range(len(workers)), workers=workers) - results = client.gather(futures) - np.testing.assert_allclose(np.array(results), 17) + fn = update_wrapper(partial(local_test, rabit_args=args), local_test) + results = pool.map(fn, range(n_workers)) + np.testing.assert_allclose(np.array(list(results)), 17) -@pytest.mark.skipif(**tm.no_dask()) +@pytest.mark.skipif(**tm.no_loky()) def test_broadcast(): - from distributed import Client, LocalCluster + from loky import get_reusable_executor - n_workers = 3 - with LocalCluster(n_workers=n_workers) as cluster: - with Client(cluster) as client: - run_broadcast(client) + n_workers = 4 + n_trials = 2 + + for _ in range(n_trials): + with get_reusable_executor(max_workers=n_workers) as pool: + run_broadcast(pool, n_workers) @pytest.mark.skipif(**tm.no_ipv6()) @@ -151,7 +154,7 @@ def local_test(worker_id): with xgb.dask.CommunicatorContext(**args) as ctx: task_id = ctx["DMLC_TASK_ID"] matched = re.search(".*-([0-9]).*", task_id) - rank = xgb.collective.get_rank() + rank = collective.get_rank() # As long as the number of workers is lesser than 10, rank and worker id # should be the same assert rank == int(matched.group(1)) @@ -170,21 +173,12 @@ def local_test(worker_id): client.gather(futures) -@pytest.fixture -def local_cluster(): - from distributed import LocalCluster - - n_workers = 8 - with LocalCluster(n_workers=n_workers, dashboard_address=":0") as cluster: - yield cluster - - ops_strategy = strategies.lists( strategies.sampled_from(["broadcast", "allreduce_max", "allreduce_sum"]) ) -@pytest.mark.skipif(**tm.no_dask()) +@pytest.mark.skipif(**tm.no_loky()) @given(ops=ops_strategy, size=strategies.integers(2**4, 2**16)) @settings( deadline=None, @@ -192,12 +186,14 @@ def local_cluster(): max_examples=10, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_ops_restart_comm(local_cluster, ops, size) -> None: - from distributed import Client +def test_ops_restart_comm(ops, size) -> None: + from loky import get_reusable_executor + + n_workers = 8 - def local_test(w: int, n_workers: int) -> None: + def local_test(w: int, rabit_args: Dict[str, Union[str, int]]) -> None: a = np.arange(0, n_workers) - with xgb.dask.CommunicatorContext(**args): + with collective.CommunicatorContext(**rabit_args): for op in ops: if op == "broadcast": b = collective.broadcast(a, root=1) @@ -211,27 +207,21 @@ def local_test(w: int, n_workers: int) -> None: else: raise ValueError() - with Client(local_cluster) as client: - workers = tm.get_client_workers(client) - args = client.sync( - xgb.dask._get_rabit_args, - len(workers), - None, - client, - ) + with get_reusable_executor(max_workers=n_workers) as pool: + tracker = RabitTracker(host_ip="127.0.0.1", n_workers=n_workers) + tracker.start() + args = tracker.worker_args() - workers = tm.get_client_workers(client) - n_workers = len(workers) + fn = update_wrapper(partial(local_test, rabit_args=args), local_test) + results = pool.map(fn, range(n_workers)) - futures = client.map( - local_test, range(len(workers)), workers=workers, n_workers=n_workers - ) - client.gather(futures) + for r in results: + assert r is None -@pytest.mark.skipif(**tm.no_dask()) -def test_ops_reuse_comm(local_cluster) -> None: - from distributed import Client +@pytest.mark.skipif(**tm.no_loky()) +def test_ops_reuse_comm() -> None: + from loky import get_reusable_executor rng = np.random.default_rng(1994) n_examples = 10 @@ -239,10 +229,13 @@ def test_ops_reuse_comm(local_cluster) -> None: ["broadcast", "allreduce_sum", "allreduce_max"], size=n_examples ).tolist() - def local_test(w: int, n_workers: int) -> None: + n_workers = 8 + n_trials = 8 + + def local_test(w: int, rabit_args: Dict[str, Union[str, int]]) -> None: a = np.arange(0, n_workers) - with xgb.dask.CommunicatorContext(**args): + with collective.CommunicatorContext(**rabit_args): for op in ops: if op == "broadcast": b = collective.broadcast(a, root=1) @@ -257,18 +250,13 @@ def local_test(w: int, n_workers: int) -> None: else: raise ValueError() - with Client(local_cluster) as client: - workers = tm.get_client_workers(client) - args = client.sync( - xgb.dask._get_rabit_args, - len(workers), - None, - client, - ) - - n_workers = len(workers) - - futures = client.map( - local_test, range(len(workers)), workers=workers, n_workers=n_workers - ) - client.gather(futures) + with get_reusable_executor(max_workers=n_workers) as pool: + for _ in range(n_trials): + tracker = RabitTracker(host_ip="127.0.0.1", n_workers=n_workers) + tracker.start() + args = tracker.worker_args() + + fn = update_wrapper(partial(local_test, rabit_args=args), local_test) + results = pool.map(fn, range(n_workers)) + for r in results: + assert r is None