Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Nov 27, 2024
1 parent 62eea62 commit fc3f2c7
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 13 deletions.
3 changes: 3 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,9 @@ if(USE_XPU)
message(WARNING "Failed to include ATen XPU implementation target")
else()
target_link_libraries(torch_xpu PRIVATE torch_xpu_ops)
if(USE_C10D_XCCL)
target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL)
endif()
if(MSVC)
# Windows
target_link_options(torch_xpu PRIVATE
Expand Down
9 changes: 9 additions & 0 deletions torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,12 @@ class _SymmetricMemory:
def stream_write_value32(
tensor: torch.Tensor, offset: int, val: int
) -> torch.Tensor: ...

class ProcessGroupXCCL(Backend):
def __init__(
self,
store: Store,
rank: int,
size: int,
): ...

9 changes: 8 additions & 1 deletion torch/csrc/distributed/c10d/ProcessGroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
NCCL = 2,
UCC = 3,
MPI = 4,
CUSTOM = 5,
XCCL = 5,
CUSTOM = 6,
};

static std::string backendTypeToString(const BackendType& type) {
Expand All @@ -86,6 +87,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
return "gloo";
case BackendType::NCCL:
return "nccl";
case BackendType::XCCL:
return "xccl";
case BackendType::UCC:
return "ucc";
case BackendType::MPI:
Expand All @@ -106,6 +109,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
return BackendType::GLOO;
} else if (backend == "nccl") {
return BackendType::NCCL;
} else if (backend == "xccl") {
return BackendType::XCCL;
} else if (backend == "ucc") {
return BackendType::UCC;
} else if (backend == "mpi") {
Expand Down Expand Up @@ -636,6 +641,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
// TODO: HACK for backend name to get sequence number for that backend.
if (backendType == ProcessGroup::BackendType::GLOO ||
backendType == ProcessGroup::BackendType::NCCL ||
backendType == ProcessGroup::BackendType::XCCL ||
backendType == ProcessGroup::BackendType::UCC) {
getDefaultBackend()->setSequenceNumberForGroup();
} else {
Expand All @@ -657,6 +663,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
// TODO: HACK for backend name to get sequence number for that backend.
if (backendType == ProcessGroup::BackendType::GLOO ||
backendType == ProcessGroup::BackendType::NCCL ||
backendType == ProcessGroup::BackendType::XCCL ||
backendType == ProcessGroup::BackendType::UCC) {
return getDefaultBackend()->getSequenceNumberForGroup();
} else {
Expand Down
22 changes: 22 additions & 0 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
#include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp>
#endif

#ifdef USE_C10D_XCCL
#include <torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp>
#endif

#ifdef USE_C10D_NCCL
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
Expand Down Expand Up @@ -2311,6 +2315,7 @@ The hook must have the following signature:
.value("UNDEFINED", ::c10d::ProcessGroup::BackendType::UNDEFINED)
.value("GLOO", ::c10d::ProcessGroup::BackendType::GLOO)
.value("NCCL", ::c10d::ProcessGroup::BackendType::NCCL)
.value("XCCL", ::c10d::ProcessGroup::BackendType::XCCL)
.value("UCC", ::c10d::ProcessGroup::BackendType::UCC)
.value("MPI", ::c10d::ProcessGroup::BackendType::MPI)
.value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM)
Expand Down Expand Up @@ -2927,6 +2932,23 @@ Example::
py::call_guard<py::gil_scoped_release>());
#endif

#ifdef USE_C10D_XCCL
auto processGroupXCCL =
intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupXCCL>(
module, "ProcessGroupXCCL", backend)
.def(
py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
int size) {
return c10::make_intrusive<::c10d::ProcessGroupXCCL>(
store, rank, size);
}),
py::arg("store"),
py::arg("rank"),
py::arg("size"),
py::call_guard<py::gil_scoped_release>());
#endif

