Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auto parallel profiling op #1

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion oneflow/core/auto_parallel/sbp_constructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ limitations under the License.
*/

#include "oneflow/core/auto_parallel/sbp_constructor.h"
#include <iostream>
#include "oneflow/core/auto_parallel/auto_memory.h"
#include "oneflow/core/auto_parallel/sbp_node.h"
#include "oneflow/core/auto_parallel/sbp_util.h"
#include "oneflow/core/common/just.h"
#include "oneflow/core/common/singleton.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/framework/sbp_infer_util.h"
Expand Down Expand Up @@ -520,10 +522,55 @@ void SbpConstructor::PrintSBPGraphDebugInfo() {
auto_parallel::DecideOrder(node_list, order, [&](OpNode* a, OpNode* b) {
return a->op().op_name().compare(b->op().op_name()) > 0;
});

// get cost by running time start

std::cout << "--------------------------------------------------------------" << std::endl;
std::cout << "------------------get cost by running time start--------------" << std::endl;
double total_comp_cost_0 = 0;
for (int32_t i = 0; i < node_list.size(); i++) {
OpNode* op_node = node_list[order[i]];
std::cout << op_node->op().op_name() << " (^_^): " << op_node->op().op_conf().op_type_case()
<< std::endl;
auto LogicalBlobDesc4Bn = [&](const std::string& bn) -> const BlobDesc& {
const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(bn);
return op_node->LogicalBlobDesc4Lbi(lbi);
};
const ParallelDesc& parallel_desc = op_node->parallel_desc();

SbpNode* sbp_node = op_name2sbp_node_[op_node->op().op_name()];

// auto GetCompCost = [&](int32_t sbp_id) -> Maybe<void> {
// double comp_cost = JUST(op_node->op().GetComputeComplexity(
// &sbp_node->sbp_sig_list_[sbp_id], LogicalBlobDesc4Bn, parallel_desc));
// return Maybe<void>::Ok();
// };

std::cout << "sbp_node->sbp_sig_list_.size(): " << sbp_node->sbp_sig_list_.size() << std::endl;
double comp_cost = CHECK_JUST(op_node->op().GetComputeComplexity(
&sbp_node->sbp_sig_list_[sbp_node->final_sbp_sig_id_], LogicalBlobDesc4Bn, parallel_desc));

// if (comp_cost > GetValidMaxCopyCost()) {
// sbp_node->cost_[sbp_id] = comp_cost;
// } else {
// sbp_node->cost_[sbp_id] =
// cost_ratio_ * comp_cost
// * JUST(op_node->op().GetInputOutputFastestTimeShape())->elem_cnt();
// }
std::cout << "comp_cost: " << comp_cost << std::endl;
total_comp_cost_0 += comp_cost;
// for (int32_t sbp_id = 0; sbp_id < sbp_node->sbp_sig_list_.size(); sbp_id++) {
// }
}
std::cout << "Total cost: " << total_comp_cost_0 << std::endl;
std::cout << "--------------------------------------------------------------" << std::endl;
// get cost by running time end

std::vector<int32_t> str_order;

// test debug
std::cout << "Finish deciding order" << std::endl;
double total_cost = 0;

for (int32_t i = 0; i < node_list.size(); i++) {
OpNode* op_node = node_list[order[i]];
Expand All @@ -533,7 +580,9 @@ void SbpConstructor::PrintSBPGraphDebugInfo() {
// Print debug information for sbp graph
CHECK(it != op_name2sbp_node_.end());
const SbpNode* sbp_node = it->second;
std::cout << "Computation Cost: " << sbp_node->weighted_cost_[sbp_node->final_sbp_sig_id_];
double node_cost = sbp_node->weighted_cost_[sbp_node->final_sbp_sig_id_];
total_cost += node_cost;
std::cout << "Computation Cost: " << node_cost;
std::cout << ", Min Layer: " << sbp_node->min_layer_ << ", Max Layer: " << sbp_node->max_layer_
<< ", Tributary Layer: " << sbp_node->tributary_layer_
<< ", in trunk: " << sbp_node->on_trunk_
Expand Down Expand Up @@ -570,6 +619,8 @@ void SbpConstructor::PrintSBPGraphDebugInfo() {
}
std::cout << std::endl;
}

std::cout << "Total cost: " << total_cost << std::endl;
}

} // namespace auto_parallel
Expand Down
47 changes: 46 additions & 1 deletion oneflow/core/graph/op_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/auto_parallel/algorithm_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/job/sbp_parallel.pb.h"

