Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][compiled graphs] Support reduce scatter collective in compiled graph #49404

Open
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

anyadontfly
Copy link
Contributor

@anyadontfly anyadontfly commented Dec 22, 2024

Why are these changes needed?

Currently we do not have other collective operations except allreduce in Ray Compiled Graphs, we plan to add the other collective operations required in FSDP in the future.

Proposed API:

import ray.experimental.collective as collective

with InputNode() as inp:
    dag = [worker.return_tensor.bind(inp) for worker in workers]
    dag = collective.reducescatter.bind(dag, ReduceOp.SUM)
    dag = MultiOutputNode(dag)

Related issue number

Meta-issue: #47983

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@jcotant1 jcotant1 added core Issues that should be addressed in Ray Core compiled-graphs labels Dec 23, 2024
Copy link
Contributor

@dengwxn dengwxn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good! Left some comments to polish.

python/ray/experimental/collective/reducescatter.py Outdated Show resolved Hide resolved
python/ray/experimental/collective/reducescatter.py Outdated Show resolved Hide resolved
python/ray/dag/tests/experimental/test_torch_tensor_dag.py Outdated Show resolved Hide resolved
python/ray/dag/tests/experimental/test_torch_tensor_dag.py Outdated Show resolved Hide resolved
python/ray/dag/collective_node.py Outdated Show resolved Hide resolved
python/ray/dag/collective_node.py Outdated Show resolved Hide resolved
python/ray/dag/collective_node.py Outdated Show resolved Hide resolved
@anyadontfly anyadontfly changed the title [compiled graphs] Support reduce scatter collective in compiled graph [core][compiled graphs] Support reduce scatter collective in compiled graph Dec 23, 2024
@@ -9,6 +9,23 @@ class _CollectiveOp(Enum):

@PublicAPI
class ReduceOp(_CollectiveOp):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the types to this because python has strict requirements on subclassing enum.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate more on this? What's the problem and how to solve it?

Copy link
Contributor Author

@anyadontfly anyadontfly Dec 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like enums can only extend on enums that does not define any members. So ReduceOp can extend on _CollectiveOp, but AllReduceOp can't extend on ReduceOp if it defines members.
So in order to have AllReduceOp and ReduceScatterOp extend on ReduceOp, ReduceOp can't define any member. AllReduceOp and ReduceScatterOp has to define the ops in both of them.

@@ -9,6 +9,23 @@ class _CollectiveOp(Enum):

@PublicAPI
class ReduceOp(_CollectiveOp):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate more on this? What's the problem and how to solve it?


@PublicAPI
class AllReduceReduceOp(ReduceOp):
SUM = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you saying you have to define {SUM, PRODUCT, ...} manually in both all reduce and reduce scatter ops?

python/ray/dag/collective_node.py Outdated Show resolved Hide resolved
python/ray/dag/tests/experimental/test_torch_tensor_dag.py Outdated Show resolved Hide resolved
@@ -69,20 +71,26 @@ async def wait_collective(self, op_id: int, data: "torch.Tensor", op: ReduceOp):
def _apply_op(self, op: ReduceOp, tensors: List["torch.Tensor"]) -> "torch.Tensor":
"""Apply the specified reduction operation across a list of tensors."""

SUM = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to manually specifcy the five values? Why not just use ReduceOp.*? It's less error prone.

Signed-off-by: Puyuan Yao <[email protected]>
@anyadontfly anyadontfly requested a review from dengwxn December 26, 2024 01:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiled-graphs core Issues that should be addressed in Ray Core
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants