Skip to content

Commit

Permalink
Add UCXX tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Oct 30, 2023
1 parent 79efe09 commit ab80c58
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 59 deletions.
42 changes: 32 additions & 10 deletions dask_cuda/tests/test_dgx.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,13 @@ def test_default():
assert not p.exitcode


def _test_tcp_over_ucx():
ucp = pytest.importorskip("ucp")
def _test_tcp_over_ucx(protocol):
if protocol == "ucx":
ucp = pytest.importorskip("ucp")
elif protocol == "ucxx":
ucp = pytest.importorskip("ucxx")

with LocalCUDACluster(enable_tcp_over_ucx=True) as cluster:
with LocalCUDACluster(protocol=protocol, enable_tcp_over_ucx=True) as cluster:
with Client(cluster) as client:
res = da.from_array(numpy.arange(10000), chunks=(1000,))
res = res.sum().compute()
Expand All @@ -93,10 +96,17 @@ def check_ucx_options():
assert all(client.run(check_ucx_options).values())


def test_tcp_over_ucx():
ucp = pytest.importorskip("ucp") # NOQA: F841
@pytest.mark.parametrize(
"protocol",
["ucx", "ucxx"],
)
def test_tcp_over_ucx(protocol):
if protocol == "ucx":
pytest.importorskip("ucp")
elif protocol == "ucxx":
pytest.importorskip("ucxx")

p = mp.Process(target=_test_tcp_over_ucx)
p = mp.Process(target=_test_tcp_over_ucx, args=(protocol,))
p.start()
p.join()
assert not p.exitcode
Expand All @@ -117,9 +127,14 @@ def test_tcp_only():
assert not p.exitcode


def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm):
def _test_ucx_infiniband_nvlink(
protocol, enable_infiniband, enable_nvlink, enable_rdmacm
):
cupy = pytest.importorskip("cupy")
ucp = pytest.importorskip("ucp")
if protocol == "ucx":
ucp = pytest.importorskip("ucp")
elif protocol == "ucxx":
ucp = pytest.importorskip("ucxx")

if enable_infiniband is None and enable_nvlink is None and enable_rdmacm is None:
enable_tcp_over_ucx = None
Expand All @@ -135,13 +150,15 @@ def _test_ucx_infiniband_nvlink(enable_infiniband, enable_nvlink, enable_rdmacm)
cm_tls_priority = ["tcp"]

initialize(
protocol=protocol,
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_infiniband=enable_infiniband,
enable_nvlink=enable_nvlink,
enable_rdmacm=enable_rdmacm,
)

