diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index dbd765ab44b13e..a47fcbd0d7c1a4 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1104,6 +1104,9 @@ if(USE_XPU) message(WARNING "Failed to include ATen XPU implementation target") else() target_link_libraries(torch_xpu PRIVATE torch_xpu_ops) + if(USE_C10D_XCCL) + target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL) + endif() if(MSVC) # Windows target_link_options(torch_xpu PRIVATE diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index c14e195cd43672..ab59afc88b610b 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -679,3 +679,12 @@ class _SymmetricMemory: def stream_write_value32( tensor: torch.Tensor, offset: int, val: int ) -> torch.Tensor: ... + +class ProcessGroupXCCL(Backend): + def __init__( + self, + store: Store, + rank: int, + size: int, + ): ... + \ No newline at end of file diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 5cba3a39629d4e..1259813d5d7a5f 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -77,7 +77,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { NCCL = 2, UCC = 3, MPI = 4, - CUSTOM = 5, + XCCL = 5, + CUSTOM = 6, }; static std::string backendTypeToString(const BackendType& type) { @@ -86,6 +87,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return "gloo"; case BackendType::NCCL: return "nccl"; + case BackendType::XCCL: + return "xccl"; case BackendType::UCC: return "ucc"; case BackendType::MPI: @@ -106,6 +109,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return BackendType::GLOO; } else if (backend == "nccl") { return BackendType::NCCL; + } else if (backend == "xccl") { + return BackendType::XCCL; } else if (backend == "ucc") { return BackendType::UCC; } else if (backend == "mpi") { @@ -636,6 +641,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // TODO: HACK for backend name to get sequence number for that backend. if (backendType == ProcessGroup::BackendType::GLOO || backendType == ProcessGroup::BackendType::NCCL || + backendType == ProcessGroup::BackendType::XCCL || backendType == ProcessGroup::BackendType::UCC) { getDefaultBackend()->setSequenceNumberForGroup(); } else { @@ -657,6 +663,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // TODO: HACK for backend name to get sequence number for that backend. if (backendType == ProcessGroup::BackendType::GLOO || backendType == ProcessGroup::BackendType::NCCL || + backendType == ProcessGroup::BackendType::XCCL || backendType == ProcessGroup::BackendType::UCC) { return getDefaultBackend()->getSequenceNumberForGroup(); } else { diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 01fc8cb45a3336..5644f3d00f3e7a 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -24,6 +24,10 @@ #include #endif +#ifdef USE_C10D_XCCL +#include +#endif + #ifdef USE_C10D_NCCL #include #include @@ -2311,6 +2315,7 @@ The hook must have the following signature: .value("UNDEFINED", ::c10d::ProcessGroup::BackendType::UNDEFINED) .value("GLOO", ::c10d::ProcessGroup::BackendType::GLOO) .value("NCCL", ::c10d::ProcessGroup::BackendType::NCCL) + .value("XCCL", ::c10d::ProcessGroup::BackendType::XCCL) .value("UCC", ::c10d::ProcessGroup::BackendType::UCC) .value("MPI", ::c10d::ProcessGroup::BackendType::MPI) .value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM) @@ -2927,6 +2932,23 @@ Example:: py::call_guard()); #endif +#ifdef USE_C10D_XCCL + auto processGroupXCCL = + intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupXCCL>( + module, "ProcessGroupXCCL", backend) + .def( + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, + int rank, + int size) { + return c10::make_intrusive<::c10d::ProcessGroupXCCL>( + store, rank, size); + }), + py::arg("store"), + py::arg("rank"), + py::arg("size"), + py::call_guard()); +#endif + #ifdef USE_C10D_UCC auto processGroupUCC = intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupUCC>( diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 8c8d7f2a8d8ee2..8d18edee42059e 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -89,6 +89,7 @@ "is_nccl_available", "is_torchelastic_launched", "is_ucc_available", + "is_xccl_available", "isend", "monitored_barrier", "new_group", @@ -132,6 +133,7 @@ _NCCL_AVAILABLE = True _GLOO_AVAILABLE = True _UCC_AVAILABLE = True +_XCCL_AVAILABLE = True _pickler = pickle.Pickler _unpickler = pickle.Unpickler @@ -195,6 +197,14 @@ def _export_c_types() -> None: except ImportError: _UCC_AVAILABLE = False +try: + from torch._C._distributed_c10d import ProcessGroupXCCL + + ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupXCCL"] +except ImportError: + _XCCL_AVAILABLE = False + logger = logging.getLogger(__name__) PG_WRAPPER_STORE_PREFIX = "pg_wrapper" @@ -224,7 +234,7 @@ class Backend(str): """ An enum-like class for backends. - Available backends: GLOO, NCCL, UCC, MPI, and other registered backends. + Available backends: GLOO, NCCL, UCC, MPI, XCCL, and other registered backends. The values of this class are lowercase strings, e.g., ``"gloo"``. They can be accessed as attributes, e.g., ``Backend.NCCL``. @@ -244,21 +254,24 @@ class Backend(str): NCCL = "nccl" UCC = "ucc" MPI = "mpi" + XCCL = "xccl" _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"]) _plugins: Dict[str, _BackendPlugin] = {} - backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI] + backend_list = [UNDEFINED, GLOO, NCCL, XCCL, UCC, MPI] default_device_backend_map: Dict[str, str] = { "cpu": GLOO, "cuda": NCCL, + "xpu": XCCL, } backend_capability: Dict[str, List[str]] = { GLOO: ["cpu", "cuda"], NCCL: ["cuda"], + XCCL: ["xpu"], UCC: ["cpu", "cuda"], MPI: ["cpu", "cuda"], } @@ -267,6 +280,7 @@ class Backend(str): UNDEFINED: ProcessGroup.BackendType.UNDEFINED, GLOO: ProcessGroup.BackendType.GLOO, NCCL: ProcessGroup.BackendType.NCCL, + XCCL: ProcessGroup.BackendType.XCCL, UCC: ProcessGroup.BackendType.UCC, MPI: ProcessGroup.BackendType.MPI, } @@ -327,7 +341,7 @@ def register_backend( Backend.backend_list.append(name.lower()) if devices is not None: for device in devices: - if device != "cpu" and device != "cuda": + if device != "cpu" and device != "cuda" and device != "xpu": Backend.default_device_backend_map[device] = name.lower() Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM @@ -340,7 +354,7 @@ def register_backend( "`cuda`. Please specify it via the `devices` argument of " "`register_backend`." ) - Backend.backend_capability[name.lower()] = ["cpu", "cuda"] + Backend.backend_capability[name.lower()] = ["cpu", "cuda", "xpu"] elif isinstance(devices, str): # Single device string specified. Simply convert to list. Backend.backend_capability[name.lower()] = [devices] @@ -1185,6 +1199,11 @@ def is_ucc_available() -> bool: return _UCC_AVAILABLE +def is_xccl_available() -> bool: + """Check if the XCCL backend is available.""" + return _XCCL_AVAILABLE + + def is_backend_available(backend: str) -> bool: """ Check backend availability. @@ -1437,6 +1456,10 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> backends.add(backend) # type: ignore[arg-type] elif is_gloo_available() and isinstance(backend, ProcessGroupGloo): backends.add(backend) # type: ignore[arg-type] + if torch.device("xpu") in devices and is_xccl_available(): + backend = group._get_backend(torch.device("xpu")) + if isinstance(backend, ProcessGroupXCCL): + backends.add(backend) # type: ignore[arg-type] if len(backends) == 0: warnings.warn("Set timeout is now only supported for either nccl or gloo.") for backend in backends: @@ -1472,7 +1495,7 @@ def init_process_group( Args: backend (str or Backend, optional): The backend to use. Depending on - build-time configurations, valid values include ``mpi``, ``gloo``, + build-time configurations, valid values include ``mpi``, ``gloo``, ``xccl``, ``nccl``, and ``ucc``. If the backend is not provided, then both a ``gloo`` and ``nccl`` backend will be created, see notes below for how multiple backends are managed. This field can be given as a lowercase string @@ -1752,10 +1775,9 @@ def _new_process_group_helper( "created, please use a different group name" ) - if device_id is not None and (device_id.index is None or device_id.type != "cuda"): + if device_id is not None and device_id.index is None: raise ValueError( - "init_process_group device_id parameter must be a cuda device with an " - "id, e.g. cuda:0, not just cuda or cpu" + "init_process_group device_id parameter must be a device with an index" ) # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value @@ -1885,6 +1907,17 @@ def _new_process_group_helper( backend_prefix_store, group_rank, group_size, timeout=timeout ) backend_type = ProcessGroup.BackendType.UCC + elif backend_str == Backend.XCCL: + if not is_xccl_available(): + raise RuntimeError("Distributed package doesn't have XCCL built in") + if backend_options is not None: + assert isinstance( + backend_options, ProcessGroupXCCL.Options + ), "Expected backend_options argument to be of type ProcessGroupXCCL.Options" + backend_class = ProcessGroupXCCL( + backend_prefix_store, group_rank, group_size + ) + backend_type = ProcessGroup.BackendType.XCCL else: assert ( backend_str.upper() in Backend._plugins @@ -2693,7 +2726,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): return _IllegalWork() else: return None - + work = group.allreduce([tensor], opts) if async_op: @@ -2701,7 +2734,6 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): else: work.wait() - @_exception_logger @deprecated( "`torch.distributed.all_reduce_coalesced` will be deprecated. If you must " @@ -4059,7 +4091,7 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F return _IllegalWork() else: return None - + work = group._reduce_scatter_base(output, input, opts) if async_op: @@ -4067,7 +4099,6 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F else: work.wait() - @deprecated( "`torch.distributed._reduce_scatter_base` is a private function and will be deprecated. " "Please use `torch.distributed.reduce_scatter_tensor` instead.",