Skip to content

Commit

Permalink
fix typo and register frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Sep 3, 2024
1 parent 486b61a commit 6844932
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
2 changes: 1 addition & 1 deletion cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ endif()
if(USE_XCCL)
if(NOT USE_XPU)
message(WARNING
"Not using XPU, so disabling USE_NUSE_XCCLCCL. Suppress this warning with "
"Not using XPU, so disabling USE_XCCL. Suppress this warning with "
"-DUSE_XCCL=OFF.")
caffe2_update_option(USE_XCCL OFF)
elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Linux")
Expand Down
2 changes: 1 addition & 1 deletion cmake/Modules/FindXCCL.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ find_file(
NO_DEFAULT_PATH
)

# Find include/sycl path from include path.
# Find include/oneapi path from include path.
find_file(
XCCL_INCLUDE_ONEAPI_DIR
NAMES oneapi
Expand Down
29 changes: 26 additions & 3 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
_NCCL_AVAILABLE = True
_GLOO_AVAILABLE = True
_UCC_AVAILABLE = True
_XCCL_AVAILABLE = True

_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
Expand Down Expand Up @@ -193,6 +194,14 @@ def _export_c_types() -> None:
except ImportError:
_UCC_AVAILABLE = False

try:
from torch._C._distributed_c10d import ProcessGroupXCCL

ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupXCCL"]
except ImportError:
_XCCL_AVAILABLE = False

logger = logging.getLogger(__name__)

PG_WRAPPER_STORE_PREFIX = "pg_wrapper"
Expand Down Expand Up @@ -222,7 +231,7 @@ class Backend(str):
"""
An enum-like class for backends.
Available backends: GLOO, NCCL, UCC, MPI, and other registered backends.
Available backends: GLOO, NCCL, UCC, MPI, XCCL, and other registered backends.
The values of this class are lowercase strings, e.g., ``"gloo"``. They can
be accessed as attributes, e.g., ``Backend.NCCL``.
Expand All @@ -242,6 +251,7 @@ class Backend(str):
NCCL = "nccl"
UCC = "ucc"
MPI = "mpi"
XCCL = "XCCL"

_BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"])

Expand Down Expand Up @@ -1097,6 +1107,9 @@ def is_ucc_available() -> bool:
"""Check if the UCC backend is available."""
return _UCC_AVAILABLE

def is_xccl_available() -> bool:
"""Check if the XCCL backend is available."""
return _XCCL_AVAILABLE

def is_backend_available(backend: str) -> bool:
"""
Expand Down Expand Up @@ -1385,7 +1398,7 @@ def init_process_group(
Args:
backend (str or Backend, optional): The backend to use. Depending on
build-time configurations, valid values include ``mpi``, ``gloo``,
build-time configurations, valid values include ``mpi``, ``gloo``, ``xccl``,
``nccl``, and ``ucc``. If the backend is not provided, then both a ``gloo``
and ``nccl`` backend will be created, see notes below for how multiple
backends are managed. This field can be given as a lowercase string
Expand Down Expand Up @@ -1762,7 +1775,6 @@ def _new_process_group_helper(
pg_options = ProcessGroupNCCL.Options()
pg_options.is_high_priority_stream = False
pg_options._timeout = timeout

if split_from:
pg_options.split_from = split_from
pg_options.split_color = _process_group_color(global_ranks_in_group)
Expand All @@ -1781,6 +1793,17 @@ def _new_process_group_helper(
backend_prefix_store, group_rank, group_size, timeout=timeout
)
backend_type = ProcessGroup.BackendType.UCC
elif backend_str == Backend.XCCL:
if not is_xccl_available():
raise RuntimeError("Distributed package doesn't have XCCL built in")
if pg_options is not None:
assert isinstance(
pg_options, ProcessGroupXCCL.Options
), "Expected pg_options argument to be of type ProcessGroupXCCL.Options"
backend_class = ProcessGroupXCCL(
backend_prefix_store, group_rank, group_size
)
backend_type = ProcessGroup.BackendType.XCCL
else:
assert (
backend_str.upper() in Backend._plugins
Expand Down

0 comments on commit 6844932

Please sign in to comment.