Skip to content

Commit

Permalink
Support UCXX alongside UCX-Py
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Oct 30, 2023
1 parent d9e1001 commit 79efe09
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 26 deletions.
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/local_cudf_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}"
)
print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
if args.protocol == "ucx":
if args.protocol in ["ucx", "ucxx"]:
print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/local_cudf_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
)
print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
print_key_value(key="Frac-match", value=f"{args.frac_match}")
if args.protocol == "ucx":
if args.protocol in ["ucx", "ucxx"]:
print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/local_cudf_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
key="Device memory limit", value=f"{format_bytes(args.device_memory_limit)}"
)
print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
if args.protocol == "ucx":
if args.protocol in ["ucx", "ucxx"]:
print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/local_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
)
print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
print_key_value(key="Protocol", value=f"{args.protocol}")
if args.protocol == "ucx":
if args.protocol in ["ucx", "ucxx"]:
print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/local_cupy_map_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def pretty_print_results(args, address_to_index, p2p_bw, results):
)
print_key_value(key="RMM Pool", value=f"{not args.disable_rmm_pool}")
print_key_value(key="Protocol", value=f"{args.protocol}")
if args.protocol == "ucx":
if args.protocol in ["ucx", "ucxx"]:
print_key_value(key="TCP", value=f"{args.enable_tcp_over_ucx}")
print_key_value(key="InfiniBand", value=f"{args.enable_infiniband}")
print_key_value(key="NVLink", value=f"{args.enable_nvlink}")
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def parse_benchmark_args(description="Generic dask-cuda Benchmark", args_list=[]
cluster_args.add_argument(
"-p",
"--protocol",
choices=["tcp", "ucx"],
choices=["tcp", "ucx", "ucxx"],
default="tcp",
type=str,
help="The communication protocol to use.",
Expand Down
63 changes: 47 additions & 16 deletions dask_cuda/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numba.cuda

import dask
import distributed.comm.ucx
from distributed.diagnostics.nvml import get_device_index_and_uuid, has_cuda_context

from .utils import get_ucx_config
Expand All @@ -23,12 +22,21 @@ def _create_cuda_context_handler():
numba.cuda.current_context()


def _create_cuda_context():
def _create_cuda_context(protocol="ucx"):
if protocol not in ["ucx", "ucxx"]:
return
try:
# Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
# context directly from the UCX module, thus avoiding a similar warning there.
try:
distributed.comm.ucx.init_once()
if protocol == "ucx":
import distributed.comm.ucx

distributed.comm.ucx.init_once()
elif protocol == "ucxx":
import distributed_ucxx.ucxx

distributed_ucxx.ucxx.init_once()
except ModuleNotFoundError:
# UCX initialization has to be delegated to Distributed, it will take care
# of setting correct environment variables and importing `ucp` after that.
Expand All @@ -39,20 +47,35 @@ def _create_cuda_context():
os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
)
ctx = has_cuda_context()
if (
ctx.has_context
and not distributed.comm.ucx.cuda_context_created.has_context
):
distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
if protocol == "ucx":
if (
ctx.has_context
and not distributed.comm.ucx.cuda_context_created.has_context
):
distributed.comm.ucx._warn_existing_cuda_context(ctx, os.getpid())
elif protocol == "ucxx":
if (
ctx.has_context
and not distributed_ucxx.ucxx.cuda_context_created.has_context
):
distributed_ucxx.ucxx._warn_existing_cuda_context(ctx, os.getpid())

_create_cuda_context_handler()

if not distributed.comm.ucx.cuda_context_created.has_context:
ctx = has_cuda_context()
if ctx.has_context and ctx.device_info != cuda_visible_device:
distributed.comm.ucx._warn_cuda_context_wrong_device(
cuda_visible_device, ctx.device_info, os.getpid()
)
if protocol == "ucx":
if not distributed.comm.ucx.cuda_context_created.has_context:
ctx = has_cuda_context()
if ctx.has_context and ctx.device_info != cuda_visible_device:
distributed.comm.ucx._warn_cuda_context_wrong_device(
cuda_visible_device, ctx.device_info, os.getpid()
)
elif protocol == "ucxx":
if not distributed_ucxx.ucxx.cuda_context_created.has_context:
ctx = has_cuda_context()
if ctx.has_context and ctx.device_info != cuda_visible_device:
distributed_ucxx.ucxx._warn_cuda_context_wrong_device(
cuda_visible_device, ctx.device_info, os.getpid()
)

except Exception:
logger.error("Unable to start CUDA Context", exc_info=True)
Expand All @@ -64,6 +87,7 @@ def initialize(
enable_infiniband=None,
enable_nvlink=None,
enable_rdmacm=None,
protocol="ucx",
):
"""Create CUDA context and initialize UCX-Py, depending on user parameters.
Expand Down Expand Up @@ -118,7 +142,7 @@ def initialize(
dask.config.set({"distributed.comm.ucx": ucx_config})

if create_cuda_context:
_create_cuda_context()
_create_cuda_context(protocol=protocol)


@click.command()
Expand All @@ -127,6 +151,12 @@ def initialize(
default=False,
help="Create CUDA context",
)
@click.option(
"--protocol",
default=None,
type=str,
help="Communication protocol, such as: 'tcp', 'tls', 'ucx' or 'ucxx'.",
)
@click.option(
"--enable-tcp-over-ucx/--disable-tcp-over-ucx",
default=False,
Expand All @@ -150,10 +180,11 @@ def initialize(
def dask_setup(
service,
create_cuda_context,
protocol,
enable_tcp_over_ucx,
enable_infiniband,
enable_nvlink,
enable_rdmacm,
):
if create_cuda_context:
_create_cuda_context()
_create_cuda_context(protocol=protocol)
9 changes: 6 additions & 3 deletions dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,11 @@ def __init__(
if enable_tcp_over_ucx or enable_infiniband or enable_nvlink:
if protocol is None:
protocol = "ucx"
elif protocol != "ucx":
raise TypeError("Enabling InfiniBand or NVLink requires protocol='ucx'")
elif protocol not in ["ucx", "ucxx"]:
raise TypeError(
"Enabling InfiniBand or NVLink requires protocol='ucx' or "
"protocol='ucxx'"
)

self.host = kwargs.get("host", None)

Expand Down Expand Up @@ -371,7 +374,7 @@ def __init__(
) + ["dask_cuda.initialize"]
self.new_spec["options"]["preload_argv"] = self.new_spec["options"].get(
"preload_argv", []
) + ["--create-cuda-context"]
) + ["--create-cuda-context", "--protocol", protocol]

self.cuda_visible_devices = CUDA_VISIBLE_DEVICES
self.scale(n_workers)
Expand Down
6 changes: 5 additions & 1 deletion dask_cuda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def get_preload_options(
if create_cuda_context:
preload_options["preload_argv"].append("--create-cuda-context")

if protocol == "ucx":
if protocol in ["ucx", "ucxx"]:
initialize_ucx_argv = []
if enable_tcp_over_ucx:
initialize_ucx_argv.append("--enable-tcp-over-ucx")
Expand Down Expand Up @@ -625,6 +625,10 @@ def get_worker_config(dask_worker):
import ucp

ret["ucx-transports"] = ucp.get_active_transports()
elif scheme == "ucxx":
import ucxx

ret["ucx-transports"] = ucxx.get_active_transports()

# comm timeouts
ret["distributed.comm.timeouts"] = dask.config.get("distributed.comm.timeouts")
Expand Down

0 comments on commit 79efe09

Please sign in to comment.