#ifdef USE_C10D_UCC
auto processGroupUCC =
intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupUCC>(
Expand Down
55 changes: 43 additions & 12 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"is_nccl_available",
"is_torchelastic_launched",
"is_ucc_available",
"is_xccl_available",
"isend",
"monitored_barrier",
"new_group",
Expand Down Expand Up @@ -132,6 +133,7 @@
_NCCL_AVAILABLE = True
_GLOO_AVAILABLE = True
_UCC_AVAILABLE = True
_XCCL_AVAILABLE = True

_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
Expand Down Expand Up @@ -195,6 +197,14 @@ def _export_c_types() -> None:
except ImportError:
_UCC_AVAILABLE = False

try:
from torch._C._distributed_c10d import ProcessGroupXCCL

ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupXCCL"]
except ImportError:
_XCCL_AVAILABLE = False

logger = logging.getLogger(__name__)

PG_WRAPPER_STORE_PREFIX = "pg_wrapper"
Expand Down Expand Up @@ -224,7 +234,7 @@ class Backend(str):
"""
An enum-like class for backends.
Available backends: GLOO, NCCL, UCC, MPI, and other registered backends.
Available backends: GLOO, NCCL, UCC, MPI, XCCL, and other registered backends.
The values of this class are lowercase strings, e.g., ``"gloo"``. They can
be accessed as attributes, e.g., ``Backend.NCCL``.
Expand All @@ -244,21 +254,24 @@ class Backend(str):
NCCL = "nccl"
UCC = "ucc"
MPI = "mpi"
XCCL = "xccl"

_BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"])

_plugins: Dict[str, _BackendPlugin] = {}

backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI]
backend_list = [UNDEFINED, GLOO, NCCL, XCCL, UCC, MPI]

default_device_backend_map: Dict[str, str] = {
"cpu": GLOO,
"cuda": NCCL,
"xpu": XCCL,
}

backend_capability: Dict[str, List[str]] = {
GLOO: ["cpu", "cuda"],
NCCL: ["cuda"],
XCCL: ["xpu"],
UCC: ["cpu", "cuda"],
MPI: ["cpu", "cuda"],
}
Expand All @@ -267,6 +280,7 @@ class Backend(str):
UNDEFINED: ProcessGroup.BackendType.UNDEFINED,
GLOO: ProcessGroup.BackendType.GLOO,
NCCL: ProcessGroup.BackendType.NCCL,
XCCL: ProcessGroup.BackendType.XCCL,
UCC: ProcessGroup.BackendType.UCC,
MPI: ProcessGroup.BackendType.MPI,
}
Expand Down Expand Up @@ -327,7 +341,7 @@ def register_backend(
Backend.backend_list.append(name.lower())
if devices is not None:
for device in devices:
if device != "cpu" and device != "cuda":
if device != "cpu" and device != "cuda" and device != "xpu":
Backend.default_device_backend_map[device] = name.lower()
Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM

Expand All @@ -340,7 +354,7 @@ def register_backend(
"`cuda`. Please specify it via the `devices` argument of "
"`register_backend`."
)
Backend.backend_capability[name.lower()] = ["cpu", "cuda"]
Backend.backend_capability[name.lower()] = ["cpu", "cuda", "xpu"]
elif isinstance(devices, str):
# Single device string specified. Simply convert to list.
Backend.backend_capability[name.lower()] = [devices]
Expand Down Expand Up @@ -1185,6 +1199,11 @@ def is_ucc_available() -> bool:
return _UCC_AVAILABLE


def is_xccl_available() -> bool:
"""Check if the XCCL backend is available."""
return _XCCL_AVAILABLE


def is_backend_available(backend: str) -> bool:
"""
Check backend availability.
Expand Down Expand Up @@ -1437,6 +1456,10 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) ->
backends.add(backend) # type: ignore[arg-type]
elif is_gloo_available() and isinstance(backend, ProcessGroupGloo):
backends.add(backend) # type: ignore[arg-type]
if torch.device("xpu") in devices and is_xccl_available():
backend = group._get_backend(torch.device("xpu"))
if isinstance(backend, ProcessGroupXCCL):
backends.add(backend) # type: ignore[arg-type]
if len(backends) == 0:
warnings.warn("Set timeout is now only supported for either nccl or gloo.")
for backend in backends:
Expand Down Expand Up @@ -1472,7 +1495,7 @@ def init_process_group(
Args:
backend (str or Backend, optional): The backend to use. Depending on
build-time configurations, valid values include ``mpi``, ``gloo``,
build-time configurations, valid values include ``mpi``, ``gloo``, ``xccl``,
``nccl``, and ``ucc``. If the backend is not provided, then both a ``gloo``
and ``nccl`` backend will be created, see notes below for how multiple
backends are managed. This field can be given as a lowercase string
Expand Down Expand Up @@ -1752,10 +1775,9 @@ def _new_process_group_helper(
"created, please use a different group name"
)

if device_id is not None and (device_id.index is None or device_id.type != "cuda"):
if device_id is not None and device_id.index is None:
raise ValueError(
"init_process_group device_id parameter must be a cuda device with an "
"id, e.g. cuda:0, not just cuda or cpu"
"init_process_group device_id parameter must be a device with an index"
)

# Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value
Expand Down Expand Up @@ -1885,6 +1907,17 @@ def _new_process_group_helper(
backend_prefix_store, group_rank, group_size, timeout=timeout
)
backend_type = ProcessGroup.BackendType.UCC
elif backend_str == Backend.XCCL:
if not is_xccl_available():
raise RuntimeError("Distributed package doesn't have XCCL built in")
if backend_options is not None:
assert isinstance(
backend_options, ProcessGroupXCCL.Options
), "Expected backend_options argument to be of type ProcessGroupXCCL.Options"
backend_class = ProcessGroupXCCL(
backend_prefix_store, group_rank, group_size
)
backend_type = ProcessGroup.BackendType.XCCL
else:
assert (
backend_str.upper() in Backend._plugins
Expand Down Expand Up @@ -2693,15 +2726,14 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
return _IllegalWork()
else:
return None

work = group.allreduce([tensor], opts)

if async_op:
return work
else:
work.wait()


@_exception_logger
@deprecated(
"`torch.distributed.all_reduce_coalesced` will be deprecated. If you must "
Expand Down Expand Up @@ -4059,15 +4091,14 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F
return _IllegalWork()
else:
return None

work = group._reduce_scatter_base(output, input, opts)

if async_op:
return work
else:
work.wait()


@deprecated(
"`torch.distributed._reduce_scatter_base` is a private function and will be deprecated. "
"Please use `torch.distributed.reduce_scatter_tensor` instead.",
Expand Down

0 comments on commit fc3f2c7

Please sign in to comment.