with LocalCUDACluster(
protocol=protocol,
interface="ib0",
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_infiniband=enable_infiniband,
Expand Down Expand Up @@ -171,6 +188,7 @@ def check_ucx_options():
assert all(client.run(check_ucx_options).values())


@pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
@pytest.mark.parametrize(
"params",
[
Expand All @@ -185,8 +203,11 @@ def check_ucx_options():
_get_dgx_version() == DGXVersion.DGX_A100,
reason="Automatic InfiniBand device detection Unsupported for %s" % _get_dgx_name(),
)
def test_ucx_infiniband_nvlink(params):
ucp = pytest.importorskip("ucp") # NOQA: F841
def test_ucx_infiniband_nvlink(protocol, params):
if protocol == "ucx":
ucp = pytest.importorskip("ucp")
elif protocol == "ucxx":
ucp = pytest.importorskip("ucxx")

if params["enable_infiniband"]:
if not any([at.startswith("rc") for at in ucp.get_active_transports()]):
Expand All @@ -195,6 +216,7 @@ def test_ucx_infiniband_nvlink(params):
p = mp.Process(
target=_test_ucx_infiniband_nvlink,
args=(
protocol,
params["enable_infiniband"],
params["enable_nvlink"],
params["enable_rdmacm"],
Expand Down
8 changes: 4 additions & 4 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _test_local_cluster(protocol):
assert sum(c.run(my_rank, 0)) == sum(range(4))


@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
def test_local_cluster(protocol):
p = mp.Process(target=_test_local_cluster, args=(protocol,))
p.start()
Expand Down Expand Up @@ -160,7 +160,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):

@pytest.mark.parametrize("nworkers", [1, 2, 3])
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
@pytest.mark.parametrize("_partitions", [True, False])
def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
if backend == "cudf":
Expand Down Expand Up @@ -256,7 +256,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):

@pytest.mark.parametrize("nworkers", [1, 2, 4])
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
def test_dataframe_shuffle_merge(backend, protocol, nworkers):
if backend == "cudf":
pytest.importorskip("cudf")
Expand Down Expand Up @@ -293,7 +293,7 @@ def _test_jit_unspill(protocol):
assert_eq(got, expected)


@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx", "ucxx"])
def test_jit_unspill(protocol):
pytest.importorskip("cudf")

Expand Down
8 changes: 6 additions & 2 deletions dask_cuda/tests/test_from_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@

from dask_cuda import LocalCUDACluster

pytest.importorskip("ucp")
cupy = pytest.importorskip("cupy")


@pytest.mark.parametrize("protocol", ["ucx", "tcp"])
@pytest.mark.parametrize("protocol", ["ucx", "ucxx", "tcp"])
def test_ucx_from_array(protocol):
if protocol == "ucx":
pytest.importorskip("ucp")
elif protocol == "ucxx":
pytest.importorskip("ucxx")

N = 10_000
with LocalCUDACluster(protocol=protocol) as cluster:
with Client(cluster):
Expand Down
85 changes: 64 additions & 21 deletions dask_cuda/tests/test_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,23 @@
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny

mp = mp.get_context("spawn") # type: ignore
ucp = pytest.importorskip("ucp")

# Notice, all of the following tests is executed in a new process such
# that UCX options of the different tests doesn't conflict.
# Furthermore, all tests do some computation to trigger initialization
# of UCX before retrieving the current config.


def _test_initialize_ucx_tcp():
def _test_initialize_ucx_tcp(protocol):
if protocol == "ucx":
ucp = pytest.importorskip("ucp")
elif protocol == "ucxx":
ucp = pytest.importorskip("ucxx")

kwargs = {"enable_tcp_over_ucx": True}
initialize(**kwargs)
initialize(protocol=protocol, **kwargs)
with LocalCluster(
protocol="ucx",
protocol=protocol,
dashboard_address=None,
n_workers=1,
threads_per_worker=1,
Expand All @@ -50,18 +54,29 @@ def check_ucx_options():
assert all(client.run(check_ucx_options).values())


def test_initialize_ucx_tcp():
p = mp.Process(target=_test_initialize_ucx_tcp)
@pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
def test_initialize_ucx_tcp(protocol):
if protocol == "ucx":
pytest.importorskip("ucp")
elif protocol == "ucxx":
pytest.importorskip("ucxx")

p = mp.Process(target=_test_initialize_ucx_tcp, args=(protocol,))
p.start()
p.join()
assert not p.exitcode


def _test_initialize_ucx_nvlink():
def _test_initialize_ucx_nvlink(protocol):
if protocol == "ucx":
ucp = pytest.importorskip("ucp")
elif protocol == "ucxx":
ucp = pytest.importorskip("ucxx")

kwargs = {"enable_nvlink": True}
initialize(**kwargs)
initialize(protocol=protocol, **kwargs)
with LocalCluster(
protocol="ucx",
protocol=protocol,
dashboard_address=None,
n_workers=1,
threads_per_worker=1,
Expand All @@ -87,18 +102,29 @@ def check_ucx_options():
assert all(client.run(check_ucx_options).values())


def test_initialize_ucx_nvlink():
p = mp.Process(target=_test_initialize_ucx_nvlink)
@pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
def test_initialize_ucx_nvlink(protocol):
if protocol == "ucx":
pytest.importorskip("ucp")
elif protocol == "ucxx":
pytest.importorskip("ucxx")

p = mp.Process(target=_test_initialize_ucx_nvlink, args=(protocol,))
p.start()
p.join()
assert not p.exitcode


def _test_initialize_ucx_infiniband():
def _test_initialize_ucx_infiniband(protocol):
if protocol == "ucx":
ucp = pytest.importorskip("ucp")
elif protocol == "ucxx":
ucp = pytest.importorskip("ucxx")

kwargs = {"enable_infiniband": True}
initialize(**kwargs)
initialize(protocol=protocol, **kwargs)
with LocalCluster(
protocol="ucx",
protocol=protocol,
dashboard_address=None,
n_workers=1,
threads_per_worker=1,
Expand Down Expand Up @@ -127,17 +153,28 @@ def check_ucx_options():
@pytest.mark.skipif(
"ib0" not in psutil.net_if_addrs(), reason="Infiniband interface ib0 not found"
)
def test_initialize_ucx_infiniband():
p = mp.Process(target=_test_initialize_ucx_infiniband)
@pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
def test_initialize_ucx_infiniband(protocol):
if protocol == "ucx":
pytest.importorskip("ucp")
elif protocol == "ucxx":
pytest.importorskip("ucxx")

p = mp.Process(target=_test_initialize_ucx_infiniband, args=(protocol,))
p.start()
p.join()
assert not p.exitcode


def _test_initialize_ucx_all():
initialize()
def _test_initialize_ucx_all(protocol):
if protocol == "ucx":
ucp = pytest.importorskip("ucp")
elif protocol == "ucxx":
ucp = pytest.importorskip("ucxx")

initialize(protocol=protocol)
with LocalCluster(
protocol="ucx",
protocol=protocol,
dashboard_address=None,
n_workers=1,
threads_per_worker=1,
Expand Down Expand Up @@ -166,8 +203,14 @@ def check_ucx_options():
assert all(client.run(check_ucx_options).values())


def test_initialize_ucx_all():
p = mp.Process(target=_test_initialize_ucx_all)
@pytest.mark.parametrize("protocol", ["ucx", "ucxx"])
def test_initialize_ucx_all(protocol):
if protocol == "ucx":
pytest.importorskip("ucp")
elif protocol == "ucxx":
pytest.importorskip("ucxx")

p = mp.Process(target=_test_initialize_ucx_all, args=(protocol,))
p.start()
p.join()
assert not p.exitcode
Loading

0 comments on commit ab80c58

Please sign in to comment.