Skip to content

Commit

Permalink
added all gather
Browse files Browse the repository at this point in the history
Signed-off-by: Puyuan Yao <[email protected]>
  • Loading branch information
anyadontfly committed Jan 6, 2025
1 parent f1864e3 commit 3cb14f2
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 7 deletions.
16 changes: 12 additions & 4 deletions python/ray/dag/collective_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class _CollectiveOperation:
def __init__(
self,
input_nodes: List[DAGNode],
op: _CollectiveOp,
transport: Optional[Union[str, Communicator]] = None,
op: Optional[_CollectiveOp] = None,
):
if len(input_nodes) == 0:
raise ValueError("Expected input nodes for a collective operation")
Expand All @@ -66,8 +66,8 @@ def __init__(
)

self._op = op
if not isinstance(self._op, ReduceOp):
raise NotImplementedError("Only ReduceOp is implemented")
if self._op is not None and not isinstance(self._op, ReduceOp):
raise NotImplementedError(f"Unimplemented collective operation: {self._op}")
if transport is None:
transport = TorchTensorType.NCCL
self._type_hint = TorchTensorType(transport=transport, _direct_return=True)
Expand Down Expand Up @@ -129,7 +129,15 @@ def execute(self, send_buf: "torch.Tensor") -> "torch.Tensor":
if not isinstance(send_buf, torch.Tensor):
raise ValueError("Expected a torch tensor")
communicator = self.get_communicator()
if isinstance(self._op, AllReduceOp):
if self._op is None:
world_size = len(self._actor_handles)
recv_buf = torch.empty(
(send_buf.shape[0] * world_size, *send_buf.shape[1:]),
dtype=send_buf.dtype,
device=send_buf.device,
)
communicator.allgather(send_buf, recv_buf)
elif isinstance(self._op, AllReduceOp):
recv_buf = torch.empty_like(send_buf)
communicator.allreduce(send_buf, recv_buf, self._op)
elif isinstance(self._op, ReduceScatterOp):
Expand Down
46 changes: 46 additions & 0 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,6 +1780,52 @@ def test_nccl_reduce_scatter_with_class_method_output_node(ray_start_regular):
t3,
]
assert all(torch.equal(r, e) for r, e in zip(result, expected_result))


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_nccl_all_gather(ray_start_regular):
"""
Test basic all-gather.
"""
if not USE_GPU:
pytest.skip("NCCL tests require GPUs")

assert (
sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1
), "This test requires at least 2 GPUs"

actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1)

num_workers = 2
workers = [actor_cls.remote() for _ in range(num_workers)]

with InputNode() as inp:
computes = [
worker.compute_with_tuple_args.bind(inp, i)
for i, worker in enumerate(workers)
]
collectives = collective.allgather.bind(computes)
recvs = [
worker.recv_tensor.bind(collective)
for worker, collective in zip(workers, collectives)
]
dag = MultiOutputNode(recvs)

compiled_dag = dag.experimental_compile()

