-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core][compiled graphs] Support reduce scatter collective in compiled graph #49404
base: master
Are you sure you want to change the base?
Conversation
Signed-off-by: Puyuan Yao <[email protected]> rebase to updated main branch
Signed-off-by: Puyuan Yao <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good! Left some comments to polish.
@@ -9,6 +9,23 @@ class _CollectiveOp(Enum): | |||
|
|||
@PublicAPI | |||
class ReduceOp(_CollectiveOp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the types to this because python has strict requirements on subclassing enum.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate more on this? What's the problem and how to solve it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like enums can only extend on enums that does not define any members. So ReduceOp
can extend on _CollectiveOp
, but AllReduceOp
can't extend on ReduceOp
if it defines members.
So in order to have AllReduceOp
and ReduceScatterOp
extend on ReduceOp
, ReduceOp
can't define any member. AllReduceOp
and ReduceScatterOp
has to define the ops in both of them.
@@ -9,6 +9,23 @@ class _CollectiveOp(Enum): | |||
|
|||
@PublicAPI | |||
class ReduceOp(_CollectiveOp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate more on this? What's the problem and how to solve it?
|
||
@PublicAPI | ||
class AllReduceReduceOp(ReduceOp): | ||
SUM = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you saying you have to define {SUM, PRODUCT, ...} manually in both all reduce and reduce scatter ops?
@@ -69,20 +71,26 @@ async def wait_collective(self, op_id: int, data: "torch.Tensor", op: ReduceOp): | |||
def _apply_op(self, op: ReduceOp, tensors: List["torch.Tensor"]) -> "torch.Tensor": | |||
"""Apply the specified reduction operation across a list of tensors.""" | |||
|
|||
SUM = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to manually specifcy the five values? Why not just use ReduceOp.*
? It's less error prone.
Signed-off-by: Puyuan Yao <[email protected]>
Why are these changes needed?
Currently we do not have other collective operations except allreduce in Ray Compiled Graphs, we plan to add the other collective operations required in FSDP in the future.
Proposed API:
Related issue number
Meta-issue: #47983
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.