-
Notifications
You must be signed in to change notification settings - Fork 726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable ZeRO with auto parallel #9288
Merged
Merged
Changes from 6 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
54771bc
Enable ZeRO with auto parallel in the first setting
Yipeng1994 d7ab8c9
Remove compute_cost parameter
Yipeng1994 021f30e
Move the addition of wait time into sbp_node
Yipeng1994 ad38ff1
Remove transfer cost since it is merged into the GetTransferCost()
Yipeng1994 02486af
Rename mainstem to trunk
Yipeng1994 8e4c063
Merge branch 'master' into feat-auto_parallel-ZeRO
Yipeng1994 17956e0
Update warning
Yipeng1994 c29d464
Merge branch 'feat-auto_parallel-ZeRO' of github.com:Oneflow-Inc/onef…
Yipeng1994 2a3f564
Merge branch 'master' into feat-auto_parallel-ZeRO
Yipeng1994 a092846
Merge branch 'master' into feat-auto_parallel-ZeRO
mergify[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -537,7 +537,7 @@ void SbpNode::RaiseConsumerNum(const HashMap<std::string, SbpNode*>& op_name2sbp | |
void SbpNode::SpreadAvailWaitTime(const std::vector<double>& trunk_cost, | ||
const std::vector<double>& acc_trunk_cost, | ||
const HashMap<std::string, SbpNode*>& op_name2sbp_node, | ||
double wait_time, double transfer_cost) { | ||
double wait_time) { | ||
// skip the proxy nodes and the sources | ||
if (min_layer_ <= 0) { return; } | ||
// Have not finished spreading for consumers or downstream nodes or already visited. | ||
|
@@ -577,15 +577,13 @@ void SbpNode::SpreadAvailWaitTime(const std::vector<double>& trunk_cost, | |
// (1) P->S0->S0->S0->B | ||
// (2) p->B->B->B->B | ||
// We would use (2) when the tensor is relatively tiny. | ||
this_edge->wait_time_ += transfer_cost; | ||
// Do not inherit trunk cost for nodes on the trunk | ||
if (!producer->on_trunk_) { | ||
// Inherit the minimal of the trunk cost from consumers | ||
producer->DropAvailWaitTime(curr_trunk_cost); | ||
} | ||
producer->counter_--; | ||
producer->SpreadAvailWaitTime(trunk_cost, acc_trunk_cost, op_name2sbp_node, wait_time, | ||
transfer_cost); | ||
producer->SpreadAvailWaitTime(trunk_cost, acc_trunk_cost, op_name2sbp_node, wait_time); | ||
} | ||
// Put the rest the trunk cost in the upstream nodes. | ||
for (const auto& ctrl_in_op_name : op_node_->op().op_conf().ctrl_in_op_name()) { | ||
|
@@ -601,8 +599,7 @@ void SbpNode::SpreadAvailWaitTime(const std::vector<double>& trunk_cost, | |
producer->DropAvailWaitTime(curr_trunk_cost); | ||
} | ||
producer->counter_--; | ||
producer->SpreadAvailWaitTime(trunk_cost, acc_trunk_cost, op_name2sbp_node, wait_time, | ||
transfer_cost); | ||
producer->SpreadAvailWaitTime(trunk_cost, acc_trunk_cost, op_name2sbp_node, wait_time); | ||
} | ||
} | ||
// Set counter_ to be -1, do not visit it again. | ||
|
@@ -619,18 +616,24 @@ void SbpNode::DropAvailWaitTime(double curr_trunk_cost) { | |
|
||
// Assemble copy cost for all the incoming edges | ||
|
||
void SbpNode::InitializeCopyCost(bool compute_cost, bool use_sbp_collector) { | ||
void SbpNode::InitializeCopyCost(bool use_sbp_collector) { | ||
for (SbpEdge* this_edge : edges_in_) { | ||
const auto* sbp_node_producer = this_edge->start_node_; | ||
OpNode* producer = sbp_node_producer->op_node_; | ||
|
||
// skip it if proxy | ||
if (use_sbp_collector && !producer) { continue; } | ||
|
||
// look through input blobs | ||
for (const std::string& ibn : op_node_->op().input_bns()) { | ||
if (producer->op().op_name() == op_node_->SrcNode4Ibn(ibn).op().op_name()) { | ||
this_edge->InitializeCopyCost(ibn, compute_cost, use_sbp_collector); | ||
this_edge->InitializeCopyCost(ibn, use_sbp_collector); | ||
} | ||
} | ||
// Add Wait time | ||
for (auto& cost_row : this_edge->cost_) { | ||
for (auto& cost_value : cost_row) { | ||
// If transferring between devices, we need to add wait time. | ||
if (cost_value > 0.0) { cost_value += this_edge->wait_time_; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 看起来最主要的增加的逻辑是 Add Wait time 这里? 最主要的删除逻辑是删掉了很多 compute_cost 的开关 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 只不过由赋值改为了增加。当然这些实现在最终结果上都是一样的 |
||
} | ||
} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是因为这里不做 edge 的遍历,所以可以变快?
这样改完后,逻辑还等价不?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里我测了一下,时间基本是一样的。逻辑都是等价的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
主要变快的地方在于 https://github.com/Oneflow-Inc/oneflow/pull/9288/files#diff-40b436fe2eff96c43760f1c9abc5c2a1518c696046d8df8279b5bb6d70b6beaaL267-L281
InitializeCopyCost() 的调用从2次变到1次