namespace oneflow {

Expand Down Expand Up @@ -583,12 +584,56 @@ void OpGraph::PrintSBPGraphDebugInfo() const {
});
std::vector<int32_t> str_order;

std::cout << "--------------------------------------------------------------" << std::endl;
std::cout << "------------------get cost by running time start--------------" << std::endl;
double total_comp_cost_0 = 0;
for (int32_t i = 0; i < NodeList.size(); i++) {
OpNode* op_node = NodeList[order[i]];
std::cout << op_node->op().op_name() << " (^_^): " << op_node->op().op_conf().op_type_case()
<< std::endl;
auto LogicalBlobDesc4Bn = [&](const std::string& bn) -> const BlobDesc& {
const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(bn);
return op_node->LogicalBlobDesc4Lbi(lbi);
};
const ParallelDesc& parallel_desc = op_node->parallel_desc();

// SbpNode* sbp_node = op_name2sbp_node_[op_node->op().op_name()];

// auto GetCompCost = [&](int32_t sbp_id) -> Maybe<void> {
// double comp_cost = JUST(op_node->op().GetComputeComplexity(
// &sbp_node->sbp_sig_list_[sbp_id], LogicalBlobDesc4Bn, parallel_desc));
// return Maybe<void>::Ok();
// };

auto sig = op_node->nd_sbp_signature();
// std::cout << "sbp_node->sbp_sig_list_.size(): " << sbp_node->sbp_sig_list_.size() <<
// std::endl;
double comp_cost =
CHECK_JUST(op_node->op().GetComputeComplexity(&sig, LogicalBlobDesc4Bn, parallel_desc));

// if (comp_cost > GetValidMaxCopyCost()) {
// sbp_node->cost_[sbp_id] = comp_cost;
// } else {
// sbp_node->cost_[sbp_id] =
// cost_ratio_ * comp_cost
// * JUST(op_node->op().GetInputOutputFastestTimeShape())->elem_cnt();
// }
std::cout << "comp_cost: " << comp_cost << std::endl;
total_comp_cost_0 += comp_cost;
// for (int32_t sbp_id = 0; sbp_id < sbp_node->sbp_sig_list_.size(); sbp_id++) {
// }
}
std::cout << "Total cost: " << total_comp_cost_0 << std::endl;
std::cout << "--------------------------------------------------------------" << std::endl;

// test debug
std::cout << "Finish deciding order" << std::endl;

for (int32_t i = 0; i < NodeList.size(); i++) {
OpNode* op_node = NodeList[order[i]];
std::cout << op_node->op().op_name() << " (^_^):" << std::endl;
// std::cout << op_node->op().op_name() << " (^_^):" << std::endl;
std::cout << op_node->op().op_name() << " (^_^): " << op_node->op().op_conf().op_type_case()
<< std::endl;
// Sort before printing
const auto& op_input_bns = op_node->op().input_bns();
auto comp = [](const std::string& a, const std::string& b) { return a.compare(b) > 0; };
Expand Down
6 changes: 4 additions & 2 deletions oneflow/core/job_rewriter/auto_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ class AutoParallelPass final : public JobPass {
Maybe<void> Apply(const OpGraph& op_graph, Job* job) const;

Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {
const OpGraph op_graph(*job);
if (GlobalProcessCtx::Rank() == 0) { op_graph.PrintSBPGraphDebugInfo(); }
if (!job->job_conf().enable_auto_parallel()) { return Maybe<void>::Ok(); }
VLOG(3) << "=== Enable AutoParallel ===";
if (job->job_conf().enable_auto_parallel_ignore_user_sbp_config()) {
JUST(RemoveParallelCastOps(job));
}
const OpGraph op_graph(*job);
// const OpGraph op_graph(*job);
return Apply(op_graph, job);
}

Expand All @@ -62,7 +64,7 @@ Maybe<void> AutoParallelPass::Apply(const OpGraph& op_graph, Job* job) const {
<< std::chrono::duration_cast<std::chrono::milliseconds>(time_end - time_begin).count()
<< " ms\n";
if (GlobalProcessCtx::Rank() == 0) {
// sbp_constructor.PrintSBPGraphDebugInfo();
sbp_constructor.PrintSBPGraphDebugInfo();
JUST(sbp_constructor.CheckSbpAgreement(*job));
}
return Maybe<void>::Ok();
Expand Down
Loading