for i in range(3):
i += 1
shape = (i, i)
dtype = torch.float16
value = i
ref = compiled_dag.execute(
[(shape, dtype, value) for _ in range(num_workers)]
)
result = ray.get(ref)
for tensor in result:
tensor = tensor.to("cpu")
expected_tensor_val = torch.ones((num_workers * i, i), dtype=dtype) * value
assert torch.equal(tensor, expected_tensor_val)


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 2}], indirect=True)
Expand Down
17 changes: 17 additions & 0 deletions python/ray/experimental/channel/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,23 @@ def reducescatter(
op: The reduce operation.
"""
raise NotImplementedError

@abstractmethod
def allgather(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
) -> None:
"""
Collectively allgather the tensor across the group.
Args:
send_buf: The input torch.tensor to allgather. It should already be
on this actor's default device.
recv_buf: The output torch.tensor to store the allgather result.
op: The reduce operation.
"""
raise NotImplementedError

@abstractmethod
def destroy() -> None:
Expand Down
14 changes: 14 additions & 0 deletions python/ray/experimental/channel/cpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,20 @@ def reducescatter(
]
self.num_ops[barrier_key] += 1

def allgather(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
):
all_ranks = [
self.get_rank(actor_handle) for actor_handle in self.get_actor_handles()
]
barrier_key = "barrier-collective-" + "-".join(map(str, sorted(all_ranks)))
barrier = CPUCommBarrier.options(name=barrier_key, get_if_exists=True).remote(
self._world_size
)
self.barriers.add(barrier)

def destroy(self) -> None:
for barrier in self.barriers:
ray.kill(barrier)
Expand Down
40 changes: 40 additions & 0 deletions python/ray/experimental/channel/nccl_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,46 @@ def reducescatter(
"There may be a dtype mismatch between input tensors from "
"different ranks."
)

def allgather(
self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
):
if self._closed:
raise RayChannelError("NCCL group has been destroyed.")

assert send_buf.dtype == recv_buf.dtype, (
"Ray Compiled Graph derived the dtype of recv_buf from send_buf, "
"so send_buf and recv_buf must have the same dtype. "
"If you see this error, please file an issue at Ray repository."
)

if not recv_buf.shape[0] == send_buf.shape[0] * len(self.get_actor_handles()):
raise ValueError(
"Ray Compiled Graph only support all-gather on tensors "
"of same size."
)
self._comm.allGather(
self.nccl_util.get_tensor_ptr(send_buf),
self.nccl_util.get_tensor_ptr(recv_buf),
send_buf.numel(),
self.nccl_util.get_nccl_tensor_dtype(send_buf),
self._cuda_stream.ptr,
)

# Buffer values are undefined if NCCL ops are aborted. Therefore, we
# need to synchronize here and check that the channel is still open to
# ensure that the receive buffer is valid.
# TODO(swang): Avoid CUDA synchronization.
# TODO(wxdeng): Use check_async_error.
self._cuda_stream.synchronize()
if self._closed:
raise RayChannelError(
"NCCL group has been destroyed during allreduce operation. "
"There may be a dtype mismatch between input tensors from "
"different ranks."
)

@property
def recv_stream(self) -> Optional["cp.cuda.ExternalStream"]:
Expand Down
3 changes: 2 additions & 1 deletion python/ray/experimental/collective/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ray.experimental.collective.allreduce import allreduce
from ray.experimental.collective.reducescatter import reducescatter
from ray.experimental.collective.allgather import allgather

__all__ = ["allreduce", "reducescatter"]
__all__ = ["allreduce", "reducescatter", "allgather"]
86 changes: 86 additions & 0 deletions python/ray/experimental/collective/allgather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import logging
from typing import List, Optional, Union

import ray
from ray.dag.collective_node import CollectiveOutputNode, _CollectiveOperation
from ray.dag.constants import (
BIND_INDEX_KEY,
COLLECTIVE_OPERATION_KEY,
PARENT_CLASS_NODE_KEY,
)
from ray.experimental.channel.torch_tensor_type import Communicator, TorchTensorType

logger = logging.getLogger(__name__)


class AllGatherWrapper:
"""Wrapper for NCCL all-gather."""

def bind(
self,
input_nodes: List["ray.dag.DAGNode"],
transport: Optional[Union[str, Communicator]] = None,
) -> List[CollectiveOutputNode]:
"""
Bind input nodes with a collective operation. The collective operation is
directly applied to the torch tensors from the input nodes. The output nodes
are the results of the collective operation in the same torch tensors.
Requirements:
1. Each input node returns a torch tensor.
2. Each input node is from a different actor.
3. If a custom transport is specified, its actor set matches the actor set
of the input nodes.
4. All tensors have the same shape.
Requirements 1-3 are checked in the `CollectiveGroup` constructor.
Requirement 4 is not checked yet.
Args:
input_nodes: A list of DAG nodes.
op: The collective operation.
transport: GPU communicator for the collective operation. If not
specified, the default NCCL is used.
Returns:
A list of collective output nodes.
"""
if transport is None:
transport = TorchTensorType.NCCL
collective_op = _CollectiveOperation(input_nodes, transport)
collective_output_nodes: List[CollectiveOutputNode] = []

for input_node in input_nodes:
actor_handle: Optional[
"ray.actor.ActorHandle"
] = input_node._get_actor_handle()
if actor_handle is None:
raise ValueError("Expected an actor handle from the input node")
collective_output_node = CollectiveOutputNode(
method_name=f"allgather",
method_args=(input_node,),
method_kwargs=dict(),
method_options=dict(),
other_args_to_resolve={
PARENT_CLASS_NODE_KEY: actor_handle,
BIND_INDEX_KEY: actor_handle._ray_dag_bind_index,
COLLECTIVE_OPERATION_KEY: collective_op,
},
)
actor_handle._ray_dag_bind_index += 1
collective_output_nodes.append(collective_output_node)

return collective_output_nodes

def __call__(
self,
tensor_list,
tensor,
group_name: str = "default",
):
from ray.util.collective.collective import allgather

return allgather(tensor_list, tensor, group_name)


allgather = AllGatherWrapper()
2 changes: 1 addition & 1 deletion python/ray/experimental/collective/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def bind(
"""
if transport is None:
transport = TorchTensorType.NCCL
collective_op = _CollectiveOperation(input_nodes, op, transport)
collective_op = _CollectiveOperation(input_nodes, transport, op)
collective_output_nodes: List[CollectiveOutputNode] = []

for input_node in input_nodes:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/experimental/collective/reducescatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def bind(
"""
if transport is None:
transport = TorchTensorType.NCCL
collective_op = _CollectiveOperation(input_nodes, op, transport)
collective_op = _CollectiveOperation(input_nodes, transport, op)
collective_output_nodes: List[CollectiveOutputNode] = []

for input_node in input_nodes:
Expand Down

0 comments on commit 3cb14f2

Please sign in to comment.