Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more algebra simplify rules #1291

Merged
merged 3 commits into from
May 20, 2024
Merged

Conversation

eedalong
Copy link
Collaborator

No description provided.

@eedalong eedalong requested a review from Yancey1989 May 17, 2024 09:01
@eedalong eedalong self-assigned this May 17, 2024
@eedalong eedalong force-pushed the add_algebra_simplifier branch 4 times, most recently from edd31ee to a939e1c Compare May 19, 2024 13:58
@eedalong eedalong changed the title Optimize algebra simplify pass Add more algebra simplify rules May 19, 2024
}
};

struct OptimizationBarrierSimplifierPattern
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of this pattern rewriter?

Copy link
Collaborator Author

@eedalong eedalong May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have

B = All-Gather(A)
C = Reshape(B)
D = OptimizationBarrier(C)
E = Slice(D)

graph pattern lowered from TorchAcc, to prevent standalone Reshape Op, which will cause extra global memory read & write, we add this rule to change this graph into

B = All-Gather(A)
C = OptimizationBarrier(B)
D = Reshape(C)
E = Slice(D)

@Yancey1989

Copy link
Collaborator

@Yancey1989 Yancey1989 May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the clarification. After discussing with @eedalong in the office, it appears that the reshape->optimization_barrier->slice pattern is not a commonly simplified rule. In the XLA FSDP ( implementation, this pattern serves as a workaround for TPUs. However, it is not required for NVIDIA GPUs. Perhaps we could set _shard_size_multiple=1, which would eliminate the need for this pattern. Ref: :https://github.com/AlibabaPAI/xla/blob/67edb354372a8e5cbce41f20ecfca68b328635c6/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py#L1435 .
cc @anw90

Copy link
Collaborator

@Yancey1989 Yancey1989 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@eedalong eedalong merged commit 59c9279 into alibaba:main May 20, 2024
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants