-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Puyuan Yao <[email protected]>
- Loading branch information
1 parent
f1864e3
commit 3cb14f2
Showing
9 changed files
with
219 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from ray.experimental.collective.allreduce import allreduce | ||
from ray.experimental.collective.reducescatter import reducescatter | ||
from ray.experimental.collective.allgather import allgather | ||
|
||
__all__ = ["allreduce", "reducescatter"] | ||
__all__ = ["allreduce", "reducescatter", "allgather"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import logging | ||
from typing import List, Optional, Union | ||
|
||
import ray | ||
from ray.dag.collective_node import CollectiveOutputNode, _CollectiveOperation | ||
from ray.dag.constants import ( | ||
BIND_INDEX_KEY, | ||
COLLECTIVE_OPERATION_KEY, | ||
PARENT_CLASS_NODE_KEY, | ||
) | ||
from ray.experimental.channel.torch_tensor_type import Communicator, TorchTensorType | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AllGatherWrapper: | ||
"""Wrapper for NCCL all-gather.""" | ||
|
||
def bind( | ||
self, | ||
input_nodes: List["ray.dag.DAGNode"], | ||
transport: Optional[Union[str, Communicator]] = None, | ||
) -> List[CollectiveOutputNode]: | ||
""" | ||
Bind input nodes with a collective operation. The collective operation is | ||
directly applied to the torch tensors from the input nodes. The output nodes | ||
are the results of the collective operation in the same torch tensors. | ||
Requirements: | ||
1. Each input node returns a torch tensor. | ||
2. Each input node is from a different actor. | ||
3. If a custom transport is specified, its actor set matches the actor set | ||
of the input nodes. | ||
4. All tensors have the same shape. | ||
Requirements 1-3 are checked in the `CollectiveGroup` constructor. | ||
Requirement 4 is not checked yet. | ||
Args: | ||
input_nodes: A list of DAG nodes. | ||
op: The collective operation. | ||
transport: GPU communicator for the collective operation. If not | ||
specified, the default NCCL is used. | ||
Returns: | ||
A list of collective output nodes. | ||
""" | ||
if transport is None: | ||
transport = TorchTensorType.NCCL | ||
collective_op = _CollectiveOperation(input_nodes, transport) | ||
collective_output_nodes: List[CollectiveOutputNode] = [] | ||
|
||
for input_node in input_nodes: | ||
actor_handle: Optional[ | ||
"ray.actor.ActorHandle" | ||
] = input_node._get_actor_handle() | ||
if actor_handle is None: | ||
raise ValueError("Expected an actor handle from the input node") | ||
collective_output_node = CollectiveOutputNode( | ||
method_name=f"allgather", | ||
method_args=(input_node,), | ||
method_kwargs=dict(), | ||
method_options=dict(), | ||
other_args_to_resolve={ | ||
PARENT_CLASS_NODE_KEY: actor_handle, | ||
BIND_INDEX_KEY: actor_handle._ray_dag_bind_index, | ||
COLLECTIVE_OPERATION_KEY: collective_op, | ||
}, | ||
) | ||
actor_handle._ray_dag_bind_index += 1 | ||
collective_output_nodes.append(collective_output_node) | ||
|
||
return collective_output_nodes | ||
|
||
def __call__( | ||
self, | ||
tensor_list, | ||
tensor, | ||
group_name: str = "default", | ||
): | ||
from ray.util.collective.collective import allgather | ||
|
||
return allgather(tensor_list, tensor, group_name) | ||
|
||
|
||
allgather = AllGatherWrapper() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters