Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[coll] Use loky for rabit op tests. #10828

Merged
merged 2 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python-package/xgboost/testing/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,13 @@ def check_extmem_qdm(
)

booster_it = xgb.train({"device": device}, Xy_it, num_boost_round=8)
X, y, w = it.as_arrays()
Xy = xgb.QuantileDMatrix(X, y, weight=w)
it = tm.IteratorForTest(
*tm.make_batches(
n_samples_per_batch, n_features, n_batches, use_cupy=device != "cpu"
),
cache=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cache=None,
cache="cache",
on_host=on_host,

Did you mean to re-create the same iterator?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not the same, the second one doesn't cache data to the disk.

)
Xy = xgb.QuantileDMatrix(it)
booster = xgb.train({"device": device}, Xy, num_boost_round=8)

if device == "cpu":
Expand Down
61 changes: 26 additions & 35 deletions tests/python/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,44 +34,48 @@ def test_socket_error():
tracker.free()


def run_rabit_ops(client, n_workers):
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args

workers = tm.get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
assert not collective.is_distributed()
n_workers_from_dask = len(workers)
assert n_workers == n_workers_from_dask
def run_rabit_ops(pool, n_workers: int, address: str) -> None:
tracker = RabitTracker(host_ip=address, n_workers=n_workers)
tracker.start()
args = tracker.worker_args()

def local_test(worker_id):
with CommunicatorContext(**rabit_args):
def local_test(worker_id: int, rabit_args: dict) -> int:
with collective.CommunicatorContext(**rabit_args):
a = 1
assert collective.is_distributed()
a = np.array([a])
reduced = collective.allreduce(a, collective.Op.SUM)
arr = np.array([a])
reduced = collective.allreduce(arr, collective.Op.SUM)
assert reduced[0] == n_workers

worker_id = np.array([worker_id])
reduced = collective.allreduce(worker_id, collective.Op.MAX)
arr = np.array([worker_id])
reduced = collective.allreduce(arr, collective.Op.MAX)
assert reduced == n_workers - 1

return 1

futures = client.map(local_test, range(len(workers)), workers=workers)
results = client.gather(futures)
fn = update_wrapper(partial(local_test, rabit_args=args), local_test)
results = pool.map(fn, range(n_workers))
assert sum(results) == n_workers


@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_loky())
def test_rabit_ops():
from distributed import Client, LocalCluster
from loky import get_reusable_executor

n_workers = 3
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)
n_workers = 4
with get_reusable_executor(max_workers=n_workers) as pool:
run_rabit_ops(pool, n_workers, "127.0.0.1")


@pytest.mark.skipif(**tm.no_ipv6())
@pytest.mark.skipif(**tm.no_loky())
def test_rabit_ops_ipv6():
from loky import get_reusable_executor

n_workers = 4
with get_reusable_executor(max_workers=n_workers) as pool:
run_rabit_ops(pool, n_workers, "::1")


def run_allreduce(pool, n_workers: int) -> None:
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=n_workers)
Expand Down Expand Up @@ -133,19 +137,6 @@ def test_broadcast():
run_broadcast(pool, n_workers)


@pytest.mark.skipif(**tm.no_ipv6())
@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops_ipv6():
import dask
from distributed import Client, LocalCluster

n_workers = 3
with dask.config.set({"xgboost.scheduler_address": "[::1]"}):
with LocalCluster(n_workers=n_workers, host="[::1]") as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)


@pytest.mark.skipif(**tm.no_dask())
def test_rank_assignment() -> None:
from distributed import Client, LocalCluster
Expand Down
Loading