From 972dfeea885aba50d7f0bbe8a64838ffdb716edb Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Thu, 11 Aug 2022 10:06:37 +0000 Subject: [PATCH 01/66] default nccl use compute stream in grad acc --- python/oneflow/nn/graph/graph_config.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index ff5b044b779..5e291f12b1f 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -116,7 +116,7 @@ def build(self, x): Args: mode (bool): if set to true, optimizer states of Data Parallel will be sharded across devices. - stage (int): optimization stage, range from 1 to 3. + stage (int): optimization stage, range from 1 to 3. shard_min_size (int): min size (element count) of a shard of an optimizer state. shard_restore_level (int): level to restore sharded parameter to whole parameter for consumer operators, level 0 is no restore, level 1 is soft restore, level 2 is hard restore. Note that this paremeter is at pre-alpha stage. """ @@ -178,7 +178,7 @@ def __init__(self): self.bn1 = flow.nn.BatchNorm1d(100) self.config.allow_fuse_add_to_output(True) def build(self, x): - bn = self.bn1(x) + bn = self.bn1(x) out = bn + x return out @@ -191,7 +191,7 @@ def build(self, x): def allow_fuse_cast_scale(self, mode: bool = True): r"""If set to true, try to fuse cast and scalar_mul_by_tensor to improve performance. - + For example: .. code-block:: python @@ -240,6 +240,11 @@ def build(self, x): value (int): num of steps. """ self.proto.num_gradient_accumulation_steps = value + if value > 1: + # NOTE(chengcheng): when use gradient accumulation, optimizer nccl allreduce can NOT + # overlap with backward, so nccl use compute stream is optimization without negative + # effects. + nccl_config.enable_use_compute_stream(True) def set_outputs_buffer_size(self, value: int = 2): r"""Set the outputs buffer size of ``nn.Graph``. @@ -278,7 +283,7 @@ def build(self, x): return self.m(x) graph = Graph() - + Args: mode (bool, optional): The default vaule is True. """ @@ -289,7 +294,7 @@ def enable_straighten_algorithm(self, mode: bool = True): If using nccl compute stream, turning it on might not speed up the training. If not using nccl compute stream, turning it on might slow down data parallelism by 0.6% and slow down model parallelism by 6%. - Considering memory, enabling the straighten algorithm is forbidden with one machine/device only, and not recommended under pipeline parallelism. + Considering memory, enabling the straighten algorithm is forbidden with one machine/device only, and not recommended under pipeline parallelism. """ self.proto.enable_straighten_algorithm_in_task_graph = mode From 5c19afa3521b2f87398a05230a201e0906d1999c Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Tue, 16 Aug 2022 04:44:20 +0000 Subject: [PATCH 02/66] rm sharable mem block graph --- .../core/graph/sharable_mem_block_graph.cpp | 97 ------------------- oneflow/core/graph/sharable_mem_block_graph.h | 64 ------------ oneflow/core/graph/task_graph.cpp | 24 ++++- 3 files changed, 19 insertions(+), 166 deletions(-) delete mode 100644 oneflow/core/graph/sharable_mem_block_graph.cpp delete mode 100644 oneflow/core/graph/sharable_mem_block_graph.h diff --git a/oneflow/core/graph/sharable_mem_block_graph.cpp b/oneflow/core/graph/sharable_mem_block_graph.cpp deleted file mode 100644 index 26b4c9303eb..00000000000 --- a/oneflow/core/graph/sharable_mem_block_graph.cpp +++ /dev/null @@ -1,97 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/graph/sharable_mem_block_graph.h" -#include "oneflow/core/register/register_desc.h" -#include "oneflow/core/register/runtime_register_desc.h" -#include "oneflow/core/graph/inplace_regst_graph.h" - -namespace oneflow { - -namespace { - -bool IsConsumersAndProducerInSameChain(const RegstDescProto& regst_desc, - const PlanTaskGraph& plan_task_graph) { - auto ChainId4TaskId = [&](int64_t task_id) { - return plan_task_graph.TaskProto4TaskId(task_id)->task_set_info().chain_id(); - }; - int64_t producer_chain_id = ChainId4TaskId(regst_desc.producer_task_id()); - for (int64_t consumer_task_id : regst_desc.consumer_task_id()) { - if (ChainId4TaskId(consumer_task_id) != producer_chain_id) { return false; } - } - return true; -} -void ForEachInplacedRegstDescs( - const HashSet regst_desc, - const std::function&)>& Handler) { - InplaceRegstGraph inplace_gph(regst_desc); - inplace_gph.ForEachConnectedComponent([&](const HashSet& nodes) { - if (nodes.size() == 1) { return; } - HashSet regst_descs; - for (const auto* node : nodes) { CHECK(regst_descs.emplace(node->regst_desc()).second); } - Handler(regst_descs); - }); -} - -} // namespace - -SharableMemBlockNode::SharableMemBlockNode(int64_t chain_id, - const HashSet& regst_descs) - : chain_id_(chain_id), regst_descs_(regst_descs.begin(), regst_descs.end()) {} - -SharableMemBlockGraph::SharableMemBlockGraph( - const PlanTaskGraph& plan_task_gph, - const std::function& IsSharable) { - HashMap> chain_id2regst_descs; - for (const TaskProto& task : plan_task_gph.plan().task()) { - for (const auto& pair : task.produced_regst_desc()) { - if (IsConsumersAndProducerInSameChain(pair.second, plan_task_gph) - && IsSharable(pair.second)) { - CHECK(chain_id2regst_descs[task.task_set_info().chain_id()].emplace(&pair.second).second); - } - } - } - for (const auto& pair : chain_id2regst_descs) { - HashMap regst_desc2node; - for (const auto* regst_desc : pair.second) { - auto* node = new SharableMemBlockNode(pair.first, {regst_desc}); - AddAllocatedNode(node); - CHECK(regst_desc2node.emplace(regst_desc, node).second); - } - ForEachInplacedRegstDescs(pair.second, [&](const HashSet& regst_descs) { - auto* parent = new SharableMemBlockNode(pair.first, regst_descs); - AddAllocatedNode(parent); - for (const RegstDescProto* regst_desc : regst_descs) { - auto* edge = new SharableMemBlockEdge(); - AddAllocatedEdge(edge); - Connect(parent, edge, regst_desc2node.at(regst_desc)); - } - }); - } -} - -void SharableMemBlockGraph::ForEachSourceNodeGroup( - const std::function& GroupBy, - const std::function&)>& Handler) const { - HashMap> group_key2source_nodes; - for (const SharableMemBlockNode* source : source_nodes()) { - group_key2source_nodes[GroupBy(source)].emplace_back(source); - } - for (const auto& pair : group_key2source_nodes) { - if (pair.second.size() > 1) { Handler(pair.second); } - } -} - -} // namespace oneflow diff --git a/oneflow/core/graph/sharable_mem_block_graph.h b/oneflow/core/graph/sharable_mem_block_graph.h deleted file mode 100644 index 7f4b1c6186a..00000000000 --- a/oneflow/core/graph/sharable_mem_block_graph.h +++ /dev/null @@ -1,64 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_GRAPH_SHARABLE_MEM_BLOCK_GRAPH_H_ -#define ONEFLOW_CORE_GRAPH_SHARABLE_MEM_BLOCK_GRAPH_H_ - -#include "oneflow/core/graph/graph.h" -#include "oneflow/core/register/register_desc.pb.h" -#include "oneflow/core/graph/plan_task_graph.h" - -namespace oneflow { - -class SharableMemBlockEdge; - -class SharableMemBlockNode final : public Node { - public: - OF_DISALLOW_COPY_AND_MOVE(SharableMemBlockNode); - SharableMemBlockNode(int64_t chain_id, const HashSet& regst_descs); - - ~SharableMemBlockNode() = default; - - int64_t chain_id() const { return chain_id_; } - const std::vector& regst_descs() const { return regst_descs_; } - - private: - const int64_t chain_id_; - const std::vector regst_descs_; -}; - -class SharableMemBlockEdge final : public Edge { - public: - OF_DISALLOW_COPY_AND_MOVE(SharableMemBlockEdge); - SharableMemBlockEdge() = default; - ~SharableMemBlockEdge() = default; -}; - -class SharableMemBlockGraph final - : public Graph { - public: - OF_DISALLOW_COPY_AND_MOVE(SharableMemBlockGraph); - SharableMemBlockGraph(const PlanTaskGraph& plan_task_gph, - const std::function& IsSharable); - ~SharableMemBlockGraph() = default; - - void ForEachSourceNodeGroup( - const std::function& GroupBy, - const std::function&)>& Handler) const; -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_GRAPH_SHARABLE_MEM_BLOCK_GRAPH_H_ diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index e6d8853b191..8be777f7530 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -102,11 +102,6 @@ bool IsSpecialOpNotConsiderMergeInChain(const Operator* op) { return true; } } - // NOTE(chengcheng): ONLY nccl_use_compute_stream = false will exclude optimizer pass ops - if (!Singleton::Get()->nccl_use_compute_stream() - && IsOptimizerPassOp(op)) { - return true; - } return false; } @@ -569,6 +564,25 @@ void TaskGraph::SetOrderInGraphForEachNode() { } void TaskGraph::MergeChain() { + const OpGraph& op_graph = *Singleton::Get(); + + std::vector ordered_op_nodes; + HashMap op_node2global_order; + op_graph.TopoForEachNodeWithCtrlEdge([&](const OpNode* node) { + ordered_op_nodes.emplace_back(node); + op_node2global_order.emplace(node, ordered_op_nodes.size() - 1); + }); + + std::vector> subgraph_list; + FindAllConnectedSubgraphForGpuExecOrder(&subgraph_list, op_graph, ordered_op_nodes); + if (subgraph_list.size() == 0) { return Maybe::Ok(); } + + auto CmpOpNodeOrder = [&](const OpNode* lhs, const OpNode* rhs) { + return op_node2global_order.at(lhs) < op_node2global_order.at(rhs); + }; + + auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); + int64_t chain_id = 0; for (auto* this_node : ordered_task_nodes_) { // skip if this node has been set in a chain. From e08a79ad2dff61036594259e9c53bbe2ff2f3d1a Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Wed, 17 Aug 2022 11:21:33 +0000 Subject: [PATCH 03/66] half implement of LogicalChains --- oneflow/core/graph/task_graph.cpp | 50 +- oneflow/core/graph/task_node.cpp | 1 - .../insert_nccl_logical_op_pass.cpp | 1 - .../core/job_rewriter/logical_chain_pass.cpp | 726 ++++++++++++++++++ oneflow/core/operator/op_conf.proto | 2 + 5 files changed, 777 insertions(+), 3 deletions(-) create mode 100644 oneflow/core/job_rewriter/logical_chain_pass.cpp diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 8be777f7530..b9bb7f36aa2 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -563,9 +563,57 @@ void TaskGraph::SetOrderInGraphForEachNode() { TopoForEachNode(SetOrderInGraph); } +void GetLogicalChains(std::vector>* ret, + const OpGraph& op_graph, + const std::vector& order) { + HashSet visited; + + for (const OpNode* seed_node : order) { + if (visited.find(seed_node) != visited.end()) { continue; } + CHECK(visited.insert(seed_node).second); + const ParallelDesc& seed_parallel_desc = seed_node->parallel_desc(); + // NOTE(chengcheng): ONLY consider GPU op and parallel num > 1. + if (seed_parallel_desc.device_type() != DeviceType::kCUDA) { continue; } + if (seed_parallel_desc.parallel_num() <= 1) { continue; } + if (IsBreakpointOpNode(seed_node)) { continue; } + + HashSet this_subgraph; + std::queue queued_nodes; + + std::shared_ptr seed_time_shape = GetOpNodeTimeShape(seed_node); + queued_nodes.push(seed_node); + while (!queued_nodes.empty()) { + const OpNode* cur_node = queued_nodes.front(); + queued_nodes.pop(); + + CHECK(cur_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc)); + CHECK(this_subgraph.insert(cur_node).second); + + cur_node->ForEachNodeOnInOutEdge([&](const OpNode* next_node) { + if (visited.find(next_node) == visited.end() && (!IsBreakpointOpNode(next_node)) + && next_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc) + && SharedPtrShapeEqual(GetOpNodeTimeShape(next_node), seed_time_shape)) { + CHECK(visited.insert(next_node).second); + queued_nodes.push(next_node); + } + }); + } + + if (this_subgraph.size() > 1) { + ret->emplace_back(HashSet()); + ret->back().swap(this_subgraph); + } + } + + std::sort(ret->begin(), ret->end(), + [](const HashSet& lhs, const HashSet& rhs) { + return lhs.size() > rhs.size(); + }); +} + void TaskGraph::MergeChain() { const OpGraph& op_graph = *Singleton::Get(); - + std::vector ordered_op_nodes; HashMap op_node2global_order; op_graph.TopoForEachNodeWithCtrlEdge([&](const OpNode* node) { diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index fd129f73caa..65b741d47d7 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -204,7 +204,6 @@ bool TaskNode::IsMeaningLess() { return produced_regsts_.empty() && consumed_reg void TaskNode::ToProto(TaskProto* task_proto) const { // Step1: process some scalar items. - CHECK_NE(chain_id_, -1); task_proto->set_task_type(GetTaskType()); task_proto->set_machine_id(machine_id_); task_proto->set_thrd_id(thrd_id_); diff --git a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp index 20885c6633e..1679a8d3303 100644 --- a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp +++ b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp @@ -33,7 +33,6 @@ namespace oneflow { namespace { -// Do InsertNcclLogicalOpPass will use backward recomputation for sublinear memory cost. class InsertNcclLogicalOpPass final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(InsertNcclLogicalOpPass); diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp new file mode 100644 index 00000000000..9184479bc38 --- /dev/null +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -0,0 +1,726 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/nd_sbp.h" +#include "oneflow/core/framework/instructions_builder.h" +#include "oneflow/core/framework/sbp_infer_util.h" +#include "oneflow/core/job/scope.h" +#include "oneflow/core/job/sbp_parallel.h" +#include "oneflow/core/job/job.pb.h" +#include "oneflow/core/job/nd_sbp_util.h" +#include "oneflow/core/job_rewriter/job_pass.h" +#include "oneflow/core/job_rewriter/calculation_pass.h" +#include "oneflow/core/vm/vm_util.h" +#include "oneflow/core/vm/symbol_storage.h" +#include "oneflow/core/operator/operator.h" +#include "oneflow/core/common/env_var/env_var.h" + +namespace oneflow { + +DEFINE_ENV_BOOL(ENABLE_LOGICAL_CHAIN, true); + +namespace { + + +class LogicalChainPass final : public JobPass { + public: + OF_DISALLOW_COPY_AND_MOVE(LogicalChainPass); + LogicalChainPass() = default; + ~LogicalChainPass() = default; + + Maybe Apply(Job* job, JobPassCtx* ctx) const override { + if (!IsEnabled(*ctx)) { return Maybe::Ok(); } + const OpGraph op_graph(*job); + JobBuilder job_builder(job); + return Apply(op_graph, &job_builder); + } + + bool IsEnabled(const JobPassCtx& ctx) const { + return EnvBool(); + } + + Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; +}; + +bool IsBreakpointOpNode(const OpNode* node) { + // NOTE(chengcheng): breakpoint op is special which CANNOT merge in chain such as: + // variable, tick, repeat/acc/pack/unpack change timeshape + const Operator& op = node->op(); + const OperatorConf& op_conf = op.op_conf(); + + // TODO(chengcheng): filter ops which has special type + // TODO(chengcheng): get stream by op type + if (op_conf.has_variable_conf() /* varialbe */ + || op_conf.has_tick_conf() || op_conf.has_device_tick_conf() + || op_conf.has_src_subset_tick_conf() || op_conf.has_dst_subset_tick_conf() + || op_conf.has_source_tick_conf() || op_conf.has_sink_tick_conf() + || op_conf.has_acc_tick_conf() + || op_conf.has_critical_section_wait_tick_conf() + || op_conf.has_critical_section_callback_tick_conf() /* tick */ + || op_conf.has_input_conf() + || op_conf.has_output_conf() /* io */ + || op_conf.has_wait_and_send_ids_conf() + || op_conf.has_callback_notify_conf() /* ctrl */ + || op_conf.has_image_decoder_random_crop_resize_conf() /* gpu decode */) { + return true; + } + + if (op_conf.has_user_conf()) { + const std::string& user_type_name = op_conf.user_conf().op_type_name(); + // TODO(chengcheng): acc node can be merged in chain. + if (user_type_name == "repeat" || user_type_name == "acc" || user_type_name == "pack" + || user_type_name == "unpack" || user_type_name == "identity_buffer" + || user_type_name == "copy_h2d" || user_type_name == "copy_d2h") { + return true; + } + } + return false; +} + +bool IsAccOpNode(const OpNode* node) { + return node->op().op_conf().has_user_conf() + && node->op().op_conf().user_conf().op_type_name() == "acc"; +} + +bool IsRepeatOpNode(const OpNode* node) { + return node->op().op_conf().has_user_conf() + && node->op().op_conf().user_conf().op_type_name() == "repeat"; +} + +std::shared_ptr GetOpNodeFastestTimeShape(const OpNode* op_node) { + return CHECK_JUST(op_node->op().GetInputOutputFastestTimeShape()); +} + +std::shared_ptr GetOpNodeInputTimeShape(const OpNode* op_node) { + return CHECK_JUST(op_node->op().GetInputBlobFastestTimeShape()); +} + +bool SharedPtrShapeEqual(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return (*lhs) == (*rhs); +} + +void GetChainsWithTimeShape(std::vector>* ret, + const OpGraph& op_graph, + const std::vector& order, + const std::shared_ptr& seed_time_shape) { + HashSet visited; + for (const OpNode* seed_node : order) { + if (visited.find(seed_node) != visited.end()) { continue; } + CHECK(visited.insert(seed_node).second); + const ParallelDesc& seed_parallel_desc = seed_node->parallel_desc(); + // TODO(chengcheng): support cpu chain. + if (seed_parallel_desc.device_type() == DeviceType::kCPU) { continue; } + if (!SharedPtrShapeEqual(GetOpNodeFastestTimeShape(seed_node), seed_time_shape) { continue; } + if (IsBreakpointOpNode(seed_node)) { continue; } + + HashSet this_subgraph; + std::queue queued_nodes; + + queued_nodes.push(seed_node); + while (!queued_nodes.empty()) { + const OpNode* cur_node = queued_nodes.front(); + queued_nodes.pop(); + + CHECK(cur_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc)); + CHECK(this_subgraph.insert(cur_node).second); + + cur_node->ForEachNodeOnInOutEdge([&](const OpNode* next_node) { + if (visited.find(next_node) == visited.end() && (!IsBreakpointOpNode(next_node)) + && next_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc) + && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(next_node), seed_time_shape)) { + CHECK(visited.insert(next_node).second); + queued_nodes.push(next_node); + } + }); + } + + if (this_subgraph.size() > 1) { + ret->emplace_back(HashSet()); + ret->back().swap(this_subgraph); + } + } + + std::sort(ret->begin(), ret->end(), + [](const HashSet& lhs, const HashSet& rhs) { + return lhs.size() > rhs.size(); + }); +} + +struct LogicalChain { + int64_t logical_chain_id; + std::vector ordered_op_nodes; + int64_t begin_op_global_order; + int64_t end_op_global_order; + const OpNode* begin_op; + const OpNode* end_op; +}; + +struct PlacementLogicalChainsInfo { + std::vector> ordered_logical_chains; + std::vector ordered_acc_op_nodes; + const ParallelDesc* seed_parallel_desc; +}; + +std::string GenParallelConfKey(const ParallelConf& conf) { + std::string ret = conf.device_tag(); + for (const auto& name : conf.device_name()) { ret += ("-" + name); } + return ret; +} + +Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { + std::vector ordered_op_nodes; + HashMap op_node2global_order; + // TODO(chengcheng) : better order for memory. + std::shared_ptr seed_time_shape = std::make_shared(Shape({1, 1})); + op_graph.TopoForEachNodeWithCtrlEdge([&](const OpNode* node) { + ordered_op_nodes.emplace_back(node); + op_node2global_order.emplace(node, ordered_op_nodes.size() - 1); + std::shared_ptr this_time_shape = GetOpNodeFastestTimeShape(node); + if (this_time_shape->elem_cnt() > seed_time_shape->elem_cnt()) { + seed_time_shape = this_time_shape; + } + }); + + VLOG(2) << " seed time shape = " << seed_time_shape->ToString(); + + std::vector> logical_chains; + GetChainsWithTimeShape(&logical_chains, op_graph, ordered_op_nodes, seed_time_shape); + if (logical_chains.size() == 0) { return Maybe::Ok(); } + + int64_t logical_chain_id = 0; + auto NewLogicalChainId = [&]() { return logical_chain_id++}; + + auto CmpOpNodeOrder = [&](const OpNode* lhs, const OpNode* rhs) { + return op_node2global_order.at(lhs) < op_node2global_order.at(rhs); + }; + auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); + + HashMap placement2logical_chains; + for (const auto& origin_logical_chain : logical_chains) { + const OpNode* rand_node = *subgraph.begin(); + const ParallelDesc& this_parallel_desc = rand_node->parallel_desc(); + std::string key = GenParallelConfKey(this_parallel_desc.parallel_conf()); + const std::shared_ptr& this_time_shape = GetOpNodeFastestTimeShape(rand_node); + auto it = placement2subgraphs.find(key); + if (it == placement2subgraphs.end()) { + it = placement2subgraphs.emplace(key, PlacementNcclSubGraghsInfo()).first; + auto& info = it->second; + info.seed_parallel_desc = &this_parallel_desc; + info.seed_time_shape = this_time_shape; + info.ordered_subgraph.emplace_back(std::make_shared()); + InitInsertNcclSubGraphInfoFromSet(info.ordered_subgraph.back(), subgraph, + op_node2global_order, CmpOpNodeOrder); + } else { + auto& info = it->second; + if (SharedPtrShapeEqual(info.seed_time_shape, this_time_shape)) { + CHECK(this_parallel_desc.EqualsIgnoringHierarchy(*info.seed_parallel_desc)); + std::shared_ptr nccl_subgraph_info = + std::make_shared(); + InitInsertNcclSubGraphInfoFromSet(nccl_subgraph_info, subgraph, op_node2global_order, + CmpOpNodeOrder); + CHECK_GT(info.ordered_subgraph.size(), 0); + const auto& first_graph = info.ordered_subgraph.front(); + const auto& last_graph = info.ordered_subgraph.back(); + int64_t first_order = first_graph->begin_op_global_order; + int64_t last_order = last_graph->end_op_global_order; + if (nccl_subgraph_info->end_op_global_order < first_order) { + if (IsReachable(nccl_subgraph_info->end_op->op().op_name(), + first_graph->begin_op->op().op_name())) { + info.ordered_subgraph.insert(info.ordered_subgraph.begin(), nccl_subgraph_info); + } + } else if (nccl_subgraph_info->begin_op_global_order > last_order) { + if (IsReachable(last_graph->end_op->op().op_name(), + nccl_subgraph_info->begin_op->op().op_name())) { + info.ordered_subgraph.emplace_back(nccl_subgraph_info); + } + } else { + auto before = info.ordered_subgraph.begin(); + auto next = before + 1; + while (next != info.ordered_subgraph.end()) { + if ((*before)->end_op_global_order < nccl_subgraph_info->begin_op_global_order + && nccl_subgraph_info->end_op_global_order < (*next)->begin_op_global_order) { + if (IsReachable((*before)->end_op->op().op_name(), + nccl_subgraph_info->begin_op->op().op_name()) + && IsReachable(nccl_subgraph_info->end_op->op().op_name(), + (*next)->begin_op->op().op_name())) { + info.ordered_subgraph.insert(next, nccl_subgraph_info); + } + break; + } + before = next; + next++; + } + } + } + } + } + + for (const OpNode* this_node : ordered_op_nodes) { + if (IsAccOpNode(this_node)) { + const ParallelDesc& this_parallel_desc = this_node->parallel_desc(); + std::string key = GenParallelConfKey(this_parallel_desc.parallel_conf()); + auto it = placement2subgraphs.find(key); + if (it != placement2subgraphs.end()) { + it->second.ordered_acc_op_nodes.emplace_back(this_node); + } + } + } + + for (auto& pair : placement2subgraphs) { + PlacementNcclSubGraghsInfo& info = pair.second; + for (int i = 0; i < info.ordered_subgraph.size() - 1; i++) { + CHECK_LT(info.ordered_subgraph.at(i)->end_op_global_order, + info.ordered_subgraph.at(i + 1)->begin_op_global_order); + } + + // NOTE(chengcheng): insert nccl ops for each subgraph + uint32_t stream_offset = 0; + int64_t total_op_num = 0; + for (int i = 0; i < info.ordered_subgraph.size(); i++) { + auto& ordered_op_nodes = info.ordered_subgraph.at(i)->ordered_op_nodes; + InsertNcclLogicalOpsInSubGraph(op_graph, job_builder, ordered_op_nodes, IsReachable, i, + &stream_offset); + total_op_num += ordered_op_nodes.size(); + } + if (stream_offset >= 2 && total_op_num >= 1000) { + LOG(WARNING) << " In Graph: " << job_builder->job().job_conf().job_name() + << " Placement: " << pair.first << " the total_op_num = " << total_op_num + << " and has " << stream_offset + << " different nccl stream which is possible to trigger cuda stream kernel " + "launch upper limit." + << " So the nccl logical kernel will from async to sync exec, which may affect " + "performance."; + EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); + comm_mgr->SetAsyncLaunchNcclLogicalKernel(false); + } + + // NOTE(chengcheng): insert acc for all subgraph with same placement group + const OpNode* bw_sink_op = info.ordered_subgraph.back()->end_op; + const std::vector& ordered_acc_op_nodes = info.ordered_acc_op_nodes; + + if (!ordered_acc_op_nodes.empty()) { + InsertBwSinkAccTickAndNcclLogicalOpsInPlacementGroupAfterAcc( + op_graph, job_builder, ordered_acc_op_nodes, op_node2global_order, bw_sink_op); + } + } + + return Maybe::Ok(); +} + + + + +bool IsOpEdgeAllowInsertNccl(const OpEdge* edge, + const std::shared_ptr& seed_time_shape) { + const OpNode* src_node = edge->src_node(); + const OpNode* dst_node = edge->dst_node(); + const ParallelDesc& src_parallel_desc = src_node->parallel_desc(); + return src_parallel_desc.device_type() == DeviceType::kCUDA + && src_parallel_desc.parallel_num() > 1 + && src_parallel_desc.EqualsIgnoringHierarchy(dst_node->parallel_desc()) + && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(src_node), seed_time_shape) + && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(dst_node), seed_time_shape); +} + +struct InsertedNcclInfo { + OperatorConf nccl_op_conf; + ParallelConf nccl_parallel_conf; + int64_t order; + const OpNode* src_node; + const OpNode* dst_node; + std::string debug_str; +}; + +void InsertNcclLogicalOpsAfterAcc(const OpGraph& op_graph, + const HashMap& op_node2global_order, + const std::vector& ordered_acc_op_nodes, + const std::string& bw_sink_tick_op_name, + HashMap* mut_consumer_name2op, + std::vector* nccl_op_confs, + std::vector* nccl_op_parallel_confs) { + HashSet visited; + std::shared_ptr seed_time_shape = GetOpNodeFastestTimeShape(ordered_acc_op_nodes.front()); + std::vector nccl_op_infos; + + std::vector ordered_after_acc_subgraph; + // NOTE(chengcheng): bfs for op_edge may create duplicated node. + HashSet after_acc_subgraph_nodes; + HashMap op2subgraph_order; + + for (const OpNode* acc : ordered_acc_op_nodes) { + std::queue queued_edges; + for (const OpEdge* op_edge : acc->out_edges()) { + if (visited.find(op_edge) == visited.end() + && IsOpEdgeAllowInsertNccl(op_edge, seed_time_shape)) { + queued_edges.push(op_edge); + CHECK(visited.insert(op_edge).second); + if (!IsAccOpNode(op_edge->dst_node())) { + after_acc_subgraph_nodes.insert(op_edge->dst_node()); + } + } + } + + auto NextEdgeNode2AfterAccSubGraph = [&](const OpEdge* next_edge, const OpNode* next_node) { + if (visited.find(next_edge) == visited.end() + && IsOpEdgeAllowInsertNccl(next_edge, seed_time_shape)) { + CHECK(visited.insert(next_edge).second); + queued_edges.push(next_edge); + if (!IsAccOpNode(next_node)) { after_acc_subgraph_nodes.insert(next_node); } + } + }; + + // bfs search each edge after acc allow insert nccl. try insert. + while (!queued_edges.empty()) { + const OpEdge* op_edge = queued_edges.front(); + queued_edges.pop(); + + for (const LogicalBlobId& lbi : op_edge->lbis()) { + const OpNode* src_node = op_edge->src_node(); + const OpNode* dst_node = op_edge->dst_node(); + const std::string& src_op_name = src_node->op().op_name(); + const std::string& dst_op_name = dst_node->op().op_name(); + OperatorConf nccl_op; + ParallelDesc src_reduced_parallel_desc = op_edge->src_node()->parallel_desc(); + ParallelDesc dst_reduced_parallel_desc = op_edge->dst_node()->parallel_desc(); + NdSbp src_reduced_nd_sbp; + NdSbp dst_reduced_nd_sbp; + if (!TryBuildNcclLogicalOpConf(&nccl_op, op_edge->src_node(), op_edge->dst_node(), lbi, + &src_reduced_parallel_desc, &dst_reduced_parallel_desc, + &src_reduced_nd_sbp, &dst_reduced_nd_sbp)) { + continue; + } + auto it = mut_consumer_name2op->find(dst_op_name); + if (it == mut_consumer_name2op->end()) { + auto ret_pair = mut_consumer_name2op->emplace(dst_op_name, dst_node->op().op_conf()); + CHECK(ret_pair.second); + it = ret_pair.first; + } + // insert nccl op + user_op::UserOpConfWrapper nccl_op_wrapper(nccl_op); + for (const std::string& ibn : op_edge->lbi2ibns().at(lbi)) { + std::string old_lbn = ReplaceInputLbnInOpCustomizedConf(&(it->second), ibn, + nccl_op_wrapper.output("out", 0)); + } + + InsertedNcclInfo nccl_op_info; + nccl_op_info.nccl_op_conf = nccl_op; + nccl_op_info.nccl_parallel_conf = src_reduced_parallel_desc.parallel_conf(); + nccl_op_info.order = op_node2global_order.at(src_node); + nccl_op_info.src_node = src_node; + nccl_op_info.dst_node = dst_node; + nccl_op_info.debug_str = + (" After ACC insert nccl op: " + nccl_op.name() + " from [" + src_op_name + + ", sbp=" + NdSbpToString(src_node->NdSbp4Lbi(lbi)) + "] to [" + dst_op_name + + ", sbp=" + NdSbpToString(dst_node->NdSbp4Lbi(lbi)) + + ", src_order=" + std::to_string(nccl_op_info.order) + "]\n"); + nccl_op_infos.emplace_back(nccl_op_info); + } + + // NOTE(chengcheng): BFS for all edges and nodes after acc. + for (const OpEdge* dst_node_out_edge : op_edge->dst_node()->out_edges()) { + NextEdgeNode2AfterAccSubGraph(dst_node_out_edge, dst_node_out_edge->dst_node()); + } + for (const OpEdge* dst_node_in_edge : op_edge->dst_node()->in_edges()) { + NextEdgeNode2AfterAccSubGraph(dst_node_in_edge, dst_node_in_edge->src_node()); + } + for (const OpEdge* src_node_out_edge : op_edge->src_node()->out_edges()) { + NextEdgeNode2AfterAccSubGraph(src_node_out_edge, src_node_out_edge->dst_node()); + } + for (const OpEdge* src_node_in_edge : op_edge->src_node()->in_edges()) { + NextEdgeNode2AfterAccSubGraph(src_node_in_edge, src_node_in_edge->src_node()); + } + } + } + + for (const auto* node : after_acc_subgraph_nodes) { ordered_after_acc_subgraph.push_back(node); } + + CHECK_EQ(after_acc_subgraph_nodes.size(), ordered_after_acc_subgraph.size()); + + std::sort(nccl_op_infos.begin(), nccl_op_infos.end(), + [](const InsertedNcclInfo& lhs, const InsertedNcclInfo& rhs) { + return lhs.order < rhs.order; + }); + + std::sort(ordered_after_acc_subgraph.begin(), ordered_after_acc_subgraph.end(), + [&](const OpNode* lhs, const OpNode* rhs) { + return op_node2global_order.at(lhs) < op_node2global_order.at(rhs); + }); + + auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); + + for (int64_t i = 0; i < ordered_after_acc_subgraph.size(); ++i) { + op2subgraph_order.emplace(ordered_after_acc_subgraph.at(i), i); + } + + for (int64_t i = 1; i < ordered_after_acc_subgraph.size(); ++i) { + const OpNode* this_node = ordered_after_acc_subgraph.at(i); + const OpNode* pre_node = ordered_after_acc_subgraph.at(i - 1); + const std::string& this_op_name = this_node->op().op_name(); + const std::string& pre_op_name = pre_node->op().op_name(); + // build ctrl edge if need. + if (!IsReachable(pre_op_name, this_op_name)) { + auto it = mut_consumer_name2op->find(this_op_name); + if (it == mut_consumer_name2op->end()) { + auto ret_pair = mut_consumer_name2op->emplace(this_op_name, this_node->op().op_conf()); + CHECK(ret_pair.second); + it = ret_pair.first; + } + OperatorConf* mut_op_conf = &(it->second); + mut_op_conf->add_ctrl_in_op_name(pre_op_name); + } + } + + for (int64_t i = 0; i < nccl_op_infos.size(); ++i) { + auto& info = nccl_op_infos.at(i); + if (i == 0) { + info.nccl_op_conf.add_ctrl_in_op_name(bw_sink_tick_op_name); + } else { + info.nccl_op_conf.add_ctrl_in_op_name(nccl_op_infos.at(i - 1).nccl_op_conf.name()); + } + + nccl_op_confs->emplace_back(info.nccl_op_conf); + nccl_op_parallel_confs->emplace_back(info.nccl_parallel_conf); + VLOG(3) << info.debug_str; + + // NOTE(chengcheng): Try add ctrl between nccl and src op next node for strict exec order. + auto src_op_it = op2subgraph_order.find(info.src_node); + if (src_op_it != op2subgraph_order.end()) { + const int64_t src_sub_order = src_op_it->second; + const int64_t next_sub_order = src_sub_order + 1; + if (next_sub_order < ordered_after_acc_subgraph.size()) { + const OpNode* next_op = ordered_after_acc_subgraph.at(next_sub_order); + const std::string& next_op_name = next_op->op().op_name(); + const std::string& dst_op_name = info.dst_node->op().op_name(); + if (next_op_name != dst_op_name) { + if (mut_consumer_name2op->find(next_op_name) == mut_consumer_name2op->end()) { + CHECK(mut_consumer_name2op->emplace(next_op_name, next_op->op().op_conf()).second); + } + // NOTE(chengcheng): MUST add ctrl edge for strict exec orde + mut_consumer_name2op->at(next_op_name).add_ctrl_in_op_name(info.nccl_op_conf.name()); + } + } + } + } +} + +struct InsertNcclSubGraph { + std::vector ordered_op_nodes; + int64_t begin_op_global_order; + int64_t end_op_global_order; + const OpNode* begin_op; + const OpNode* end_op; +}; + +struct PlacementNcclSubGraghsInfo { + std::vector> ordered_subgraph; + std::vector ordered_acc_op_nodes; + const ParallelDesc* seed_parallel_desc; + std::shared_ptr seed_time_shape; +}; + +void InitInsertNcclSubGraphInfoFromSet( + std::shared_ptr nccl_subgraph_info, const HashSet& subgraph, + const HashMap& op_node2global_order, + const std::function& CmpOpNodeOrder) { + auto* subgraph_ordered_nodes = &nccl_subgraph_info->ordered_op_nodes; + subgraph_ordered_nodes->assign(subgraph.begin(), subgraph.end()); + std::sort(subgraph_ordered_nodes->begin(), subgraph_ordered_nodes->end(), CmpOpNodeOrder); + nccl_subgraph_info->begin_op = subgraph_ordered_nodes->front(); + nccl_subgraph_info->end_op = subgraph_ordered_nodes->back(); + nccl_subgraph_info->begin_op_global_order = op_node2global_order.at(nccl_subgraph_info->begin_op); + nccl_subgraph_info->end_op_global_order = op_node2global_order.at(nccl_subgraph_info->end_op); + CHECK(nccl_subgraph_info->begin_op != nccl_subgraph_info->end_op); + CHECK_LT(nccl_subgraph_info->begin_op_global_order, nccl_subgraph_info->end_op_global_order); +} + +constexpr uint32_t kMaxNcclComputeStreamCount = 8; + +std::string GetStreamIndexName(uint32_t id) { return "NCCL_COMPUTE_" + std::to_string(id); } + +void InsertNcclLogicalOpsInSubGraph( + const OpGraph& op_graph, JobBuilder* job_builder, + const std::vector& subgraph_order, + const std::function& IsReachable, + const int32_t subgraph_id_in_same_placement_group, uint32_t* stream_offset) { + HashMap node2subgraph_order; + node2subgraph_order.reserve(subgraph_order.size()); + for (int64_t i = 0; i < subgraph_order.size(); ++i) { + CHECK(node2subgraph_order.emplace(subgraph_order.at(i), i).second); + } + + if (Singleton::Get()->enable_debug_mode()) { + VLOG(3) << " Try insert nccl logical ops into job: " << job_builder->job().job_conf().job_name() + << ". Begin...\n"; + } + + HashSet mut_op_names; + const OpNode* first_node = subgraph_order.at(0); + HashMap subgraph_op_name2conf; + subgraph_op_name2conf.emplace(first_node->op().op_name(), first_node->op().op_conf()); + + // add ctrl for strict order. + for (int64_t i = 1; i < subgraph_order.size(); ++i) { + const OpNode* this_node = subgraph_order.at(i); + const OpNode* pre_node = subgraph_order.at(i - 1); + const std::string& this_op_name = this_node->op().op_name(); + const std::string& pre_op_name = pre_node->op().op_name(); + CHECK(subgraph_op_name2conf.emplace(this_op_name, this_node->op().op_conf()).second); + // build ctrl edge if need. + if (!IsReachable(pre_op_name, this_op_name)) { + subgraph_op_name2conf.at(this_op_name).add_ctrl_in_op_name(pre_op_name); + mut_op_names.insert(this_op_name); + } + } + + std::vector nccl_op_confs; + std::vector nccl_op_parallel_confs; + // NOTE(chengcheng): ONLY support insert nccl to dst for memory. + InsertNcclLogicalOpsAsCloseAsPossibleToDstNode(&subgraph_op_name2conf, &mut_op_names, + &nccl_op_confs, &nccl_op_parallel_confs, + subgraph_order, node2subgraph_order); + + if (Singleton::Get()->enable_debug_mode()) { + VLOG(3) << " Try insert nccl logical ops into job: " << job_builder->job().job_conf().job_name() + << ". ...End\n\n"; + } + + // NOTE(chengcheng): For NCCL logical correct exec order in pipeline multi-subgraph. + do { + if (nccl_op_confs.empty()) { break; } + int64_t nccl_compute_stream_id = *stream_offset; + if (nccl_compute_stream_id >= kMaxNcclComputeStreamCount) { + break; // NOTE(chengcheng): ONLY support kMaxNcclComputeStreamCount insert nccl subgraphs. + } + std::string stream_index_name = GetStreamIndexName(nccl_compute_stream_id); + + // NOTE(chengcheng): set ALL subgraph op and ALL nccl op stream index. + for (auto& pair : subgraph_op_name2conf) { + mut_op_names.insert(pair.first); + pair.second.set_stream_name_hint(stream_index_name); + } + for (auto& nccl_op : nccl_op_confs) { nccl_op.set_stream_name_hint(stream_index_name); } + (*stream_offset)++; + } while (false); + + std::vector mut_op_confs; + mut_op_confs.reserve(mut_op_names.size()); + for (const std::string& mut_op_name : mut_op_names) { + mut_op_confs.emplace_back(subgraph_op_name2conf.at(mut_op_name)); + } + job_builder->MutOpsOnlyOnce(mut_op_confs); + + CHECK_EQ(nccl_op_confs.size(), nccl_op_parallel_confs.size()); + for (int64_t i = 0; i < nccl_op_confs.size(); ++i) { + CHECK_JUST(job_builder->AddOp(nccl_op_parallel_confs.at(i), nccl_op_confs.at(i))); + } +} + +void InsertBwSinkAccTickAndNcclLogicalOpsInPlacementGroupAfterAcc( + const OpGraph& op_graph, JobBuilder* job_builder, + const std::vector& ordered_acc_op_nodes, + const HashMap& op_node2global_order, const OpNode* bw_sink_op) { + const OpNode* first_acc_op = ordered_acc_op_nodes.front(); + std::shared_ptr time_shape_before_acc = GetOpNodeFastestTimeShape(bw_sink_op); + std::shared_ptr time_shape_after_acc = GetOpNodeFastestTimeShape(first_acc_op); + VLOG(3) << " Find acc ops (num=" << ordered_acc_op_nodes.size() + << ") in Job: " << job_builder->job().job_conf().job_name() + << ", we will try insert special identity and ctrl for " + << " UNSAFE handle ALL nccl ops between different time shape: " + << time_shape_before_acc->DebugStr() << "->acc->" << time_shape_after_acc->DebugStr() + << "\n\n"; + CHECK_GT(time_shape_before_acc->elem_cnt(), time_shape_after_acc->elem_cnt()); + CHECK_EQ(time_shape_before_acc->elem_cnt() % time_shape_after_acc->elem_cnt(), 0); + + for (const OpNode* acc : ordered_acc_op_nodes) { + CHECK(SharedPtrShapeEqual(time_shape_before_acc, GetOpNodeInputTimeShape(acc))); + CHECK(SharedPtrShapeEqual(time_shape_after_acc, GetOpNodeFastestTimeShape(acc))); + } + + // NOTE(chengcheng): insert acc_tick after bw_sink_op, and this tick op conf will control + // after_acc_nccl_ops start. + const auto& obns = bw_sink_op->op().output_bns(); + CHECK(!obns.empty()); + const std::string bw_sink_op_out_lbn = + GenLogicalBlobName(bw_sink_op->op().BnInOp2Lbi(obns.Get(0))); + VLOG(3) << " bw_sink_op : " << bw_sink_op->op().op_conf().DebugString(); + + user_op::UserOpConfWrapper cast_to_tick_op = + user_op::UserOpConfWrapperBuilder("System-CastToTick-" + NewUniqueId()) + .OpTypeName("cast_to_tick") + .Input("in", bw_sink_op_out_lbn) + .Output("out") + .Build(); + + OperatorConf bw_sink_acc_tick_conf; + bw_sink_acc_tick_conf.set_name(std::string("System-BwSinkTick-AccTick_") + NewUniqueId()); + auto* acc_conf = bw_sink_acc_tick_conf.mutable_acc_tick_conf(); + acc_conf->set_one(cast_to_tick_op.output("out", 0)); + acc_conf->set_acc("acc"); + acc_conf->set_max_acc_num(time_shape_before_acc->elem_cnt() / time_shape_after_acc->elem_cnt()); + + OperatorConf bw_sink_final_tick_conf; + bw_sink_final_tick_conf.set_name(std::string("System-BwSinkFinalTick-DeviceTick_") + + NewUniqueId()); + auto* tick_conf = bw_sink_final_tick_conf.mutable_device_tick_conf(); + tick_conf->add_tick(GenLogicalBlobName(bw_sink_acc_tick_conf.name(), "acc")); + tick_conf->set_out("out"); + + // insert nccl ops after acc + std::vector after_acc_nccl_op_confs; + std::vector after_acc_nccl_parallel_confs; + HashMap mut_consumer_name2op; + + InsertNcclLogicalOpsAfterAcc(op_graph, op_node2global_order, ordered_acc_op_nodes, + bw_sink_final_tick_conf.name(), &mut_consumer_name2op, + &after_acc_nccl_op_confs, &after_acc_nccl_parallel_confs); + + if (after_acc_nccl_op_confs.empty()) { + CHECK(after_acc_nccl_parallel_confs.empty()); + CHECK(mut_consumer_name2op.empty()); + } else { + // insert bw sink acc tick ops + CHECK_JUST( + job_builder->AddOp(bw_sink_op->parallel_desc().parallel_conf(), cast_to_tick_op.op_conf())); + VLOG(3) << " Insert cast_to_tick_op : " << cast_to_tick_op.op_conf().DebugString(); + + CHECK_JUST( + job_builder->AddOp(bw_sink_op->parallel_desc().parallel_conf(), bw_sink_acc_tick_conf)); + VLOG(3) << " Insert bw_sink_acc_tick_op : " << bw_sink_acc_tick_conf.DebugString(); + + CHECK_JUST( + job_builder->AddOp(bw_sink_op->parallel_desc().parallel_conf(), bw_sink_final_tick_conf)); + VLOG(3) << " Insert bw_sink_final_tick_op : " << bw_sink_final_tick_conf.DebugString(); + + // insert nccl ops after acc + for (const auto& pair : mut_consumer_name2op) { + CHECK_JUST(job_builder->MutOpOnlyOnce(pair.second)); + } + CHECK_EQ(after_acc_nccl_op_confs.size(), after_acc_nccl_parallel_confs.size()); + for (int64_t i = 0; i < after_acc_nccl_op_confs.size(); ++i) { + CHECK_JUST( + job_builder->AddOp(after_acc_nccl_parallel_confs.at(i), after_acc_nccl_op_confs.at(i))); + } + } +} + + + +} // namespace + +REGISTER_JOB_PASS("LogicalChainPass", LogicalChainPass); + +} // namespace oneflow diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index 1daf645e55a..cd77ae32f70 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -397,6 +397,8 @@ message OperatorConf { optional string stream_name_hint = 9; optional string pass_tag = 10; optional string loc = 11 [default = ""]; + optional int64 logical_chain_id = 12 [default = -1]; + optional int64 logical_order = 13 [default = -1]; oneof op_type { // system op CopyCommNetOpConf copy_comm_net_conf = 106; From d9a5c8201a0c83b226607cfba8c5976ecd40b82f Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Thu, 18 Aug 2022 09:16:40 +0000 Subject: [PATCH 04/66] part-0 : Logical Chain --- oneflow/core/graph/task_graph.cpp | 56 +- .../core/job/intra_job_mem_sharing_util.cpp | 1 + .../core/job_rewriter/logical_chain_pass.cpp | 628 ++++-------------- 3 files changed, 141 insertions(+), 544 deletions(-) diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index b9bb7f36aa2..3aa281713c3 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -563,9 +563,8 @@ void TaskGraph::SetOrderInGraphForEachNode() { TopoForEachNode(SetOrderInGraph); } -void GetLogicalChains(std::vector>* ret, - const OpGraph& op_graph, - const std::vector& order) { +void GetLogicalChains(std::vector>* ret, const OpGraph& op_graph, + const std::vector& order) { HashSet visited; for (const OpNode* seed_node : order) { @@ -612,40 +611,29 @@ void GetLogicalChains(std::vector>* ret, } void TaskGraph::MergeChain() { - const OpGraph& op_graph = *Singleton::Get(); - - std::vector ordered_op_nodes; - HashMap op_node2global_order; - op_graph.TopoForEachNodeWithCtrlEdge([&](const OpNode* node) { - ordered_op_nodes.emplace_back(node); - op_node2global_order.emplace(node, ordered_op_nodes.size() - 1); - }); - - std::vector> subgraph_list; - FindAllConnectedSubgraphForGpuExecOrder(&subgraph_list, op_graph, ordered_op_nodes); - if (subgraph_list.size() == 0) { return Maybe::Ok(); } - - auto CmpOpNodeOrder = [&](const OpNode* lhs, const OpNode* rhs) { - return op_node2global_order.at(lhs) < op_node2global_order.at(rhs); - }; - - auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); - - int64_t chain_id = 0; - for (auto* this_node : ordered_task_nodes_) { - // skip if this node has been set in a chain. - if (this_node->chain_id() != -1) { continue; } - - CHECK_EQ(this_node->chain_id(), -1); - if (CanBeMergedInChain(this_node)) { - TraverseConnectedSubGraphMergeInThisChain(this_node, chain_id); - } else { - this_node->set_chain_id(chain_id); + if (EnvBool()) { + for (auto* this_node : ordered_task_nodes_) { + const auto* comp_node = dynamic_cast(node); + const int64_t logical_chain_id = comp_node->op()->op_conf().logical_chain_id(); + if (logical_chain_id != -1) { this_node->set_chain_id(logical_chain_id); } } + } else { + int64_t chain_id = 0; + for (auto* this_node : ordered_task_nodes_) { + // skip if this node has been set in a chain. + if (this_node->chain_id() != -1) { continue; } + + CHECK_EQ(this_node->chain_id(), -1); + if (CanBeMergedInChain(this_node)) { + TraverseConnectedSubGraphMergeInThisChain(this_node, chain_id); + } else { + this_node->set_chain_id(chain_id); + } - ++chain_id; + ++chain_id; + } + for (auto* node : ordered_task_nodes_) { CHECK_NE(node->chain_id(), -1); } } - for (auto* node : ordered_task_nodes_) { CHECK_NE(node->chain_id(), -1); } } void TaskGraph::BuildCtrlRegstDescInSameChain() { diff --git a/oneflow/core/job/intra_job_mem_sharing_util.cpp b/oneflow/core/job/intra_job_mem_sharing_util.cpp index 1ad9657bf83..f1a2fef5b31 100644 --- a/oneflow/core/job/intra_job_mem_sharing_util.cpp +++ b/oneflow/core/job/intra_job_mem_sharing_util.cpp @@ -99,6 +99,7 @@ void InitMemoryChains(Plan* plan, DeviceType device_type = stream_id.device_id().device_type(); // TODO(zwx): eliminate this special 'is cpu' determine if (device_type == DeviceType::kCPU) { continue; } + if (task->task_set_info().chain_id() == -1) { continue; } int64_t device_id = stream_id.device_id().device_index(); int64_t device_unique_id = GenDeviceUniqueId(machine_id, device_id); MemoryChain* mem_chain = diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index 9184479bc38..16e8341a031 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -34,7 +34,6 @@ DEFINE_ENV_BOOL(ENABLE_LOGICAL_CHAIN, true); namespace { - class LogicalChainPass final : public JobPass { public: OF_DISALLOW_COPY_AND_MOVE(LogicalChainPass); @@ -48,9 +47,7 @@ class LogicalChainPass final : public JobPass { return Apply(op_graph, &job_builder); } - bool IsEnabled(const JobPassCtx& ctx) const { - return EnvBool(); - } + bool IsEnabled(const JobPassCtx& ctx) const { return EnvBool(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; }; @@ -60,22 +57,19 @@ bool IsBreakpointOpNode(const OpNode* node) { // variable, tick, repeat/acc/pack/unpack change timeshape const Operator& op = node->op(); const OperatorConf& op_conf = op.op_conf(); - - // TODO(chengcheng): filter ops which has special type + + // TODO(chengcheng): filter ops which has special type // TODO(chengcheng): get stream by op type - if (op_conf.has_variable_conf() /* varialbe */ - || op_conf.has_tick_conf() || op_conf.has_device_tick_conf() - || op_conf.has_src_subset_tick_conf() || op_conf.has_dst_subset_tick_conf() - || op_conf.has_source_tick_conf() || op_conf.has_sink_tick_conf() - || op_conf.has_acc_tick_conf() - || op_conf.has_critical_section_wait_tick_conf() - || op_conf.has_critical_section_callback_tick_conf() /* tick */ - || op_conf.has_input_conf() - || op_conf.has_output_conf() /* io */ - || op_conf.has_wait_and_send_ids_conf() - || op_conf.has_callback_notify_conf() /* ctrl */ - || op_conf.has_image_decoder_random_crop_resize_conf() /* gpu decode */) { - return true; + if (op_conf.has_variable_conf() /* varialbe */ + || op_conf.has_tick_conf() || op_conf.has_device_tick_conf() + || op_conf.has_src_subset_tick_conf() || op_conf.has_dst_subset_tick_conf() + || op_conf.has_source_tick_conf() || op_conf.has_sink_tick_conf() + || op_conf.has_acc_tick_conf() || op_conf.has_critical_section_wait_tick_conf() + || op_conf.has_critical_section_callback_tick_conf() /* tick */ + || op_conf.has_input_conf() || op_conf.has_output_conf() /* io */ + || op_conf.has_wait_and_send_ids_conf() || op_conf.has_callback_notify_conf() /* ctrl */ + || op_conf.has_image_decoder_random_crop_resize_conf() /* gpu decode */) { + return true; } if (op_conf.has_user_conf()) { @@ -113,10 +107,21 @@ bool SharedPtrShapeEqual(const std::shared_ptr& lhs, return (*lhs) == (*rhs); } -void GetChainsWithTimeShape(std::vector>* ret, - const OpGraph& op_graph, - const std::vector& order, - const std::shared_ptr& seed_time_shape) { +bool NeedInsertBoxingBetweenOpNodes(const OpNode* a_node, const OpNode* b_node, + const OpEdge* edge) { + CHECK(a_node != b_node && (edge->src_node() == a_node || edge->src_node() == b_node) + && (edge->dst_node() == a_node || edge->dst_node() == b_node)); + if (a_node->parallel_desc().parallel_num() > 1) { + for (const auto& lbi : edge->lbis()) { + if (a_node->NdSbp4Lbi(lbi) != b_node->NdSbp4Lbi(lbi)) { return true; } + } + } + return false; +} + +void GetLogicalChainsWithTimeShape(std::vector>* ret, + const OpGraph& op_graph, const std::vector& order, + const std::shared_ptr& seed_time_shape) { HashSet visited; for (const OpNode* seed_node : order) { if (visited.find(seed_node) != visited.end()) { continue; } @@ -124,7 +129,7 @@ void GetChainsWithTimeShape(std::vector>* ret, const ParallelDesc& seed_parallel_desc = seed_node->parallel_desc(); // TODO(chengcheng): support cpu chain. if (seed_parallel_desc.device_type() == DeviceType::kCPU) { continue; } - if (!SharedPtrShapeEqual(GetOpNodeFastestTimeShape(seed_node), seed_time_shape) { continue; } + if (!SharedPtrShapeEqual(GetOpNodeFastestTimeShape(seed_node), seed_time_shape)) { continue; } if (IsBreakpointOpNode(seed_node)) { continue; } HashSet this_subgraph; @@ -138,14 +143,23 @@ void GetChainsWithTimeShape(std::vector>* ret, CHECK(cur_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc)); CHECK(this_subgraph.insert(cur_node).second); - cur_node->ForEachNodeOnInOutEdge([&](const OpNode* next_node) { + auto SearchToNextNode = [&](const OpNode* cur_node, const OpNode* next_node, + const OpEdge* edge) { if (visited.find(next_node) == visited.end() && (!IsBreakpointOpNode(next_node)) && next_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc) - && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(next_node), seed_time_shape)) { + && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(next_node), seed_time_shape) + && !NeedInsertBoxingBetweenOpNodes(cur_node, next_node, edge)) { CHECK(visited.insert(next_node).second); queued_nodes.push(next_node); } - }); + }; + + for (const OpEdge* in_edge : cur_node->in_edges()) { + SearchToNextNode(cur_node, in_edge->src_node(), in_edge); + } + for (const OpEdge* out_edge : cur_node->out_edges()) { + SearchToNextNode(cur_node, out_edge->dst_node(), out_edge); + } } if (this_subgraph.size() > 1) { @@ -154,10 +168,12 @@ void GetChainsWithTimeShape(std::vector>* ret, } } + /* std::sort(ret->begin(), ret->end(), [](const HashSet& lhs, const HashSet& rhs) { return lhs.size() > rhs.size(); }); + */ } struct LogicalChain { @@ -167,12 +183,20 @@ struct LogicalChain { int64_t end_op_global_order; const OpNode* begin_op; const OpNode* end_op; + LogicalChain() + : logical_chain_id(-1), + begin_op_global_order(-1), + end_op_global_order(-1), + begin_op(nullptr), + end_op(nullptr) {} }; struct PlacementLogicalChainsInfo { std::vector> ordered_logical_chains; std::vector ordered_acc_op_nodes; + std::shared_ptr after_acc_logical_chain; const ParallelDesc* seed_parallel_desc; + PlacementLogicalChainsInfo() : seed_parallel_desc(nullptr) {} }; std::string GenParallelConfKey(const ParallelConf& conf) { @@ -181,9 +205,35 @@ std::string GenParallelConfKey(const ParallelConf& conf) { return ret; } +void InitPlacementLogicalChainsInfoFromSet( + const std::shared_ptr& logical_chain, + const HashSet& origin_logical_chain, + const HashMap& op_node2global_order, + const std::function& CmpOpNodeOrder) { + auto* logical_chain_ordered_nodes = &logical_chain->ordered_op_nodes; + CHECK(logical_chain_ordered_nodes->empty()); + logical_chain_ordered_nodes->assign(origin_logical_chain.begin(), origin_logical_chain.end()); + std::sort(logical_chain_ordered_nodes->begin(), logical_chain_ordered_nodes->end(), + CmpOpNodeOrder); + logical_chain->begin_op = logical_chain_ordered_nodes->front(); + logical_chain->end_op = logical_chain_ordered_nodes->back(); + logical_chain->begin_op_global_order = op_node2global_order.at(logical_chain->begin_op); + logical_chain->end_op_global_order = op_node2global_order.at(logical_chain->end_op); + CHECK(logical_chain->begin_op != logical_chain->end_op); + CHECK_LT(logical_chain->begin_op_global_order, logical_chain->end_op_global_order); +} + +void CreateAfterAccLogicalChain(const std::shared_ptr& after_acc_logical_chain, + const OpGraph& op_graph, + const std::vector& ordered_acc_op_nodes, + const HashMap& op_node2global_order) { + // TODO(chengcheng); +} + Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { std::vector ordered_op_nodes; HashMap op_node2global_order; + HashMap mut_op_name2conf; // TODO(chengcheng) : better order for memory. std::shared_ptr seed_time_shape = std::make_shared(Shape({1, 1})); op_graph.TopoForEachNodeWithCtrlEdge([&](const OpNode* node) { @@ -193,532 +243,90 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui if (this_time_shape->elem_cnt() > seed_time_shape->elem_cnt()) { seed_time_shape = this_time_shape; } + mut_op_name2conf.emplace(node->op().op_name(), node->op().op_conf()); }); VLOG(2) << " seed time shape = " << seed_time_shape->ToString(); std::vector> logical_chains; - GetChainsWithTimeShape(&logical_chains, op_graph, ordered_op_nodes, seed_time_shape); + GetLogicalChainsWithTimeShape(&logical_chains, op_graph, ordered_op_nodes, seed_time_shape); if (logical_chains.size() == 0) { return Maybe::Ok(); } int64_t logical_chain_id = 0; - auto NewLogicalChainId = [&]() { return logical_chain_id++}; + auto NewLogicalChainId = [&]() { return logical_chain_id++; }; auto CmpOpNodeOrder = [&](const OpNode* lhs, const OpNode* rhs) { return op_node2global_order.at(lhs) < op_node2global_order.at(rhs); }; + auto CmpLogicalChainOrder = [&](const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return lhs->begin_op_global_order < rhs->begin_op_global_order; + }; auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); HashMap placement2logical_chains; for (const auto& origin_logical_chain : logical_chains) { - const OpNode* rand_node = *subgraph.begin(); + const OpNode* rand_node = *origin_logical_chain.begin(); const ParallelDesc& this_parallel_desc = rand_node->parallel_desc(); std::string key = GenParallelConfKey(this_parallel_desc.parallel_conf()); - const std::shared_ptr& this_time_shape = GetOpNodeFastestTimeShape(rand_node); - auto it = placement2subgraphs.find(key); - if (it == placement2subgraphs.end()) { - it = placement2subgraphs.emplace(key, PlacementNcclSubGraghsInfo()).first; - auto& info = it->second; - info.seed_parallel_desc = &this_parallel_desc; - info.seed_time_shape = this_time_shape; - info.ordered_subgraph.emplace_back(std::make_shared()); - InitInsertNcclSubGraphInfoFromSet(info.ordered_subgraph.back(), subgraph, - op_node2global_order, CmpOpNodeOrder); - } else { - auto& info = it->second; - if (SharedPtrShapeEqual(info.seed_time_shape, this_time_shape)) { - CHECK(this_parallel_desc.EqualsIgnoringHierarchy(*info.seed_parallel_desc)); - std::shared_ptr nccl_subgraph_info = - std::make_shared(); - InitInsertNcclSubGraphInfoFromSet(nccl_subgraph_info, subgraph, op_node2global_order, - CmpOpNodeOrder); - CHECK_GT(info.ordered_subgraph.size(), 0); - const auto& first_graph = info.ordered_subgraph.front(); - const auto& last_graph = info.ordered_subgraph.back(); - int64_t first_order = first_graph->begin_op_global_order; - int64_t last_order = last_graph->end_op_global_order; - if (nccl_subgraph_info->end_op_global_order < first_order) { - if (IsReachable(nccl_subgraph_info->end_op->op().op_name(), - first_graph->begin_op->op().op_name())) { - info.ordered_subgraph.insert(info.ordered_subgraph.begin(), nccl_subgraph_info); - } - } else if (nccl_subgraph_info->begin_op_global_order > last_order) { - if (IsReachable(last_graph->end_op->op().op_name(), - nccl_subgraph_info->begin_op->op().op_name())) { - info.ordered_subgraph.emplace_back(nccl_subgraph_info); - } - } else { - auto before = info.ordered_subgraph.begin(); - auto next = before + 1; - while (next != info.ordered_subgraph.end()) { - if ((*before)->end_op_global_order < nccl_subgraph_info->begin_op_global_order - && nccl_subgraph_info->end_op_global_order < (*next)->begin_op_global_order) { - if (IsReachable((*before)->end_op->op().op_name(), - nccl_subgraph_info->begin_op->op().op_name()) - && IsReachable(nccl_subgraph_info->end_op->op().op_name(), - (*next)->begin_op->op().op_name())) { - info.ordered_subgraph.insert(next, nccl_subgraph_info); - } - break; - } - before = next; - next++; - } - } - } + auto it = placement2logical_chains.find(key); + if (it == placement2logical_chains.end()) { + it = placement2logical_chains.emplace(key, PlacementLogicalChainsInfo()).first; + it->second.seed_parallel_desc = &this_parallel_desc; } + auto& info = it->second; + info.ordered_logical_chains.emplace_back(std::make_shared()); + InitPlacementLogicalChainsInfoFromSet(info.ordered_logical_chains.back(), origin_logical_chain, + op_node2global_order, CmpOpNodeOrder); + } + + for (auto& pair : placement2logical_chains) { + std::sort(pair.second.ordered_logical_chains.begin(), pair.second.ordered_logical_chains.end(), + CmpLogicalChainOrder); } for (const OpNode* this_node : ordered_op_nodes) { if (IsAccOpNode(this_node)) { const ParallelDesc& this_parallel_desc = this_node->parallel_desc(); std::string key = GenParallelConfKey(this_parallel_desc.parallel_conf()); - auto it = placement2subgraphs.find(key); - if (it != placement2subgraphs.end()) { + auto it = placement2logical_chains.find(key); + if (it != placement2logical_chains.end()) { it->second.ordered_acc_op_nodes.emplace_back(this_node); } } + mut_op_name2conf.at(this_node->op().op_name()) + .set_logical_order(JUST(MapAt(op_node2global_order, this_node))); } - for (auto& pair : placement2subgraphs) { - PlacementNcclSubGraghsInfo& info = pair.second; - for (int i = 0; i < info.ordered_subgraph.size() - 1; i++) { - CHECK_LT(info.ordered_subgraph.at(i)->end_op_global_order, - info.ordered_subgraph.at(i + 1)->begin_op_global_order); - } - - // NOTE(chengcheng): insert nccl ops for each subgraph - uint32_t stream_offset = 0; - int64_t total_op_num = 0; - for (int i = 0; i < info.ordered_subgraph.size(); i++) { - auto& ordered_op_nodes = info.ordered_subgraph.at(i)->ordered_op_nodes; - InsertNcclLogicalOpsInSubGraph(op_graph, job_builder, ordered_op_nodes, IsReachable, i, - &stream_offset); - total_op_num += ordered_op_nodes.size(); - } - if (stream_offset >= 2 && total_op_num >= 1000) { - LOG(WARNING) << " In Graph: " << job_builder->job().job_conf().job_name() - << " Placement: " << pair.first << " the total_op_num = " << total_op_num - << " and has " << stream_offset - << " different nccl stream which is possible to trigger cuda stream kernel " - "launch upper limit." - << " So the nccl logical kernel will from async to sync exec, which may affect " - "performance."; - EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Singleton::Get()); - comm_mgr->SetAsyncLaunchNcclLogicalKernel(false); + for (auto& pair : placement2logical_chains) { + auto& info = pair.second; + for (int i = 0; i < info.ordered_logical_chains.size() - 1; i++) { + CHECK_LT(JUST(VectorAt(info.ordered_logical_chains, i))->begin_op_global_order, + JUST(VectorAt(info.ordered_logical_chains, i + 1))->begin_op_global_order); } - // NOTE(chengcheng): insert acc for all subgraph with same placement group - const OpNode* bw_sink_op = info.ordered_subgraph.back()->end_op; + // NOTE(chengcheng): create logical chain after acc, and merge with first logical chain. const std::vector& ordered_acc_op_nodes = info.ordered_acc_op_nodes; - if (!ordered_acc_op_nodes.empty()) { - InsertBwSinkAccTickAndNcclLogicalOpsInPlacementGroupAfterAcc( - op_graph, job_builder, ordered_acc_op_nodes, op_node2global_order, bw_sink_op); - } - } - - return Maybe::Ok(); -} - - - - -bool IsOpEdgeAllowInsertNccl(const OpEdge* edge, - const std::shared_ptr& seed_time_shape) { - const OpNode* src_node = edge->src_node(); - const OpNode* dst_node = edge->dst_node(); - const ParallelDesc& src_parallel_desc = src_node->parallel_desc(); - return src_parallel_desc.device_type() == DeviceType::kCUDA - && src_parallel_desc.parallel_num() > 1 - && src_parallel_desc.EqualsIgnoringHierarchy(dst_node->parallel_desc()) - && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(src_node), seed_time_shape) - && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(dst_node), seed_time_shape); -} - -struct InsertedNcclInfo { - OperatorConf nccl_op_conf; - ParallelConf nccl_parallel_conf; - int64_t order; - const OpNode* src_node; - const OpNode* dst_node; - std::string debug_str; -}; - -void InsertNcclLogicalOpsAfterAcc(const OpGraph& op_graph, - const HashMap& op_node2global_order, - const std::vector& ordered_acc_op_nodes, - const std::string& bw_sink_tick_op_name, - HashMap* mut_consumer_name2op, - std::vector* nccl_op_confs, - std::vector* nccl_op_parallel_confs) { - HashSet visited; - std::shared_ptr seed_time_shape = GetOpNodeFastestTimeShape(ordered_acc_op_nodes.front()); - std::vector nccl_op_infos; - - std::vector ordered_after_acc_subgraph; - // NOTE(chengcheng): bfs for op_edge may create duplicated node. - HashSet after_acc_subgraph_nodes; - HashMap op2subgraph_order; - - for (const OpNode* acc : ordered_acc_op_nodes) { - std::queue queued_edges; - for (const OpEdge* op_edge : acc->out_edges()) { - if (visited.find(op_edge) == visited.end() - && IsOpEdgeAllowInsertNccl(op_edge, seed_time_shape)) { - queued_edges.push(op_edge); - CHECK(visited.insert(op_edge).second); - if (!IsAccOpNode(op_edge->dst_node())) { - after_acc_subgraph_nodes.insert(op_edge->dst_node()); - } - } - } - - auto NextEdgeNode2AfterAccSubGraph = [&](const OpEdge* next_edge, const OpNode* next_node) { - if (visited.find(next_edge) == visited.end() - && IsOpEdgeAllowInsertNccl(next_edge, seed_time_shape)) { - CHECK(visited.insert(next_edge).second); - queued_edges.push(next_edge); - if (!IsAccOpNode(next_node)) { after_acc_subgraph_nodes.insert(next_node); } - } - }; - - // bfs search each edge after acc allow insert nccl. try insert. - while (!queued_edges.empty()) { - const OpEdge* op_edge = queued_edges.front(); - queued_edges.pop(); - - for (const LogicalBlobId& lbi : op_edge->lbis()) { - const OpNode* src_node = op_edge->src_node(); - const OpNode* dst_node = op_edge->dst_node(); - const std::string& src_op_name = src_node->op().op_name(); - const std::string& dst_op_name = dst_node->op().op_name(); - OperatorConf nccl_op; - ParallelDesc src_reduced_parallel_desc = op_edge->src_node()->parallel_desc(); - ParallelDesc dst_reduced_parallel_desc = op_edge->dst_node()->parallel_desc(); - NdSbp src_reduced_nd_sbp; - NdSbp dst_reduced_nd_sbp; - if (!TryBuildNcclLogicalOpConf(&nccl_op, op_edge->src_node(), op_edge->dst_node(), lbi, - &src_reduced_parallel_desc, &dst_reduced_parallel_desc, - &src_reduced_nd_sbp, &dst_reduced_nd_sbp)) { - continue; - } - auto it = mut_consumer_name2op->find(dst_op_name); - if (it == mut_consumer_name2op->end()) { - auto ret_pair = mut_consumer_name2op->emplace(dst_op_name, dst_node->op().op_conf()); - CHECK(ret_pair.second); - it = ret_pair.first; - } - // insert nccl op - user_op::UserOpConfWrapper nccl_op_wrapper(nccl_op); - for (const std::string& ibn : op_edge->lbi2ibns().at(lbi)) { - std::string old_lbn = ReplaceInputLbnInOpCustomizedConf(&(it->second), ibn, - nccl_op_wrapper.output("out", 0)); - } - - InsertedNcclInfo nccl_op_info; - nccl_op_info.nccl_op_conf = nccl_op; - nccl_op_info.nccl_parallel_conf = src_reduced_parallel_desc.parallel_conf(); - nccl_op_info.order = op_node2global_order.at(src_node); - nccl_op_info.src_node = src_node; - nccl_op_info.dst_node = dst_node; - nccl_op_info.debug_str = - (" After ACC insert nccl op: " + nccl_op.name() + " from [" + src_op_name - + ", sbp=" + NdSbpToString(src_node->NdSbp4Lbi(lbi)) + "] to [" + dst_op_name - + ", sbp=" + NdSbpToString(dst_node->NdSbp4Lbi(lbi)) - + ", src_order=" + std::to_string(nccl_op_info.order) + "]\n"); - nccl_op_infos.emplace_back(nccl_op_info); - } - - // NOTE(chengcheng): BFS for all edges and nodes after acc. - for (const OpEdge* dst_node_out_edge : op_edge->dst_node()->out_edges()) { - NextEdgeNode2AfterAccSubGraph(dst_node_out_edge, dst_node_out_edge->dst_node()); - } - for (const OpEdge* dst_node_in_edge : op_edge->dst_node()->in_edges()) { - NextEdgeNode2AfterAccSubGraph(dst_node_in_edge, dst_node_in_edge->src_node()); - } - for (const OpEdge* src_node_out_edge : op_edge->src_node()->out_edges()) { - NextEdgeNode2AfterAccSubGraph(src_node_out_edge, src_node_out_edge->dst_node()); - } - for (const OpEdge* src_node_in_edge : op_edge->src_node()->in_edges()) { - NextEdgeNode2AfterAccSubGraph(src_node_in_edge, src_node_in_edge->src_node()); - } - } - } - - for (const auto* node : after_acc_subgraph_nodes) { ordered_after_acc_subgraph.push_back(node); } - - CHECK_EQ(after_acc_subgraph_nodes.size(), ordered_after_acc_subgraph.size()); - - std::sort(nccl_op_infos.begin(), nccl_op_infos.end(), - [](const InsertedNcclInfo& lhs, const InsertedNcclInfo& rhs) { - return lhs.order < rhs.order; - }); - - std::sort(ordered_after_acc_subgraph.begin(), ordered_after_acc_subgraph.end(), - [&](const OpNode* lhs, const OpNode* rhs) { - return op_node2global_order.at(lhs) < op_node2global_order.at(rhs); - }); - - auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); - - for (int64_t i = 0; i < ordered_after_acc_subgraph.size(); ++i) { - op2subgraph_order.emplace(ordered_after_acc_subgraph.at(i), i); - } - - for (int64_t i = 1; i < ordered_after_acc_subgraph.size(); ++i) { - const OpNode* this_node = ordered_after_acc_subgraph.at(i); - const OpNode* pre_node = ordered_after_acc_subgraph.at(i - 1); - const std::string& this_op_name = this_node->op().op_name(); - const std::string& pre_op_name = pre_node->op().op_name(); - // build ctrl edge if need. - if (!IsReachable(pre_op_name, this_op_name)) { - auto it = mut_consumer_name2op->find(this_op_name); - if (it == mut_consumer_name2op->end()) { - auto ret_pair = mut_consumer_name2op->emplace(this_op_name, this_node->op().op_conf()); - CHECK(ret_pair.second); - it = ret_pair.first; - } - OperatorConf* mut_op_conf = &(it->second); - mut_op_conf->add_ctrl_in_op_name(pre_op_name); + CreateAfterAccLogicalChain(info.after_acc_logical_chain, op_graph, ordered_acc_op_nodes, + op_node2global_order); } - } - for (int64_t i = 0; i < nccl_op_infos.size(); ++i) { - auto& info = nccl_op_infos.at(i); - if (i == 0) { - info.nccl_op_conf.add_ctrl_in_op_name(bw_sink_tick_op_name); - } else { - info.nccl_op_conf.add_ctrl_in_op_name(nccl_op_infos.at(i - 1).nccl_op_conf.name()); - } - - nccl_op_confs->emplace_back(info.nccl_op_conf); - nccl_op_parallel_confs->emplace_back(info.nccl_parallel_conf); - VLOG(3) << info.debug_str; - - // NOTE(chengcheng): Try add ctrl between nccl and src op next node for strict exec order. - auto src_op_it = op2subgraph_order.find(info.src_node); - if (src_op_it != op2subgraph_order.end()) { - const int64_t src_sub_order = src_op_it->second; - const int64_t next_sub_order = src_sub_order + 1; - if (next_sub_order < ordered_after_acc_subgraph.size()) { - const OpNode* next_op = ordered_after_acc_subgraph.at(next_sub_order); - const std::string& next_op_name = next_op->op().op_name(); - const std::string& dst_op_name = info.dst_node->op().op_name(); - if (next_op_name != dst_op_name) { - if (mut_consumer_name2op->find(next_op_name) == mut_consumer_name2op->end()) { - CHECK(mut_consumer_name2op->emplace(next_op_name, next_op->op().op_conf()).second); - } - // NOTE(chengcheng): MUST add ctrl edge for strict exec orde - mut_consumer_name2op->at(next_op_name).add_ctrl_in_op_name(info.nccl_op_conf.name()); - } + for (auto& logical_chain : info.ordered_logical_chains) { + logical_chain->logical_chain_id = NewLogicalChainId(); + for (const OpNode* op_node : logical_chain->ordered_op_nodes) { + JUST(MapAt(mut_op_name2conf, op_node->op().op_name())) + .set_logical_chain_id(logical_chain->logical_chain_id); } } } -} - -struct InsertNcclSubGraph { - std::vector ordered_op_nodes; - int64_t begin_op_global_order; - int64_t end_op_global_order; - const OpNode* begin_op; - const OpNode* end_op; -}; - -struct PlacementNcclSubGraghsInfo { - std::vector> ordered_subgraph; - std::vector ordered_acc_op_nodes; - const ParallelDesc* seed_parallel_desc; - std::shared_ptr seed_time_shape; -}; - -void InitInsertNcclSubGraphInfoFromSet( - std::shared_ptr nccl_subgraph_info, const HashSet& subgraph, - const HashMap& op_node2global_order, - const std::function& CmpOpNodeOrder) { - auto* subgraph_ordered_nodes = &nccl_subgraph_info->ordered_op_nodes; - subgraph_ordered_nodes->assign(subgraph.begin(), subgraph.end()); - std::sort(subgraph_ordered_nodes->begin(), subgraph_ordered_nodes->end(), CmpOpNodeOrder); - nccl_subgraph_info->begin_op = subgraph_ordered_nodes->front(); - nccl_subgraph_info->end_op = subgraph_ordered_nodes->back(); - nccl_subgraph_info->begin_op_global_order = op_node2global_order.at(nccl_subgraph_info->begin_op); - nccl_subgraph_info->end_op_global_order = op_node2global_order.at(nccl_subgraph_info->end_op); - CHECK(nccl_subgraph_info->begin_op != nccl_subgraph_info->end_op); - CHECK_LT(nccl_subgraph_info->begin_op_global_order, nccl_subgraph_info->end_op_global_order); -} - -constexpr uint32_t kMaxNcclComputeStreamCount = 8; - -std::string GetStreamIndexName(uint32_t id) { return "NCCL_COMPUTE_" + std::to_string(id); } - -void InsertNcclLogicalOpsInSubGraph( - const OpGraph& op_graph, JobBuilder* job_builder, - const std::vector& subgraph_order, - const std::function& IsReachable, - const int32_t subgraph_id_in_same_placement_group, uint32_t* stream_offset) { - HashMap node2subgraph_order; - node2subgraph_order.reserve(subgraph_order.size()); - for (int64_t i = 0; i < subgraph_order.size(); ++i) { - CHECK(node2subgraph_order.emplace(subgraph_order.at(i), i).second); - } - - if (Singleton::Get()->enable_debug_mode()) { - VLOG(3) << " Try insert nccl logical ops into job: " << job_builder->job().job_conf().job_name() - << ". Begin...\n"; - } - - HashSet mut_op_names; - const OpNode* first_node = subgraph_order.at(0); - HashMap subgraph_op_name2conf; - subgraph_op_name2conf.emplace(first_node->op().op_name(), first_node->op().op_conf()); - - // add ctrl for strict order. - for (int64_t i = 1; i < subgraph_order.size(); ++i) { - const OpNode* this_node = subgraph_order.at(i); - const OpNode* pre_node = subgraph_order.at(i - 1); - const std::string& this_op_name = this_node->op().op_name(); - const std::string& pre_op_name = pre_node->op().op_name(); - CHECK(subgraph_op_name2conf.emplace(this_op_name, this_node->op().op_conf()).second); - // build ctrl edge if need. - if (!IsReachable(pre_op_name, this_op_name)) { - subgraph_op_name2conf.at(this_op_name).add_ctrl_in_op_name(pre_op_name); - mut_op_names.insert(this_op_name); - } - } - - std::vector nccl_op_confs; - std::vector nccl_op_parallel_confs; - // NOTE(chengcheng): ONLY support insert nccl to dst for memory. - InsertNcclLogicalOpsAsCloseAsPossibleToDstNode(&subgraph_op_name2conf, &mut_op_names, - &nccl_op_confs, &nccl_op_parallel_confs, - subgraph_order, node2subgraph_order); - if (Singleton::Get()->enable_debug_mode()) { - VLOG(3) << " Try insert nccl logical ops into job: " << job_builder->job().job_conf().job_name() - << ". ...End\n\n"; - } - - // NOTE(chengcheng): For NCCL logical correct exec order in pipeline multi-subgraph. - do { - if (nccl_op_confs.empty()) { break; } - int64_t nccl_compute_stream_id = *stream_offset; - if (nccl_compute_stream_id >= kMaxNcclComputeStreamCount) { - break; // NOTE(chengcheng): ONLY support kMaxNcclComputeStreamCount insert nccl subgraphs. - } - std::string stream_index_name = GetStreamIndexName(nccl_compute_stream_id); + // NOTE(chengcheng): update global order and chain id for ops. + for (const auto& pair : mut_op_name2conf) { JUST(job_builder->MutOpOnlyOnce(pair.second)); } - // NOTE(chengcheng): set ALL subgraph op and ALL nccl op stream index. - for (auto& pair : subgraph_op_name2conf) { - mut_op_names.insert(pair.first); - pair.second.set_stream_name_hint(stream_index_name); - } - for (auto& nccl_op : nccl_op_confs) { nccl_op.set_stream_name_hint(stream_index_name); } - (*stream_offset)++; - } while (false); - - std::vector mut_op_confs; - mut_op_confs.reserve(mut_op_names.size()); - for (const std::string& mut_op_name : mut_op_names) { - mut_op_confs.emplace_back(subgraph_op_name2conf.at(mut_op_name)); - } - job_builder->MutOpsOnlyOnce(mut_op_confs); - - CHECK_EQ(nccl_op_confs.size(), nccl_op_parallel_confs.size()); - for (int64_t i = 0; i < nccl_op_confs.size(); ++i) { - CHECK_JUST(job_builder->AddOp(nccl_op_parallel_confs.at(i), nccl_op_confs.at(i))); - } -} - -void InsertBwSinkAccTickAndNcclLogicalOpsInPlacementGroupAfterAcc( - const OpGraph& op_graph, JobBuilder* job_builder, - const std::vector& ordered_acc_op_nodes, - const HashMap& op_node2global_order, const OpNode* bw_sink_op) { - const OpNode* first_acc_op = ordered_acc_op_nodes.front(); - std::shared_ptr time_shape_before_acc = GetOpNodeFastestTimeShape(bw_sink_op); - std::shared_ptr time_shape_after_acc = GetOpNodeFastestTimeShape(first_acc_op); - VLOG(3) << " Find acc ops (num=" << ordered_acc_op_nodes.size() - << ") in Job: " << job_builder->job().job_conf().job_name() - << ", we will try insert special identity and ctrl for " - << " UNSAFE handle ALL nccl ops between different time shape: " - << time_shape_before_acc->DebugStr() << "->acc->" << time_shape_after_acc->DebugStr() - << "\n\n"; - CHECK_GT(time_shape_before_acc->elem_cnt(), time_shape_after_acc->elem_cnt()); - CHECK_EQ(time_shape_before_acc->elem_cnt() % time_shape_after_acc->elem_cnt(), 0); - - for (const OpNode* acc : ordered_acc_op_nodes) { - CHECK(SharedPtrShapeEqual(time_shape_before_acc, GetOpNodeInputTimeShape(acc))); - CHECK(SharedPtrShapeEqual(time_shape_after_acc, GetOpNodeFastestTimeShape(acc))); - } - - // NOTE(chengcheng): insert acc_tick after bw_sink_op, and this tick op conf will control - // after_acc_nccl_ops start. - const auto& obns = bw_sink_op->op().output_bns(); - CHECK(!obns.empty()); - const std::string bw_sink_op_out_lbn = - GenLogicalBlobName(bw_sink_op->op().BnInOp2Lbi(obns.Get(0))); - VLOG(3) << " bw_sink_op : " << bw_sink_op->op().op_conf().DebugString(); - - user_op::UserOpConfWrapper cast_to_tick_op = - user_op::UserOpConfWrapperBuilder("System-CastToTick-" + NewUniqueId()) - .OpTypeName("cast_to_tick") - .Input("in", bw_sink_op_out_lbn) - .Output("out") - .Build(); - - OperatorConf bw_sink_acc_tick_conf; - bw_sink_acc_tick_conf.set_name(std::string("System-BwSinkTick-AccTick_") + NewUniqueId()); - auto* acc_conf = bw_sink_acc_tick_conf.mutable_acc_tick_conf(); - acc_conf->set_one(cast_to_tick_op.output("out", 0)); - acc_conf->set_acc("acc"); - acc_conf->set_max_acc_num(time_shape_before_acc->elem_cnt() / time_shape_after_acc->elem_cnt()); - - OperatorConf bw_sink_final_tick_conf; - bw_sink_final_tick_conf.set_name(std::string("System-BwSinkFinalTick-DeviceTick_") - + NewUniqueId()); - auto* tick_conf = bw_sink_final_tick_conf.mutable_device_tick_conf(); - tick_conf->add_tick(GenLogicalBlobName(bw_sink_acc_tick_conf.name(), "acc")); - tick_conf->set_out("out"); - - // insert nccl ops after acc - std::vector after_acc_nccl_op_confs; - std::vector after_acc_nccl_parallel_confs; - HashMap mut_consumer_name2op; - - InsertNcclLogicalOpsAfterAcc(op_graph, op_node2global_order, ordered_acc_op_nodes, - bw_sink_final_tick_conf.name(), &mut_consumer_name2op, - &after_acc_nccl_op_confs, &after_acc_nccl_parallel_confs); - - if (after_acc_nccl_op_confs.empty()) { - CHECK(after_acc_nccl_parallel_confs.empty()); - CHECK(mut_consumer_name2op.empty()); - } else { - // insert bw sink acc tick ops - CHECK_JUST( - job_builder->AddOp(bw_sink_op->parallel_desc().parallel_conf(), cast_to_tick_op.op_conf())); - VLOG(3) << " Insert cast_to_tick_op : " << cast_to_tick_op.op_conf().DebugString(); - - CHECK_JUST( - job_builder->AddOp(bw_sink_op->parallel_desc().parallel_conf(), bw_sink_acc_tick_conf)); - VLOG(3) << " Insert bw_sink_acc_tick_op : " << bw_sink_acc_tick_conf.DebugString(); - - CHECK_JUST( - job_builder->AddOp(bw_sink_op->parallel_desc().parallel_conf(), bw_sink_final_tick_conf)); - VLOG(3) << " Insert bw_sink_final_tick_op : " << bw_sink_final_tick_conf.DebugString(); - - // insert nccl ops after acc - for (const auto& pair : mut_consumer_name2op) { - CHECK_JUST(job_builder->MutOpOnlyOnce(pair.second)); - } - CHECK_EQ(after_acc_nccl_op_confs.size(), after_acc_nccl_parallel_confs.size()); - for (int64_t i = 0; i < after_acc_nccl_op_confs.size(); ++i) { - CHECK_JUST( - job_builder->AddOp(after_acc_nccl_parallel_confs.at(i), after_acc_nccl_op_confs.at(i))); - } - } + return Maybe::Ok(); } - - } // namespace REGISTER_JOB_PASS("LogicalChainPass", LogicalChainPass); From d3c9b093c6e247803e50bca0b4c3100325bc962c Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Thu, 18 Aug 2022 09:52:56 +0000 Subject: [PATCH 05/66] fix compile --- oneflow/core/common/env_var/debug_mode.h | 3 ++ oneflow/core/graph/task_graph.cpp | 54 ++----------------- .../core/job_rewriter/logical_chain_pass.cpp | 7 ++- 3 files changed, 10 insertions(+), 54 deletions(-) diff --git a/oneflow/core/common/env_var/debug_mode.h b/oneflow/core/common/env_var/debug_mode.h index 9f2790e49c8..4a9dcc546b2 100644 --- a/oneflow/core/common/env_var/debug_mode.h +++ b/oneflow/core/common/env_var/debug_mode.h @@ -25,6 +25,9 @@ DEFINE_ENV_BOOL(ONEFLOW_DEBUG, false); inline bool IsInDebugMode() { return EnvBool() || EnvBool(); } +DEFINE_ENV_BOOL(ENABLE_LOGICAL_CHAIN, true); +inline bool EnableLogicalChain() { return EnvBool(); } + } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_ENV_VAR_DEBUG_MODE_H_ diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 3aa281713c3..d8a7be6ae35 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/graph/task_graph.h" #include "oneflow/core/common/util.h" +#include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/graph/inplace_lbi_graph.h" #include "oneflow/core/register/blob_desc.h" #include "oneflow/core/job/global_for.h" @@ -563,57 +564,10 @@ void TaskGraph::SetOrderInGraphForEachNode() { TopoForEachNode(SetOrderInGraph); } -void GetLogicalChains(std::vector>* ret, const OpGraph& op_graph, - const std::vector& order) { - HashSet visited; - - for (const OpNode* seed_node : order) { - if (visited.find(seed_node) != visited.end()) { continue; } - CHECK(visited.insert(seed_node).second); - const ParallelDesc& seed_parallel_desc = seed_node->parallel_desc(); - // NOTE(chengcheng): ONLY consider GPU op and parallel num > 1. - if (seed_parallel_desc.device_type() != DeviceType::kCUDA) { continue; } - if (seed_parallel_desc.parallel_num() <= 1) { continue; } - if (IsBreakpointOpNode(seed_node)) { continue; } - - HashSet this_subgraph; - std::queue queued_nodes; - - std::shared_ptr seed_time_shape = GetOpNodeTimeShape(seed_node); - queued_nodes.push(seed_node); - while (!queued_nodes.empty()) { - const OpNode* cur_node = queued_nodes.front(); - queued_nodes.pop(); - - CHECK(cur_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc)); - CHECK(this_subgraph.insert(cur_node).second); - - cur_node->ForEachNodeOnInOutEdge([&](const OpNode* next_node) { - if (visited.find(next_node) == visited.end() && (!IsBreakpointOpNode(next_node)) - && next_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc) - && SharedPtrShapeEqual(GetOpNodeTimeShape(next_node), seed_time_shape)) { - CHECK(visited.insert(next_node).second); - queued_nodes.push(next_node); - } - }); - } - - if (this_subgraph.size() > 1) { - ret->emplace_back(HashSet()); - ret->back().swap(this_subgraph); - } - } - - std::sort(ret->begin(), ret->end(), - [](const HashSet& lhs, const HashSet& rhs) { - return lhs.size() > rhs.size(); - }); -} - void TaskGraph::MergeChain() { - if (EnvBool()) { - for (auto* this_node : ordered_task_nodes_) { - const auto* comp_node = dynamic_cast(node); + if (EnableLogicalChain()) { + for (TaskNode* this_node : ordered_task_nodes_) { + CompTaskNode* comp_node = dynamic_cast(this_node); const int64_t logical_chain_id = comp_node->op()->op_conf().logical_chain_id(); if (logical_chain_id != -1) { this_node->set_chain_id(logical_chain_id); } } diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index 16e8341a031..70421ceb5e1 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -27,11 +27,10 @@ limitations under the License. #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/env_var/env_var.h" +#include "oneflow/core/common/env_var/debug_mode.h" namespace oneflow { -DEFINE_ENV_BOOL(ENABLE_LOGICAL_CHAIN, true); - namespace { class LogicalChainPass final : public JobPass { @@ -47,7 +46,7 @@ class LogicalChainPass final : public JobPass { return Apply(op_graph, &job_builder); } - bool IsEnabled(const JobPassCtx& ctx) const { return EnvBool(); } + bool IsEnabled(const JobPassCtx& ctx) const { return EnableLogicalChain(); } Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; }; @@ -294,7 +293,7 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui it->second.ordered_acc_op_nodes.emplace_back(this_node); } } - mut_op_name2conf.at(this_node->op().op_name()) + JUST(MapAt(mut_op_name2conf, this_node->op().op_name())) .set_logical_order(JUST(MapAt(op_node2global_order, this_node))); } From 6f7ed2ca1480420fea63fb1bd25ebaf97d07b64e Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 19 Aug 2022 04:36:56 +0000 Subject: [PATCH 06/66] logical chain runnable --- oneflow/core/graph/task_graph.cpp | 5 ++++- oneflow/core/job_rewriter/job_completer.cpp | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index d8a7be6ae35..7362ccb3ab9 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -192,7 +192,8 @@ MakePredicatorIsLbiAllConsumersReachable( IsOpNameDataOrCtrlReachable) { auto IsDataOrCtrlReachable = [IsOpNameDataOrCtrlReachable](const TaskNode* src_node, const TaskNode* dst_node) -> bool { - if (src_node->chain_id() == dst_node->chain_id() + if (src_node->chain_id() != -1 && dst_node->chain_id() != -1 + && src_node->chain_id() == dst_node->chain_id() && src_node->order_in_graph() <= dst_node->order_in_graph()) { return true; } @@ -568,6 +569,7 @@ void TaskGraph::MergeChain() { if (EnableLogicalChain()) { for (TaskNode* this_node : ordered_task_nodes_) { CompTaskNode* comp_node = dynamic_cast(this_node); + if (!comp_node) { continue; } const int64_t logical_chain_id = comp_node->op()->op_conf().logical_chain_id(); if (logical_chain_id != -1) { this_node->set_chain_id(logical_chain_id); } } @@ -595,6 +597,7 @@ void TaskGraph::BuildCtrlRegstDescInSameChain() { for (auto* node : ordered_task_nodes_) { if (IsConnectToTickOp(node)) { continue; } int64_t chain_id = node->chain_id(); + if (chain_id == -1) { continue; } // NOTE(chengcheng): skip chain id default -1. auto iter = chain_id2node.find(chain_id); if (iter == chain_id2node.end()) { CHECK(chain_id2node.emplace(chain_id, node).second); diff --git a/oneflow/core/job_rewriter/job_completer.cpp b/oneflow/core/job_rewriter/job_completer.cpp index 4c32127a84a..1252b4d18bf 100644 --- a/oneflow/core/job_rewriter/job_completer.cpp +++ b/oneflow/core/job_rewriter/job_completer.cpp @@ -129,6 +129,8 @@ Maybe JobCompleter::Complete(Job* job) const { JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); } #endif // WITH_CUDA + JUST(JobPass4Name("LogicalChainPass")(job, &job_pass_ctx)); + JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); JUST(CheckOpGraph(OpGraph(*job))); return Maybe::Ok(); } From dccee3f4dd7bb2878f8514c5164f96c46a180613 Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 19 Aug 2022 08:31:53 +0000 Subject: [PATCH 07/66] fix bug of logical chain dp --- oneflow/core/graph/task_graph.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 7362ccb3ab9..4796b3978eb 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -593,14 +593,18 @@ void TaskGraph::MergeChain() { } void TaskGraph::BuildCtrlRegstDescInSameChain() { - HashMap chain_id2node; + auto GenPhysicalChainId = [](TaskNode* node) { + // NOTE(chengcheng): different rank cannot use same chain id for bad ctrl link. + return (node->chain_id() << 31) | (node->machine_id()); + }; + HashMap physical_chain_id2node; for (auto* node : ordered_task_nodes_) { if (IsConnectToTickOp(node)) { continue; } - int64_t chain_id = node->chain_id(); - if (chain_id == -1) { continue; } // NOTE(chengcheng): skip chain id default -1. - auto iter = chain_id2node.find(chain_id); - if (iter == chain_id2node.end()) { - CHECK(chain_id2node.emplace(chain_id, node).second); + if (node->chain_id() == -1) { continue; } // NOTE(chengcheng): skip chain id default -1. + int64_t physical_chain_id = GenPhysicalChainId(node); + auto iter = physical_chain_id2node.find(physical_chain_id); + if (iter == physical_chain_id2node.end()) { + CHECK(physical_chain_id2node.emplace(physical_chain_id, node).second); } else { TaskNode* src_node = iter->second; TaskNode* dst_node = node; From 2e750d364b7d2d705a011f8eb73a311801d2621c Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Sat, 20 Aug 2022 08:38:03 +0000 Subject: [PATCH 08/66] Part 1 : AfterGradAccChain --- .../core/job_rewriter/logical_chain_pass.cpp | 101 +++++++++++++++--- 1 file changed, 87 insertions(+), 14 deletions(-) diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index 70421ceb5e1..8401d1ea281 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -119,7 +119,7 @@ bool NeedInsertBoxingBetweenOpNodes(const OpNode* a_node, const OpNode* b_node, } void GetLogicalChainsWithTimeShape(std::vector>* ret, - const OpGraph& op_graph, const std::vector& order, + const std::vector& order, const std::shared_ptr& seed_time_shape) { HashSet visited; for (const OpNode* seed_node : order) { @@ -223,10 +223,50 @@ void InitPlacementLogicalChainsInfoFromSet( } void CreateAfterAccLogicalChain(const std::shared_ptr& after_acc_logical_chain, - const OpGraph& op_graph, const std::vector& ordered_acc_op_nodes, - const HashMap& op_node2global_order) { - // TODO(chengcheng); + const ParallelDesc& seed_parallel_desc) { + // Meta time shape (1, 1) + std::shared_ptr seed_time_shape = std::make_shared(Shape({1, 1})); + HashSet visited; + HashSet after_acc_chain_ops; + std::queue queued_nodes; + auto SearchToNextNode = [&](const OpNode* cur_node, const OpNode* next_node, const OpEdge* edge) { + if (visited.find(next_node) == visited.end() && (!IsBreakpointOpNode(next_node)) + && next_node->parallel_desc().EqualsIgnoringHierarchy(seed_parallel_desc) + && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(next_node), seed_time_shape) + && !NeedInsertBoxingBetweenOpNodes(cur_node, next_node, edge)) { + CHECK(visited.insert(next_node).second); + queued_nodes.push(next_node); + } + }; + + for (const OpNode* acc_node : ordered_acc_op_nodes) { + for (const OpEdge* out_edge : acc_node->out_edges()) { + const OpNode* seed_node = out_edge->dst_node(); + SearchToNextNode(acc_node, seed_node, out_edge); + } + } + + while (!queued_nodes.empty()) { + const OpNode* cur_node = queued_nodes.front(); + queued_nodes.pop(); + + CHECK(after_acc_chain_ops.insert(cur_node).second); + + for (const OpEdge* in_edge : cur_node->in_edges()) { + SearchToNextNode(cur_node, in_edge->src_node(), in_edge); + } + for (const OpEdge* out_edge : cur_node->out_edges()) { + SearchToNextNode(cur_node, out_edge->dst_node(), out_edge); + } + } + + if (after_acc_chain_ops.size() > 1) { + for (const OpNode* node : after_acc_chain_ops) { + after_acc_logical_chain->ordered_op_nodes.push_back(node); + } + CHECK_EQ(after_acc_logical_chain->ordered_op_nodes.size(), after_acc_chain_ops.size()); + } } Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { @@ -248,7 +288,7 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui VLOG(2) << " seed time shape = " << seed_time_shape->ToString(); std::vector> logical_chains; - GetLogicalChainsWithTimeShape(&logical_chains, op_graph, ordered_op_nodes, seed_time_shape); + GetLogicalChainsWithTimeShape(&logical_chains, ordered_op_nodes, seed_time_shape); if (logical_chains.size() == 0) { return Maybe::Ok(); } int64_t logical_chain_id = 0; @@ -297,6 +337,26 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui .set_logical_order(JUST(MapAt(op_node2global_order, this_node))); } + auto InsertCtrlEdgeInChain = [&](const std::vector& ordered_op_nodes) { + for (int64_t i = 1; i < ordered_op_nodes.size(); ++i) { + const OpNode* this_node = CHECK_JUST(VectorAt(ordered_op_nodes, i)); + const OpNode* prev_node = CHECK_JUST(VectorAt(ordered_op_nodes, i - 1)); + const std::string& this_op_name = this_node->op().op_name(); + const std::string& prev_op_name = prev_node->op().op_name(); + if (!IsReachable(prev_op_name, this_op_name)) { + CHECK_JUST(MapAt(mut_op_name2conf, this_op_name)).add_ctrl_in_op_name(prev_op_name); + } + } + }; + + auto InsertLogicalChainId = [&](const std::vector& ordered_op_nodes, + const int64_t logical_chain_id) { + for (const OpNode* op_node : ordered_op_nodes) { + CHECK_JUST(MapAt(mut_op_name2conf, op_node->op().op_name())) + .set_logical_chain_id(logical_chain_id); + } + }; + for (auto& pair : placement2logical_chains) { auto& info = pair.second; for (int i = 0; i < info.ordered_logical_chains.size() - 1; i++) { @@ -304,18 +364,31 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui JUST(VectorAt(info.ordered_logical_chains, i + 1))->begin_op_global_order); } + for (auto& logical_chain : info.ordered_logical_chains) { + logical_chain->logical_chain_id = NewLogicalChainId(); + InsertLogicalChainId(logical_chain->ordered_op_nodes, logical_chain->logical_chain_id); + InsertCtrlEdgeInChain(logical_chain->ordered_op_nodes); + } + // NOTE(chengcheng): create logical chain after acc, and merge with first logical chain. const std::vector& ordered_acc_op_nodes = info.ordered_acc_op_nodes; if (!ordered_acc_op_nodes.empty()) { - CreateAfterAccLogicalChain(info.after_acc_logical_chain, op_graph, ordered_acc_op_nodes, - op_node2global_order); - } - - for (auto& logical_chain : info.ordered_logical_chains) { - logical_chain->logical_chain_id = NewLogicalChainId(); - for (const OpNode* op_node : logical_chain->ordered_op_nodes) { - JUST(MapAt(mut_op_name2conf, op_node->op().op_name())) - .set_logical_chain_id(logical_chain->logical_chain_id); + CreateAfterAccLogicalChain(info.after_acc_logical_chain, ordered_acc_op_nodes, + *info.seed_parallel_desc); + if (info.after_acc_logical_chain->ordered_op_nodes.size() > 1) { + info.after_acc_logical_chain->logical_chain_id = NewLogicalChainId(); + std::sort(info.after_acc_logical_chain->ordered_op_nodes.begin(), + info.after_acc_logical_chain->ordered_op_nodes.end(), CmpOpNodeOrder); + const auto& chain_order_ops = info.after_acc_logical_chain->ordered_op_nodes; + info.after_acc_logical_chain->begin_op = chain_order_ops.front(); + info.after_acc_logical_chain->end_op = chain_order_ops.back(); + info.after_acc_logical_chain->begin_op_global_order = + JUST(MapAt(op_node2global_order, chain_order_ops.front())); + info.after_acc_logical_chain->end_op_global_order = + JUST(MapAt(op_node2global_order, chain_order_ops.back())); + + InsertLogicalChainId(chain_order_ops, info.after_acc_logical_chain->logical_chain_id); + InsertCtrlEdgeInChain(chain_order_ops); } } } From 42e6f86f049acfd264d46df07776c70fab9f1094 Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Mon, 22 Aug 2022 07:53:17 +0000 Subject: [PATCH 09/66] fix bug of crush in acc chain infer --- oneflow/core/job_rewriter/logical_chain_pass.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index 8401d1ea281..44195a3bb72 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -373,6 +373,7 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui // NOTE(chengcheng): create logical chain after acc, and merge with first logical chain. const std::vector& ordered_acc_op_nodes = info.ordered_acc_op_nodes; if (!ordered_acc_op_nodes.empty()) { + info.after_acc_logical_chain = std::make_shared(); CreateAfterAccLogicalChain(info.after_acc_logical_chain, ordered_acc_op_nodes, *info.seed_parallel_desc); if (info.after_acc_logical_chain->ordered_op_nodes.size() > 1) { From 7c35e1ae67ad0fe89cea48f7452d5835d845c14c Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 26 Aug 2022 06:35:29 +0000 Subject: [PATCH 10/66] AccCtrlTick Op/Task/Actor/Pass --- .../acc_ctrl_tick_compute_task_node.cpp | 59 +++++ oneflow/core/job/compiler.cpp | 1 + oneflow/core/job/job.proto | 5 + oneflow/core/job/job_conf.proto | 2 +- oneflow/core/job/plan_util.cpp | 4 + oneflow/core/job/plan_util.h | 1 + oneflow/core/job/task.proto | 7 +- .../insert_nccl_logical_op_pass.cpp | 3 + .../core/job_rewriter/logical_chain_pass.cpp | 29 +++ .../core/lazy/actor/acc_ctrl_tick_actor.cpp | 218 ++++++++++++++++++ oneflow/core/lazy/actor/register_slot.cpp | 7 + oneflow/core/lazy/actor/register_slot.h | 1 + oneflow/ir/include/OneFlow/OneFlowUserOps.td | 22 +- oneflow/user/ops/acc_ctrl_tick_op.cpp | 79 +++++++ 14 files changed, 432 insertions(+), 6 deletions(-) create mode 100644 oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp create mode 100644 oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp create mode 100644 oneflow/user/ops/acc_ctrl_tick_op.cpp diff --git a/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp b/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp new file mode 100644 index 00000000000..80ead618b5f --- /dev/null +++ b/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp @@ -0,0 +1,59 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/graph/compute_task_node.h" +#include "oneflow/core/graph/task_stream_index_manager.h" + +namespace oneflow { + +class AccCtrlTickCompTaskNode final : public CompTaskNode { + public: + OF_DISALLOW_COPY_AND_MOVE(AccCtrlTickCompTaskNode); + AccCtrlTickCompTaskNode() = default; + ~AccCtrlTickCompTaskNode() = default; + TaskType GetTaskType() const override { return TaskType::kAccCtrlTick; } + void ProduceAllRegstsAndBindEdges() override; + void ConsumeAllRegsts() override; + void BuildExecGphAndRegst() override; +}; + +void AccCtrlTickCompTaskNode::ProduceAllRegstsAndBindEdges() { + // std::shared_ptr regst = ProduceRegst("out", false); + // ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("out", regst); }); +} + +void AccCtrlTickCompTaskNode::ConsumeAllRegsts() { + ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); +} + +void AccCtrlTickCompTaskNode::BuildExecGphAndRegst() { + /* + std::shared_ptr in_regst = GetSoleConsumedRegst("in"); + std::shared_ptr out_regst = GetProducedRegst("out"); + std::shared_ptr op = this->op(); + ExecNode* exec_node = mut_exec_gph().NewNode(); + exec_node->mut_op() = op; + exec_node->BindBnWithRegst(op->SoleIbn(), in_regst); + out_regst->AddLbi(op->BnInOp2Lbi(op->SoleObn())); + exec_node->BindBnWithRegst(op->SoleObn(), out_regst); + exec_node->InferBlobDescs(parallel_ctx()); + */ +} + +REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kAccCtrlTick); + +REGISTER_USER_OP_COMP_TASK_NODE_TYPE("acc_ctrl_tick", AccCtrlTickCompTaskNode); + +} // namespace oneflow diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 4fc11e0eb87..fb757c75216 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -105,6 +105,7 @@ void Compiler::Compile(Job* job, Plan* plan) const { (*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf(); // NOTE(chengcheng): infer mem blob id & set inplace & add ctrl IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan, IsReachable); + PlanUtil::MergeMemBlockIdByLogicalChainId(plan, *job); PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan); Singleton::Delete(); } diff --git a/oneflow/core/job/job.proto b/oneflow/core/job/job.proto index 9be3edfb2b9..c6551ee51cc 100644 --- a/oneflow/core/job/job.proto +++ b/oneflow/core/job/job.proto @@ -25,6 +25,10 @@ message JobHelperConf { map op_name2arg_signature = 9; } +message MergedLogicalChainIdGroup { + repeated int64 logical_chain_id_list = 1; +} + message Job { optional DLNetConf net = 1; optional Placement placement = 2; @@ -32,4 +36,5 @@ message Job { optional JobParallelViewConf job_parallel_view_conf = 4; optional JobHelperConf helper = 5; map module_name2module_conf = 6; + repeated MergedLogicalChainIdGroup logical_chain_groups = 7; } diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index 1b7035877e1..6268ce4e4ed 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -167,7 +167,7 @@ message PredictConf { message MemoryAllocationAlgorithmConf { optional bool use_mem_size_first_algo = 1 [default = true]; optional bool use_mutual_exclusion_first_algo = 2 [default = true]; - optional bool use_time_line_algo = 3 [default = false]; + optional bool use_time_line_algo = 3 [default = true]; } message QatConfig { diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index cfbc19b5f83..25fc29d38ba 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -188,6 +188,10 @@ void GenChunkForMultiNNGraphMemoryReuseInMultiClient( } // namespace +void PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job) { + // TODO(); +} + void PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan( Plan* plan, const HashSet& variable_op_names) { HashMap> mem_block_id2mem_block; diff --git a/oneflow/core/job/plan_util.h b/oneflow/core/job/plan_util.h index 6c45db300e0..9de7142a629 100644 --- a/oneflow/core/job/plan_util.h +++ b/oneflow/core/job/plan_util.h @@ -28,6 +28,7 @@ namespace oneflow { struct PlanUtil { static RegstDescProto* GetSoleProducedDataRegst(TaskProto* task_proto); static std::function MakeGetterTaskProto4TaskId(const Plan& plan); + static void MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job); static void SetUniqueMemBlockId4UnreusedMemRegst(Plan* plan); static void GenMemBlockAndChunk4Plan(Plan* plan); static void GenMemBlockAndChunkWithVariableOpNames4Plan( diff --git a/oneflow/core/job/task.proto b/oneflow/core/job/task.proto index ef2ad9c4584..6432c5629c4 100644 --- a/oneflow/core/job/task.proto +++ b/oneflow/core/job/task.proto @@ -12,9 +12,10 @@ enum TaskType { kCopyCommNet = 13; kDeviceTick = 27; kPack = 30; - kUnpack = 32; - kRepeat = 34; - kAcc = 37; + kUnpack = 31; + kRepeat = 32; + kAcc = 33; + kAccCtrlTick = 34; kSrcSubsetTick = 38; kDstSubsetTick = 39; kSourceTick = 40; diff --git a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp index 1679a8d3303..d8aff601abf 100644 --- a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp +++ b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp @@ -968,10 +968,12 @@ void InsertBwSinkAccTickAndNcclLogicalOpsInPlacementGroupAfterAcc( .OpTypeName("cast_to_tick") .Input("in", bw_sink_op_out_lbn) .Output("out") + .ScopeSymbolId(bw_sink_op->op().op_conf().scope_symbol_id()) .Build(); OperatorConf bw_sink_acc_tick_conf; bw_sink_acc_tick_conf.set_name(std::string("System-BwSinkTick-AccTick_") + NewUniqueId()); + bw_sink_acc_tick_conf.set_scope_symbol_id(bw_sink_op->op().op_conf().scope_symbol_id()); auto* acc_conf = bw_sink_acc_tick_conf.mutable_acc_tick_conf(); acc_conf->set_one(cast_to_tick_op.output("out", 0)); acc_conf->set_acc("acc"); @@ -980,6 +982,7 @@ void InsertBwSinkAccTickAndNcclLogicalOpsInPlacementGroupAfterAcc( OperatorConf bw_sink_final_tick_conf; bw_sink_final_tick_conf.set_name(std::string("System-BwSinkFinalTick-DeviceTick_") + NewUniqueId()); + bw_sink_final_tick_conf.set_scope_symbol_id(bw_sink_op->op().op_conf().scope_symbol_id()); auto* tick_conf = bw_sink_final_tick_conf.mutable_device_tick_conf(); tick_conf->add_tick(GenLogicalBlobName(bw_sink_acc_tick_conf.name(), "acc")); tick_conf->set_out("out"); diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index 44195a3bb72..4e81149a1f9 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -254,6 +254,7 @@ void CreateAfterAccLogicalChain(const std::shared_ptr& after_acc_l CHECK(after_acc_chain_ops.insert(cur_node).second); for (const OpEdge* in_edge : cur_node->in_edges()) { + // NOTE(chengcheng): maybe bad case for too early source op before repeat. SearchToNextNode(cur_node, in_edge->src_node(), in_edge); } for (const OpEdge* out_edge : cur_node->out_edges()) { @@ -359,6 +360,7 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui for (auto& pair : placement2logical_chains) { auto& info = pair.second; + CHECK_GE(info.ordered_logical_chains.size(), 1); for (int i = 0; i < info.ordered_logical_chains.size() - 1; i++) { CHECK_LT(JUST(VectorAt(info.ordered_logical_chains, i))->begin_op_global_order, JUST(VectorAt(info.ordered_logical_chains, i + 1))->begin_op_global_order); @@ -390,6 +392,33 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui InsertLogicalChainId(chain_order_ops, info.after_acc_logical_chain->logical_chain_id); InsertCtrlEdgeInChain(chain_order_ops); + + // NOTE(chengcheng): + // 1.add acc ctrl tick between first chain src to acc chain sink for memory lock. + // 2.add acc tick between first chain sink to acc chain src for strict exec order. + const int64_t acc_num = job_builder->job().job_conf().num_gradient_accumulation_steps(); + CHECK_GT(acc_num, 1); + const OpNode* first_chain_src_op = info.ordered_logical_chains.front()->begin_op; + const auto& fcs_obns = first_chain_src_op->op().output_bns(); + CHECK(!fcs_obns.empty()); + const std::string& first_chain_src_out_lbn = + GenLogicalBlobName(first_chain_src_op->op().BnInOp2Lbi(fcs_obns.Get(0))); + + VLOG(3) << " first_chain_src_out_lbn : " << first_chain_src_out_lbn; + user_op::UserOpConfWrapper acc_ctrl_tick_op = + user_op::UserOpConfWrapperBuilder("Sys-AccCtrlTick4MergeFirstAccChain-" + NewUniqueId()) + .OpTypeName("acc_ctrl_tick") + .Input("in", first_chain_src_out_lbn) + .Output("out") + .ScopeSymbolId(first_chain_src_op->op().op_conf().scope_symbol_id()) + .Attr("max_acc_num", acc_num) + .Build(); + + JUST(MapAt(mut_op_name2conf, info.after_acc_logical_chain->end_op->op().op_name())) + .add_ctrl_in_op_name(acc_ctrl_tick_op.op_name()); + + JUST(job_builder->AddOp(first_chain_src_op->parallel_desc().parallel_conf(), + acc_ctrl_tick_op.op_conf())); } } } diff --git a/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp b/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp new file mode 100644 index 00000000000..c908780a22a --- /dev/null +++ b/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp @@ -0,0 +1,218 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/lazy/actor/actor.h" +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +class AccCtrlTickActor : public Actor { + public: + OF_DISALLOW_COPY_AND_MOVE(AccCtrlTickActor); + AccCtrlTickActor() + : acc_cnt_(0), + max_acc_num_(0), + inplace_consume_(false), + consumed_tick_regst_desc_id_(-1), + produced_tick_regst_desc_id_(-1){}; + virtual ~AccCtrlTickActor() = default; + + private: + // NOTE(chengcheng): Empty rs for naive and inplace regst, all regst is customized. + std::pair> GetNaiveOrCustomizedConsumedRegstDescName() + override { + return std::make_pair(RegstNameType::kNaive, HashSet{}); + } + std::pair> GetNaiveOrCustomizedProducedRegstDescName() + override { + return std::make_pair(RegstNameType::kNaive, HashSet{}); + } + + bool IsCustomizedReadReady() const override { + bool is_ready_ready = (!inplace_consume_) && consumed_tick_rs_.IsCurSlotReady(); + LOG(INFO) << " ccActorLog: actor: " << actor_id() << " is_ready_ready: " << is_ready_ready + << " of inplace_consume_ = " << inplace_consume_ + << " consumed_tick_rs_.IsCurSlotReady = " << consumed_tick_rs_.IsCurSlotReady(); + return (!inplace_consume_) && consumed_tick_rs_.IsCurSlotReady(); + } + bool IsCustomizedWriteReady() const override { + LOG(INFO) << " ccActorLog: actor: " << actor_id() + << " is_write_ready: " << produced_tick_rs_.IsCurSlotReady(); + return produced_tick_rs_.IsCurSlotReady(); + } + + void NormalProcessCustomizedEordMsg(const ActorMsg&) override {} + bool IsCustomizedReadAlwaysUnReadyFromNow() const override { + // all Messages are flushed + return ReceiveEordMsg(consumed_tick_regst_desc_id_); + } + + void VirtualActorInit(const TaskProto& proto) override; + void Act() override; + void AsyncSendCustomizedProducedRegstMsgToConsumer() override; + void AsyncSendCustomizedConsumedRegstMsgToProducer() override; + void UpdtStateAsCustomizedProducedRegst(Regst* regst) override; + void NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) override; + + int32_t acc_cnt_; + int32_t max_acc_num_; + bool inplace_consume_; + int64_t consumed_tick_regst_desc_id_; + int64_t produced_tick_regst_desc_id_; + RegstSlot consumed_tick_rs_; + RegstSlot produced_tick_rs_; +}; + +void AccCtrlTickActor::VirtualActorInit(const TaskProto& proto) { + acc_cnt_ = 0; + // const OperatorConf op_conf = + // proto.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf(); + // max_acc_num_ = user_op::UserOpConfWrapper(op_conf).attr("max_acc_num"); + + // NOTE(chengcheng): check time shape equal max_acc_num + const Shape& in_time_shape = Singleton::Get() + ->RegstDesc4RegstDescId(Name2SoleRegstDescId("in")) + .data_regst_time_shape(); + max_acc_num_ = in_time_shape.elem_cnt(); + CHECK_GT(max_acc_num_, 1); + + /* + const Shape& out_time_shape = Singleton::Get() + ->RegstDesc4RegstDescId(Name2SoleRegstDescId("out")) + .data_regst_time_shape(); + CHECK_EQ(in_time_shape.elem_cnt() % out_time_shape.elem_cnt(), 0); + CHECK_EQ(in_time_shape.elem_cnt() / out_time_shape.elem_cnt(), max_acc_num_); + CHECK_GT(max_acc_num_, 1); + */ + + // input + const auto& consumed_ids = proto.consumed_regst_desc_id(); + CHECK_EQ(consumed_ids.size(), 1); + CHECK(consumed_ids.find("in") != consumed_ids.end()); + const auto& in_ids = consumed_ids.at("in"); + CHECK_EQ(in_ids.regst_desc_id_size(), 1); + consumed_tick_regst_desc_id_ = in_ids.regst_desc_id(0); + consumed_tick_rs_.InsertRegstDescId(consumed_tick_regst_desc_id_); + consumed_tick_rs_.InitedDone(); + + // output + CHECK_EQ(proto.produced_regst_desc().size(), 1); + for (const auto& pair : proto.produced_regst_desc()) { + const RegstDescProto& out_regst_desc = pair.second; + if (out_regst_desc.regst_desc_type().has_ctrl_regst_desc()) { + CHECK_EQ(out_regst_desc.register_num(), 1); + CHECK_EQ(produced_tick_regst_desc_id_, -1); + produced_tick_regst_desc_id_ = out_regst_desc.regst_desc_id(); + produced_tick_rs_.InsertRegstDescId(produced_tick_regst_desc_id_); + produced_tick_rs_.InitedDone(); + } + } + CHECK_NE(produced_tick_regst_desc_id_, -1); + + ForEachProducedRegst([&](Regst* regst) { + // if (regst->regst_desc_id() == produced_tick_regst_desc_id_) { + CHECK_EQ(regst->regst_desc_id(), produced_tick_regst_desc_id_); + CHECK_EQ(0, produced_tick_rs_.TryPushBackRegst(regst)); + // } + }); + + LOG(INFO) << " ccActorLog: actor: " << actor_id() + << " has produced_tick_rs_ regst_descs = " << produced_tick_rs_.total_regst_desc_cnt() + << " with regsts size = " + << produced_tick_rs_.GetReadyRegstSize(produced_tick_regst_desc_id_); + LOG(INFO) << " ccActorLog: actor: " << actor_id() + << " has consumed_tick_rs_ regst_descs = " << consumed_tick_rs_.total_regst_desc_cnt() + << " with regsts size = " + << consumed_tick_rs_.GetReadyRegstSize(consumed_tick_regst_desc_id_); + LOG(INFO) + << " ccActorLog: actor: " << actor_id() + << " has inplace_consumed_rs_ regst_descs = " << inplace_consumed_rs_.total_regst_desc_cnt() + << " \nhas inplace_produced_rs_ regst_descs = " << inplace_produced_rs_.total_regst_desc_cnt() + << " \nhas naive_consumed_rs_ regst_descs = " << naive_consumed_rs_.total_regst_desc_cnt() + << " \nhas naive_produced_rs_ regst_descs = " << naive_produced_rs_.total_regst_desc_cnt(); + OF_SET_MSG_HANDLER(&AccCtrlTickActor::HandlerNormal); +} + +void AccCtrlTickActor::Act() { + acc_cnt_ += 1; + LOG(INFO) << " ccActorLog: actor: " << actor_id() << " acc_count_ = " << acc_cnt_ + << " max_acc_num = " << max_acc_num_; + if (acc_cnt_ == max_acc_num_) { + CHECK(!inplace_consume_); + inplace_consume_ = true; + LOG(INFO) << " ccActorLog: actor: " << actor_id() << " inplace_consume_ = true"; + acc_cnt_ = 0; + } +} + +void AccCtrlTickActor::AsyncSendCustomizedProducedRegstMsgToConsumer() { + if (inplace_consume_) { + CHECK(consumed_tick_rs_.IsCurSlotReady()); // inplace consume + CHECK(produced_tick_rs_.IsCurSlotReady()); + Regst* const tick_regst = produced_tick_rs_.Front(produced_tick_regst_desc_id_); + CHECK_GT(HandleRegstToConsumer(tick_regst), 0); + produced_tick_rs_.PopFrontRegsts({produced_tick_regst_desc_id_}); + + LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << acc_cnt_ + << " Send ctrl_tick regst " << produced_tick_regst_desc_id_ << " to Consumer."; + } else { + LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << acc_cnt_ + << " SKIP to send produced to consumer."; + } +} + +void AccCtrlTickActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { + if (!inplace_consume_) { + Regst* const tick_regst = consumed_tick_rs_.Front(consumed_tick_regst_desc_id_); + CHECK_NOTNULL(tick_regst); + AsyncSendRegstMsgToProducer(tick_regst); + CHECK_EQ(0, consumed_tick_rs_.TryPopFrontRegst(consumed_tick_regst_desc_id_)); + + LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << acc_cnt_ + << " return tick regst " << consumed_tick_regst_desc_id_ << " to producer."; + } else { + LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << acc_cnt_ + << " NOT return tick regst for waiting inplace tick regst returned. "; + } +} + +void AccCtrlTickActor::UpdtStateAsCustomizedProducedRegst(Regst* regst) { + CHECK(inplace_consume_); + CHECK_EQ(regst->regst_desc_id(), produced_tick_regst_desc_id_); + CHECK_EQ(produced_tick_rs_.TryPushBackRegst(regst), 0); + LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << acc_cnt_ + << " regst_desc_id: " << produced_tick_regst_desc_id_ << " ready size = " + << produced_tick_rs_.GetReadyRegstSize(produced_tick_regst_desc_id_); + + Regst* in_regst = consumed_tick_rs_.Front(consumed_tick_regst_desc_id_); + CHECK(in_regst); + AsyncSendRegstMsgToProducer(in_regst); + CHECK_EQ(0, consumed_tick_rs_.TryPopFrontRegst(consumed_tick_regst_desc_id_)); + inplace_consume_ = false; + + LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << acc_cnt_ + << " consumed_regst_desc_id: " << consumed_tick_regst_desc_id_ + << " return with all produced regst."; +} + +void AccCtrlTickActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) { + CHECK_EQ(0, consumed_tick_rs_.TryPushBackRegst(msg.regst())); + LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << acc_cnt_ + << " receive input regst: " << msg.regst()->regst_desc_id(); +} + +REGISTER_ACTOR(TaskType::kAccCtrlTick, AccCtrlTickActor); + +} // namespace oneflow diff --git a/oneflow/core/lazy/actor/register_slot.cpp b/oneflow/core/lazy/actor/register_slot.cpp index 713dd0b6eaa..182710eaf92 100644 --- a/oneflow/core/lazy/actor/register_slot.cpp +++ b/oneflow/core/lazy/actor/register_slot.cpp @@ -17,6 +17,13 @@ limitations under the License. namespace oneflow { +int64_t RegstSlot::GetReadyRegstSize(int64_t regst_desc_id) const { + CHECK(is_inited_); + auto it = regst_desc_id2regsts_.find(regst_desc_id); + if (it == regst_desc_id2regsts_.end()) { return -1; } + return it->second.size(); +} + bool RegstSlot::HasRegstDescId(int64_t regst_desc_id) const { CHECK(is_inited_); return regst_desc_id2regsts_.find(regst_desc_id) != regst_desc_id2regsts_.end(); diff --git a/oneflow/core/lazy/actor/register_slot.h b/oneflow/core/lazy/actor/register_slot.h index a1db92c5e24..12b67b9e47b 100644 --- a/oneflow/core/lazy/actor/register_slot.h +++ b/oneflow/core/lazy/actor/register_slot.h @@ -30,6 +30,7 @@ class RegstSlot final { size_t total_regst_desc_cnt() const { return regst_desc_id2regsts_.size(); } size_t available_regst_desc_cnt() const { return available_regst_desc_cnt_; } + int64_t GetReadyRegstSize(int64_t regst_desc_id) const; bool IsCurSlotReady() const { return available_regst_desc_cnt() == total_regst_desc_cnt(); } bool HasRegstDescId(int64_t regst_desc_id) const; const std::deque& RegstDeq4RegstDescId(int64_t regst_desc_id) const; diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 8fdef254b3f..0415dfee985 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -8372,8 +8372,8 @@ def OneFlow_NotEqualZeroGradOp : OneFlow_BaseOp<"not_equal_zero_grad", [NoSideEf #endif // GET_ONEFLOW_TRIGONOMETRIC_OP_DEFINITIONS // Group: UNARY -// acc, affine_grid, affine_grid_grad, bernoulli, cast, cast_to_static_shape, cast_to_tick, celu, copy, count_not_finite, diag, diagonal, elu, expand, expand_dims, flatten, flip, fold, gelu, hardsigmoid, hardshrink, hardswish, leaky_relu, log2, logical_not, mish, narrow, one_hot, pack, random_mask_like, repeat, roll, selu, silu, softshrink, softsign, sort, square_sum, squeeze, threshold, transpose, tril, triu, unfold, unfold_tensor, unpack, zero_like, to_contiguous, isnan, isinf, isfinite, repeat_interleave, mutable_cast_once -// Total: 52 +// acc, affine_grid, affine_grid_grad, bernoulli, cast, cast_to_static_shape, cast_to_tick, celu, copy, count_not_finite, diag, diagonal, elu, expand, expand_dims, flatten, flip, fold, gelu, hardsigmoid, hardshrink, hardswish, leaky_relu, log2, logical_not, mish, narrow, one_hot, pack, random_mask_like, repeat, roll, selu, silu, softshrink, softsign, sort, square_sum, squeeze, threshold, transpose, tril, triu, unfold, unfold_tensor, unpack, zero_like, to_contiguous, isnan, isinf, isfinite, repeat_interleave, mutable_cast_once, acc_ctrl_tick +// Total: 53 #ifdef GET_ONEFLOW_UNARY_OP_DEFINITIONS @@ -8510,6 +8510,24 @@ def OneFlow_CastToTickOp : OneFlow_BaseOp<"cast_to_tick", [NoSideEffect, NoGrad, let has_nd_sbp_infer_fn = 1; } +def OneFlow_AccCtrlTickOp : OneFlow_BaseOp<"acc_ctrl_tick", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$max_acc_num + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; + let has_output_blob_time_shape_infer_fn = 1; +} + def OneFlow_CeluOp : OneFlow_BaseOp<"celu", [NoSideEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$in diff --git a/oneflow/user/ops/acc_ctrl_tick_op.cpp b/oneflow/user/ops/acc_ctrl_tick_op.cpp new file mode 100644 index 00000000000..5e3571e4d43 --- /dev/null +++ b/oneflow/user/ops/acc_ctrl_tick_op.cpp @@ -0,0 +1,79 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/operator/operator.h" +#include "oneflow/core/framework/op_generated.h" + +namespace oneflow { + +/* static */ Maybe AccCtrlTickOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + Shape* out_shape = ctx->MutOutputShape("out", 0); + *out_shape = Shape({1}); + return Maybe::Ok(); +} + +/*static*/ Maybe AccCtrlTickOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AccCtrlTickOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe AccCtrlTickOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), parallel_hierarchy.NumAxes()); + + NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + // in use hint + in_distribution->CopyFrom(in_dis_hint); + + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + // out dim1 = broadcast + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + } + return Maybe::Ok(); +} + +/* static */ Maybe AccCtrlTickOp::InferDataType(user_op::InferContext* ctx) { + *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe AccCtrlTickOp::InferOutputBlobTimeShape( + user_op::InferOutputBlobTimeShapeFnContext* ctx) { + const int32_t max_acc_num = ctx->user_op_conf().attr("max_acc_num"); + const Shape& in_time_shape = ctx->TimeShape4InputArgNameAndIndex("in", 0); + DimVector time_shape_dim_vec = in_time_shape.dim_vec(); + CHECK_OR_RETURN(!time_shape_dim_vec.empty()); + if (time_shape_dim_vec.back() == max_acc_num) { + time_shape_dim_vec.pop_back(); + } else if (time_shape_dim_vec.back() % max_acc_num == 0) { + time_shape_dim_vec.back() /= max_acc_num; + } else { + const int64_t elem_cnt = in_time_shape.elem_cnt(); + time_shape_dim_vec.resize(1); + time_shape_dim_vec.back() = elem_cnt / max_acc_num; + } + *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); + return Maybe::Ok(); +} + +} // namespace oneflow From 23c7721307746ff3015b9bccfc2880c815f522a8 Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 26 Aug 2022 09:14:10 +0000 Subject: [PATCH 11/66] tmp --- oneflow/core/job/compiler.cpp | 1 + oneflow/core/job/job.proto | 5 ++++ oneflow/core/job/job_conf.proto | 2 +- oneflow/core/job/plan_util.cpp | 4 +++ oneflow/core/job/plan_util.h | 1 + .../insert_nccl_logical_op_pass.cpp | 3 ++ .../core/job_rewriter/logical_chain_pass.cpp | 30 +++++++++++++++++++ 7 files changed, 45 insertions(+), 1 deletion(-) diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 4fc11e0eb87..fb757c75216 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -105,6 +105,7 @@ void Compiler::Compile(Job* job, Plan* plan) const { (*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf(); // NOTE(chengcheng): infer mem blob id & set inplace & add ctrl IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan, IsReachable); + PlanUtil::MergeMemBlockIdByLogicalChainId(plan, *job); PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan); Singleton::Delete(); } diff --git a/oneflow/core/job/job.proto b/oneflow/core/job/job.proto index 9be3edfb2b9..c6551ee51cc 100644 --- a/oneflow/core/job/job.proto +++ b/oneflow/core/job/job.proto @@ -25,6 +25,10 @@ message JobHelperConf { map op_name2arg_signature = 9; } +message MergedLogicalChainIdGroup { + repeated int64 logical_chain_id_list = 1; +} + message Job { optional DLNetConf net = 1; optional Placement placement = 2; @@ -32,4 +36,5 @@ message Job { optional JobParallelViewConf job_parallel_view_conf = 4; optional JobHelperConf helper = 5; map module_name2module_conf = 6; + repeated MergedLogicalChainIdGroup logical_chain_groups = 7; } diff --git a/oneflow/core/job/job_conf.proto b/oneflow/core/job/job_conf.proto index 1b7035877e1..6268ce4e4ed 100644 --- a/oneflow/core/job/job_conf.proto +++ b/oneflow/core/job/job_conf.proto @@ -167,7 +167,7 @@ message PredictConf { message MemoryAllocationAlgorithmConf { optional bool use_mem_size_first_algo = 1 [default = true]; optional bool use_mutual_exclusion_first_algo = 2 [default = true]; - optional bool use_time_line_algo = 3 [default = false]; + optional bool use_time_line_algo = 3 [default = true]; } message QatConfig { diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index cfbc19b5f83..25fc29d38ba 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -188,6 +188,10 @@ void GenChunkForMultiNNGraphMemoryReuseInMultiClient( } // namespace +void PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job) { + // TODO(); +} + void PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan( Plan* plan, const HashSet& variable_op_names) { HashMap> mem_block_id2mem_block; diff --git a/oneflow/core/job/plan_util.h b/oneflow/core/job/plan_util.h index 6c45db300e0..9de7142a629 100644 --- a/oneflow/core/job/plan_util.h +++ b/oneflow/core/job/plan_util.h @@ -28,6 +28,7 @@ namespace oneflow { struct PlanUtil { static RegstDescProto* GetSoleProducedDataRegst(TaskProto* task_proto); static std::function MakeGetterTaskProto4TaskId(const Plan& plan); + static void MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job); static void SetUniqueMemBlockId4UnreusedMemRegst(Plan* plan); static void GenMemBlockAndChunk4Plan(Plan* plan); static void GenMemBlockAndChunkWithVariableOpNames4Plan( diff --git a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp index 1679a8d3303..d8aff601abf 100644 --- a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp +++ b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp @@ -968,10 +968,12 @@ void InsertBwSinkAccTickAndNcclLogicalOpsInPlacementGroupAfterAcc( .OpTypeName("cast_to_tick") .Input("in", bw_sink_op_out_lbn) .Output("out") + .ScopeSymbolId(bw_sink_op->op().op_conf().scope_symbol_id()) .Build(); OperatorConf bw_sink_acc_tick_conf; bw_sink_acc_tick_conf.set_name(std::string("System-BwSinkTick-AccTick_") + NewUniqueId()); + bw_sink_acc_tick_conf.set_scope_symbol_id(bw_sink_op->op().op_conf().scope_symbol_id()); auto* acc_conf = bw_sink_acc_tick_conf.mutable_acc_tick_conf(); acc_conf->set_one(cast_to_tick_op.output("out", 0)); acc_conf->set_acc("acc"); @@ -980,6 +982,7 @@ void InsertBwSinkAccTickAndNcclLogicalOpsInPlacementGroupAfterAcc( OperatorConf bw_sink_final_tick_conf; bw_sink_final_tick_conf.set_name(std::string("System-BwSinkFinalTick-DeviceTick_") + NewUniqueId()); + bw_sink_final_tick_conf.set_scope_symbol_id(bw_sink_op->op().op_conf().scope_symbol_id()); auto* tick_conf = bw_sink_final_tick_conf.mutable_device_tick_conf(); tick_conf->add_tick(GenLogicalBlobName(bw_sink_acc_tick_conf.name(), "acc")); tick_conf->set_out("out"); diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index 44195a3bb72..b133859652c 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -254,6 +254,7 @@ void CreateAfterAccLogicalChain(const std::shared_ptr& after_acc_l CHECK(after_acc_chain_ops.insert(cur_node).second); for (const OpEdge* in_edge : cur_node->in_edges()) { + // NOTE(chengcheng): maybe bad case for too early source op before repeat. SearchToNextNode(cur_node, in_edge->src_node(), in_edge); } for (const OpEdge* out_edge : cur_node->out_edges()) { @@ -359,6 +360,7 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui for (auto& pair : placement2logical_chains) { auto& info = pair.second; + CHECK_GE(info.ordered_logical_chains.size(), 1); for (int i = 0; i < info.ordered_logical_chains.size() - 1; i++) { CHECK_LT(JUST(VectorAt(info.ordered_logical_chains, i))->begin_op_global_order, JUST(VectorAt(info.ordered_logical_chains, i + 1))->begin_op_global_order); @@ -391,6 +393,34 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui InsertLogicalChainId(chain_order_ops, info.after_acc_logical_chain->logical_chain_id); InsertCtrlEdgeInChain(chain_order_ops); } + + // NOTE(chengcheng): + // creat repeat tick + // 1.add acc ctrl tick between first chain src to acc chain sink for memory lock. + // 2.add acc tick between first chain sink to acc chain src for strict exec order. + const int64_t acc_num = job_builder->job().job_conf().num_gradient_accumulation_steps(); + CHECK_GT(acc_num, 1); + const OpNode* first_chain_src_op = info.ordered_logical_chains.front()->begin_op; + const auto& fcs_obns = first_chain_src_op->op().output_bns(); + CHECK(!fcs_obns.empty()); + const std::string& first_chain_src_out_lbn = + GenLogicalBlobName(first_chain_src_op->op().BnInOp2Lbi(fcs_obns.Get(0))); + + VLOG(3) << " first_chain_src_out_lbn : " << first_chain_src_out_lbn; + user_op::UserOpConfWrapper acc_ctrl_tick_op = + user_op::UserOpConfWrapperBuilder("Sys-AccCtrlTick4MergeFirstAccChain-" + NewUniqueId()) + .OpTypeName("acc_ctrl_tick") + .Input("in", first_chain_src_out_lbn) + .Output("out") + .ScopeSymbolId(first_chain_src_op->op().op_conf().scope_symbol_id()) + .Attr("max_acc_num", acc_num) + .Build(); + + JUST(MapAt(mut_op_name2conf, info.after_acc_logical_chain->end_op->op().op_name())) + .add_ctrl_in_op_name(acc_ctrl_tick_op.op_name()); + + JUST(job_builder->AddOp(first_chain_src_op->parallel_desc().parallel_conf(), + acc_ctrl_tick_op.op_conf())); } } From f32d24724472d93aa459e37fbe7be26caaba3102 Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 26 Aug 2022 11:48:03 +0000 Subject: [PATCH 12/66] AccCtrlTick runnable --- .../acc_ctrl_tick_compute_task_node.cpp | 6 +-- .../core/job_rewriter/logical_chain_pass.cpp | 39 ++++++++++++++++--- .../core/lazy/actor/acc_ctrl_tick_actor.cpp | 34 +++++++++------- oneflow/user/kernels/nop_kernel.cpp | 1 + oneflow/user/ops/acc_ctrl_tick_op.cpp | 7 ++-- 5 files changed, 61 insertions(+), 26 deletions(-) diff --git a/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp b/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp index 80ead618b5f..adf7b2b89d1 100644 --- a/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp +++ b/oneflow/core/graph_impl/acc_ctrl_tick_compute_task_node.cpp @@ -30,8 +30,8 @@ class AccCtrlTickCompTaskNode final : public CompTaskNode { }; void AccCtrlTickCompTaskNode::ProduceAllRegstsAndBindEdges() { - // std::shared_ptr regst = ProduceRegst("out", false); - // ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("out", regst); }); + std::shared_ptr regst = ProduceRegst("out", false); + ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("out", regst); }); } void AccCtrlTickCompTaskNode::ConsumeAllRegsts() { @@ -39,7 +39,6 @@ void AccCtrlTickCompTaskNode::ConsumeAllRegsts() { } void AccCtrlTickCompTaskNode::BuildExecGphAndRegst() { - /* std::shared_ptr in_regst = GetSoleConsumedRegst("in"); std::shared_ptr out_regst = GetProducedRegst("out"); std::shared_ptr op = this->op(); @@ -49,7 +48,6 @@ void AccCtrlTickCompTaskNode::BuildExecGphAndRegst() { out_regst->AddLbi(op->BnInOp2Lbi(op->SoleObn())); exec_node->BindBnWithRegst(op->SoleObn(), out_regst); exec_node->InferBlobDescs(parallel_ctx()); - */ } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kAccCtrlTick); diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index 4e81149a1f9..d8a0d4b93c3 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -76,7 +76,8 @@ bool IsBreakpointOpNode(const OpNode* node) { // TODO(chengcheng): acc node can be merged in chain. if (user_type_name == "repeat" || user_type_name == "acc" || user_type_name == "pack" || user_type_name == "unpack" || user_type_name == "identity_buffer" - || user_type_name == "copy_h2d" || user_type_name == "copy_d2h") { + || user_type_name == "copy_h2d" || user_type_name == "copy_d2h" + || user_type_name == "acc_ctrl_tick") { return true; } } @@ -359,6 +360,7 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui }; for (auto& pair : placement2logical_chains) { + const auto& placement = pair.first; auto& info = pair.second; CHECK_GE(info.ordered_logical_chains.size(), 1); for (int i = 0; i < info.ordered_logical_chains.size() - 1; i++) { @@ -414,13 +416,40 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui .Attr("max_acc_num", acc_num) .Build(); - JUST(MapAt(mut_op_name2conf, info.after_acc_logical_chain->end_op->op().op_name())) - .add_ctrl_in_op_name(acc_ctrl_tick_op.op_name()); + OperatorConf& consumer = + JUST(MapAt(mut_op_name2conf, info.after_acc_logical_chain->end_op->op().op_name())); + if (consumer.has_user_conf()) { + (*consumer.mutable_user_conf()->mutable_input())[user_op::kUserSourceOpTickInputArgName] + .add_s(acc_ctrl_tick_op.output("out", 0)); + JUST(job_builder->AddOp(first_chain_src_op->parallel_desc().parallel_conf(), + acc_ctrl_tick_op.op_conf())); + } + } + } + + for (const auto& logical_chain : info.ordered_logical_chains) { + VLOG(3) << " In placement: " << placement + << " logical_chain_id: " << logical_chain->logical_chain_id + << " has op num = " << logical_chain->ordered_op_nodes.size(); - JUST(job_builder->AddOp(first_chain_src_op->parallel_desc().parallel_conf(), - acc_ctrl_tick_op.op_conf())); + for (int i = 0; i < logical_chain->ordered_op_nodes.size(); ++i) { + const OpNode* ordered_op = JUST(VectorAt(logical_chain->ordered_op_nodes, i)); + VLOG(3) << " ChainId: " << logical_chain_id << " order: " << i + << " op_name: " << ordered_op->op().op_name() + << " global_order: " << JUST(MapAt(op_node2global_order, ordered_op)); } } + + VLOG(3) << " In placement: " << placement + << " AccLogicalChain: " << info.after_acc_logical_chain->logical_chain_id + << " has op num = " << info.after_acc_logical_chain->ordered_op_nodes.size(); + + for (int i = 0; i < info.after_acc_logical_chain->ordered_op_nodes.size(); ++i) { + const OpNode* ordered_op = JUST(VectorAt(info.after_acc_logical_chain->ordered_op_nodes, i)); + VLOG(3) << " AfterAccChainId: " << info.after_acc_logical_chain->logical_chain_id + << " order: " << i << " op_name: " << ordered_op->op().op_name() + << " global_order: " << JUST(MapAt(op_node2global_order, ordered_op)); + } } // NOTE(chengcheng): update global order and chain id for ops. diff --git a/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp b/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp index c908780a22a..a1459812ac6 100644 --- a/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp +++ b/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp @@ -77,25 +77,22 @@ class AccCtrlTickActor : public Actor { void AccCtrlTickActor::VirtualActorInit(const TaskProto& proto) { acc_cnt_ = 0; - // const OperatorConf op_conf = - // proto.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf(); - // max_acc_num_ = user_op::UserOpConfWrapper(op_conf).attr("max_acc_num"); + const OperatorConf op_conf = + proto.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf(); + max_acc_num_ = user_op::UserOpConfWrapper(op_conf).attr("max_acc_num"); // NOTE(chengcheng): check time shape equal max_acc_num const Shape& in_time_shape = Singleton::Get() ->RegstDesc4RegstDescId(Name2SoleRegstDescId("in")) .data_regst_time_shape(); - max_acc_num_ = in_time_shape.elem_cnt(); - CHECK_GT(max_acc_num_, 1); - - /* + // max_acc_num_ = in_time_shape.elem_cnt(); + // CHECK_GT(max_acc_num_, 1); const Shape& out_time_shape = Singleton::Get() ->RegstDesc4RegstDescId(Name2SoleRegstDescId("out")) .data_regst_time_shape(); CHECK_EQ(in_time_shape.elem_cnt() % out_time_shape.elem_cnt(), 0); CHECK_EQ(in_time_shape.elem_cnt() / out_time_shape.elem_cnt(), max_acc_num_); CHECK_GT(max_acc_num_, 1); - */ // input const auto& consumed_ids = proto.consumed_regst_desc_id(); @@ -109,17 +106,26 @@ void AccCtrlTickActor::VirtualActorInit(const TaskProto& proto) { // output CHECK_EQ(proto.produced_regst_desc().size(), 1); + /* for (const auto& pair : proto.produced_regst_desc()) { const RegstDescProto& out_regst_desc = pair.second; if (out_regst_desc.regst_desc_type().has_ctrl_regst_desc()) { - CHECK_EQ(out_regst_desc.register_num(), 1); - CHECK_EQ(produced_tick_regst_desc_id_, -1); - produced_tick_regst_desc_id_ = out_regst_desc.regst_desc_id(); - produced_tick_rs_.InsertRegstDescId(produced_tick_regst_desc_id_); - produced_tick_rs_.InitedDone(); - } + CHECK_EQ(out_regst_desc.register_num(), 1); + CHECK_EQ(produced_tick_regst_desc_id_, -1); + produced_tick_regst_desc_id_ = out_regst_desc.regst_desc_id(); + produced_tick_rs_.InsertRegstDescId(produced_tick_regst_desc_id_); + produced_tick_rs_.InitedDone(); + } } CHECK_NE(produced_tick_regst_desc_id_, -1); + */ + const auto& produced_ids = proto.produced_regst_desc(); + CHECK_EQ(produced_ids.size(), 1); + CHECK(produced_ids.find("out") != produced_ids.end()); + const RegstDescProto& out_regst_desc = produced_ids.at("out"); + produced_tick_regst_desc_id_ = out_regst_desc.regst_desc_id(); + produced_tick_rs_.InsertRegstDescId(produced_tick_regst_desc_id_); + produced_tick_rs_.InitedDone(); ForEachProducedRegst([&](Regst* regst) { // if (regst->regst_desc_id() == produced_tick_regst_desc_id_) { diff --git a/oneflow/user/kernels/nop_kernel.cpp b/oneflow/user/kernels/nop_kernel.cpp index f5e0ee86d98..66017e492f8 100644 --- a/oneflow/user/kernels/nop_kernel.cpp +++ b/oneflow/user/kernels/nop_kernel.cpp @@ -35,6 +35,7 @@ class NopKernel final : public user_op::OpKernel { REGISTER_USER_KERNEL(op_type_name).SetCreateFn().SetIsMatchedHob(user_op::HobTrue()); REGISTER_NOP_KERNEL("cast_to_tick") +REGISTER_NOP_KERNEL("acc_ctrl_tick") } // namespace diff --git a/oneflow/user/ops/acc_ctrl_tick_op.cpp b/oneflow/user/ops/acc_ctrl_tick_op.cpp index 5e3571e4d43..9db2fce4bdc 100644 --- a/oneflow/user/ops/acc_ctrl_tick_op.cpp +++ b/oneflow/user/ops/acc_ctrl_tick_op.cpp @@ -36,7 +36,8 @@ namespace oneflow { /* static */ Maybe AccCtrlTickOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { const NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), parallel_hierarchy.NumAxes()); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), // NOLINT(maybe-need-error-msg) + parallel_hierarchy.NumAxes()); // NOLINT(maybe-need-error-msg) NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); @@ -61,8 +62,8 @@ namespace oneflow { user_op::InferOutputBlobTimeShapeFnContext* ctx) { const int32_t max_acc_num = ctx->user_op_conf().attr("max_acc_num"); const Shape& in_time_shape = ctx->TimeShape4InputArgNameAndIndex("in", 0); - DimVector time_shape_dim_vec = in_time_shape.dim_vec(); - CHECK_OR_RETURN(!time_shape_dim_vec.empty()); + DimVector time_shape_dim_vec = in_time_shape.dim_vec(); // NOLINT(maybe-need-error-msg) + CHECK_OR_RETURN(!time_shape_dim_vec.empty()); // NOLINT(maybe-need-error-msg) if (time_shape_dim_vec.back() == max_acc_num) { time_shape_dim_vec.pop_back(); } else if (time_shape_dim_vec.back() % max_acc_num == 0) { From d6c1760a38b08bb944929c2beca58fda65a89a6f Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 26 Aug 2022 12:12:16 +0000 Subject: [PATCH 13/66] rename group boxing identity and model diff scale op name --- oneflow/core/job_rewriter/autograd.cpp | 12 ++++++++---- .../job_rewriter/group_boxing_by_dst_parallel.cpp | 3 ++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/oneflow/core/job_rewriter/autograd.cpp b/oneflow/core/job_rewriter/autograd.cpp index c840085913f..16c8c8e0c1d 100644 --- a/oneflow/core/job_rewriter/autograd.cpp +++ b/oneflow/core/job_rewriter/autograd.cpp @@ -112,7 +112,8 @@ void ScaleModelDiffByConstantLossInstanceNum(const OpGraph& op_graph, JobBuilder const LogicalBlobId& lbi = pair.first; LogicalBlobId& diff_lbi = pair.second; auto scalar_mul_op = - user_op::UserOpConfWrapperBuilder("System-ModelDiffScale-ScalarMul_" + NewUniqueId()) + user_op::UserOpConfWrapperBuilder("Sys-DiffScale-ScalarMul-" + lbi.op_name() + "_" + + lbi.blob_name() + "-" + NewUniqueId()) .Op("scalar_mul") .Input("in", GenLogicalBlobName(diff_lbi)) .Output("out") @@ -230,7 +231,8 @@ void ScaleModelDiffByDynamicLossInstanceNum( const LogicalBlobId& lbi = pair.first; LogicalBlobId& diff_lbi = pair.second; auto scalar_div_op = - user_op::UserOpConfWrapperBuilder("System-ModelDiffScale-ScalarDiv_" + NewUniqueId()) + user_op::UserOpConfWrapperBuilder("Sys-DiffScale-ScalarDiv-" + lbi.op_name() + "_" + + lbi.blob_name() + "-" + NewUniqueId()) .Op("scalar_div_by_tensor") .Input("x", GenLogicalBlobName(diff_lbi)) .Input("scalar", GenLogicalBlobName(total_loss_instance_num_lbi)) @@ -857,7 +859,8 @@ void ScaleModelDiffByLossScale(JobPassCtx* ctx, const OpGraph& op_graph, JobBuil const LogicalBlobId& lbi = pair.first; LogicalBlobId& diff_lbi = pair.second; auto scalar_mul_op = - user_op::UserOpConfWrapperBuilder("System-ModelDiffScale-ScalarMul-" + NewUniqueId()) + user_op::UserOpConfWrapperBuilder("Sys-DiffScale-ScalarMul-" + lbi.op_name() + "_" + + lbi.blob_name() + "-" + NewUniqueId()) .Op("scalar_mul_by_tensor") .Input("x", GenLogicalBlobName(diff_lbi)) .Input("scalar", LossScale4DataType(op_graph.GetLogicalBlobDesc(lbi).data_type())) @@ -875,7 +878,8 @@ void ScaleModelDiffByLossScale(JobPassCtx* ctx, const OpGraph& op_graph, JobBuil const LogicalBlobId& lbi = pair.first; LogicalBlobId& diff_lbi = pair.second; auto scalar_mul_op = - user_op::UserOpConfWrapperBuilder("System-ModelDiffScale-ScalarMul-" + NewUniqueId()) + user_op::UserOpConfWrapperBuilder("Sys-DiffScale-ScalarMul-" + lbi.op_name() + "_" + + lbi.blob_name() + "-" + NewUniqueId()) .Op("scalar_mul") .Input("in", GenLogicalBlobName(diff_lbi)) .Output("out") diff --git a/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp b/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp index dd3e8039af7..00dc3138b40 100644 --- a/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp +++ b/oneflow/core/job_rewriter/group_boxing_by_dst_parallel.cpp @@ -101,7 +101,8 @@ Maybe GroupBoxingByDstParallel(const OpGraph& op_graph, JobBuilder* job_bu const ParallelDesc& dst_parallel_desc = parallel7group.first.first; const NdSbp& dst_nd_sbp = parallel7group.first.second; OperatorConf identity_op_conf{}; - identity_op_conf.set_name("System-Boxing-Identity-" + NewUniqueId()); + identity_op_conf.set_name("Sys-Boxing-GroupIdentity-" + lbi.op_name() + "_" + lbi.blob_name() + + "-" + NewUniqueId()); IdentityOpConf* identity_conf = identity_op_conf.mutable_identity_conf(); identity_conf->set_in(GenLogicalBlobName(lbi)); identity_conf->set_out("out"); From c15cbf046c23bc8f004c3d8fb5002a273c0a456b Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 26 Aug 2022 13:43:43 +0000 Subject: [PATCH 14/66] stric order by acc tick --- .../core/job_rewriter/logical_chain_pass.cpp | 63 ++++++++++++++++--- 1 file changed, 53 insertions(+), 10 deletions(-) diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index d8a0d4b93c3..62ec77fc44d 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -397,14 +397,13 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui // NOTE(chengcheng): // 1.add acc ctrl tick between first chain src to acc chain sink for memory lock. - // 2.add acc tick between first chain sink to acc chain src for strict exec order. const int64_t acc_num = job_builder->job().job_conf().num_gradient_accumulation_steps(); CHECK_GT(acc_num, 1); const OpNode* first_chain_src_op = info.ordered_logical_chains.front()->begin_op; - const auto& fcs_obns = first_chain_src_op->op().output_bns(); - CHECK(!fcs_obns.empty()); + const auto& fc_src_obns = first_chain_src_op->op().output_bns(); + CHECK(!fc_src_obns.empty()); const std::string& first_chain_src_out_lbn = - GenLogicalBlobName(first_chain_src_op->op().BnInOp2Lbi(fcs_obns.Get(0))); + GenLogicalBlobName(first_chain_src_op->op().BnInOp2Lbi(fc_src_obns.Get(0))); VLOG(3) << " first_chain_src_out_lbn : " << first_chain_src_out_lbn; user_op::UserOpConfWrapper acc_ctrl_tick_op = @@ -418,12 +417,56 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui OperatorConf& consumer = JUST(MapAt(mut_op_name2conf, info.after_acc_logical_chain->end_op->op().op_name())); - if (consumer.has_user_conf()) { - (*consumer.mutable_user_conf()->mutable_input())[user_op::kUserSourceOpTickInputArgName] - .add_s(acc_ctrl_tick_op.output("out", 0)); - JUST(job_builder->AddOp(first_chain_src_op->parallel_desc().parallel_conf(), - acc_ctrl_tick_op.op_conf())); - } + CHECK(consumer.has_user_conf()); + (*consumer.mutable_user_conf()->mutable_input())[user_op::kUserSourceOpTickInputArgName] + .add_s(acc_ctrl_tick_op.output("out", 0)); + JUST(job_builder->AddOp(first_chain_src_op->parallel_desc().parallel_conf(), + acc_ctrl_tick_op.op_conf())); + + // NOTE(chengcheng): + // 2.add acc tick between first chain sink to acc chain src for strict exec order. + const OpNode* first_chain_sink_op = info.ordered_logical_chains.front()->end_op; + const auto& fc_sink_obns = first_chain_sink_op->op().output_bns(); + CHECK(!fc_sink_obns.empty()); + const std::string first_chain_sink_lbn = + GenLogicalBlobName(first_chain_sink_op->op().BnInOp2Lbi(fc_sink_obns.Get(0))); + VLOG(3) << " first_chain_sink_lbn : " << first_chain_sink_lbn; + + user_op::UserOpConfWrapper cast_to_tick_op = + user_op::UserOpConfWrapperBuilder("Sys-LogicalChainSink-CastToTick-" + NewUniqueId()) + .OpTypeName("cast_to_tick") + .Input("in", first_chain_sink_lbn) + .Output("out") + .ScopeSymbolId(first_chain_sink_op->op().op_conf().scope_symbol_id()) + .Build(); + + OperatorConf sink_acc_tick_conf; + sink_acc_tick_conf.set_name(std::string("Sys-LogicalChainSink-AccTick_") + NewUniqueId()); + sink_acc_tick_conf.set_scope_symbol_id( + first_chain_sink_op->op().op_conf().scope_symbol_id()); + auto* acc_conf = sink_acc_tick_conf.mutable_acc_tick_conf(); + acc_conf->set_one(cast_to_tick_op.output("out", 0)); + acc_conf->set_acc("acc"); + acc_conf->set_max_acc_num(acc_num); + + OperatorConf sink_final_tick_conf; + sink_final_tick_conf.set_name(std::string("Sys-LogicalChainSink-FinalTick-DeviceTick_") + + NewUniqueId()); + sink_final_tick_conf.set_scope_symbol_id( + first_chain_sink_op->op().op_conf().scope_symbol_id()); + auto* tick_conf = sink_final_tick_conf.mutable_device_tick_conf(); + tick_conf->add_tick(GenLogicalBlobName(sink_acc_tick_conf.name(), "acc")); + tick_conf->set_out("out"); + + JUST(MapAt(mut_op_name2conf, info.after_acc_logical_chain->begin_op->op().op_name())) + .add_ctrl_in_op_name(sink_final_tick_conf.name()); + + CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), + cast_to_tick_op.op_conf())); + CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), + sink_acc_tick_conf)); + CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), + sink_final_tick_conf)); } } From c4ce8fbe451a59f41c88ebe6ec0f1abfb8210a4c Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Tue, 30 Aug 2022 08:29:07 +0000 Subject: [PATCH 15/66] merge mem block by logical chain id group --- oneflow/core/job/job_builder.h | 2 + oneflow/core/job/plan_util.cpp | 98 ++++++++++++++++++- .../core/job_rewriter/logical_chain_pass.cpp | 40 ++------ 3 files changed, 109 insertions(+), 31 deletions(-) diff --git a/oneflow/core/job/job_builder.h b/oneflow/core/job/job_builder.h index a954d12ed7e..f3c78d671a5 100644 --- a/oneflow/core/job/job_builder.h +++ b/oneflow/core/job/job_builder.h @@ -43,6 +43,8 @@ class JobBuilder final { return job_->mutable_job_parallel_view_conf(); } + MergedLogicalChainIdGroup* add_logical_chain_groups() { return job_->add_logical_chain_groups(); } + Maybe OpConf4OpName(const std::string& op_name) const; Maybe MutableOpConf4OpName(const std::string& op_name); diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index 25fc29d38ba..40de2ab8d26 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -189,7 +189,103 @@ void GenChunkForMultiNNGraphMemoryReuseInMultiClient( } // namespace void PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job) { - // TODO(); + if (job.logical_chain_groups_size() == 0) { return; } + HashMap> mem_block_id2regsts; + HashMap> logical_chain_id2machine_id2mem_block_id; + + // HashMap mem_block_id2machine_id; + + for (int64_t i = 0; i < plan->task_size(); ++i) { + TaskProto* task = plan->mutable_task(i); + const StreamId stream_id = PlanUtil::GetStreamId(*task); + int64_t machine_id = task->machine_id(); + DeviceType device_type = stream_id.device_id().device_type(); + // TODO(zwx): eliminate this special 'is cpu' determine + if (device_type == DeviceType::kCPU) { continue; } + if (task->task_set_info().chain_id() == -1) { continue; } + int64_t logical_chain_id = task->task_set_info().chain_id(); + + for (auto& pair : *(task->mutable_produced_regst_desc())) { + RegstDescProto* regst_desc = &pair.second; + if (regst_desc->mem_block_id() != -1 && regst_desc->enable_reuse_mem() + && regst_desc->mem_case().device_type() == device_type + && regst_desc->regst_desc_type().has_data_regst_desc()) { + int64_t mem_block_id = regst_desc->mem_block_id(); + mem_block_id2regsts[mem_block_id].insert(regst_desc); + auto* rank2blocks = &(logical_chain_id2machine_id2mem_block_id[logical_chain_id]); + if (rank2blocks->find(machine_id) == rank2blocks->end()) { + rank2blocks->emplace(machine_id, mem_block_id); + } else { + CHECK_EQ(rank2blocks->at(machine_id), mem_block_id); + } + + /* + if (mem_block_id2machine_id.find(mem_block_id) == mem_block_id2machine_id.end()) { + mem_block_id2machine_id.emplace(mem_block_id, machine_id); + } else { + CHECK_EQ(mem_block_id2machine_id.at(mem_block_id), machine_id); + } + */ + } + } + } + + HashMap mem_block_id2merged_mem_block_id; + for (const auto& logical_chain_group : job.logical_chain_groups()) { + CHECK_GE(logical_chain_group.logical_chain_id_list_size(), 2); + int64_t merged_logical_chain_id = logical_chain_group.logical_chain_id_list(0); + CHECK(logical_chain_id2machine_id2mem_block_id.find(merged_logical_chain_id) + != logical_chain_id2machine_id2mem_block_id.end()); + const auto& merged_rank2block = + logical_chain_id2machine_id2mem_block_id.at(merged_logical_chain_id); + for (int64_t i = 1; i < logical_chain_group.logical_chain_id_list_size(); ++i) { + int64_t this_logical_chain_id = logical_chain_group.logical_chain_id_list(i); + // NOTE(chengcheng): merge mem block id by each rank + CHECK(logical_chain_id2machine_id2mem_block_id.find(this_logical_chain_id) + != logical_chain_id2machine_id2mem_block_id.end()); + const auto& this_rank2block = + logical_chain_id2machine_id2mem_block_id.at(this_logical_chain_id); + for (const auto& pair : this_rank2block) { + int64_t this_machine_id = pair.first; + int64_t this_mem_block_id = pair.second; + CHECK(merged_rank2block.find(this_machine_id) != merged_rank2block.end()); + int64_t merged_mem_block_id = merged_rank2block.at(this_machine_id); + CHECK(mem_block_id2merged_mem_block_id.emplace(this_mem_block_id, merged_mem_block_id) + .second); + VLOG(2) << " merge mem_block_id: " << this_mem_block_id << " to " << merged_mem_block_id; + } + } + } + + for (int64_t i = 0; i < plan->task_size(); ++i) { + TaskProto* task = plan->mutable_task(i); + const StreamId stream_id = PlanUtil::GetStreamId(*task); + DeviceType device_type = stream_id.device_id().device_type(); + // TODO(zwx): eliminate this special 'is cpu' determine + if (device_type == DeviceType::kCPU) { continue; } + if (task->task_set_info().chain_id() == -1) { continue; } + + for (auto& pair : *(task->mutable_produced_regst_desc())) { + RegstDescProto* regst_desc = &pair.second; + if (regst_desc->mem_block_id() != -1 && regst_desc->enable_reuse_mem() + && regst_desc->mem_case().device_type() == device_type + && regst_desc->regst_desc_type().has_data_regst_desc()) { + int64_t mem_block_id = regst_desc->mem_block_id(); + if (mem_block_id2merged_mem_block_id.find(mem_block_id) + != mem_block_id2merged_mem_block_id.end()) { + // merge mem_block_id + int64_t merged_mem_block_id = mem_block_id2merged_mem_block_id.at(mem_block_id); + regst_desc->set_mem_block_id(merged_mem_block_id); + const auto& data_regst = regst_desc->regst_desc_type().data_regst_desc(); + CHECK_GE(data_regst.lbi2blob_desc_size(), 1); + const auto& lbi2blob_desc_pair = data_regst.lbi2blob_desc(0); + std::string tensor_name = GenLogicalBlobName(lbi2blob_desc_pair.lbi()); + VLOG(3) << " regst: " << tensor_name << " merge mem block id " << mem_block_id << " to " + << merged_mem_block_id; + } + } + } + } } void PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan( diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index 1ee7d333bd1..20d8bc79250 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -395,11 +395,13 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui InsertLogicalChainId(chain_order_ops, info.after_acc_logical_chain->logical_chain_id); InsertCtrlEdgeInChain(chain_order_ops); + const auto& first_chain = info.ordered_logical_chains.front(); + // NOTE(chengcheng): // 1.add acc ctrl tick between first chain src to acc chain sink for memory lock. const int64_t acc_num = job_builder->job().job_conf().num_gradient_accumulation_steps(); CHECK_GT(acc_num, 1); - const OpNode* first_chain_src_op = info.ordered_logical_chains.front()->begin_op; + const OpNode* first_chain_src_op = first_chain->begin_op; const auto& fc_src_obns = first_chain_src_op->op().output_bns(); CHECK(!fc_src_obns.empty()); const std::string& first_chain_src_out_lbn = @@ -425,7 +427,7 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui // NOTE(chengcheng): // 2.add acc tick between first chain sink to acc chain src for strict exec order. - const OpNode* first_chain_sink_op = info.ordered_logical_chains.front()->end_op; + const OpNode* first_chain_sink_op = first_chain->end_op; const auto& fc_sink_obns = first_chain_sink_op->op().output_bns(); CHECK(!fc_sink_obns.empty()); const std::string first_chain_sink_lbn = @@ -467,6 +469,12 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui sink_acc_tick_conf)); CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), sink_final_tick_conf)); + + // NOTE(chengcheng): + // 3. merge first chain and acc chain + MergedLogicalChainIdGroup* group = job_builder->add_logical_chain_groups(); + group->add_logical_chain_id_list(first_chain->logical_chain_id); + group->add_logical_chain_id_list(info.after_acc_logical_chain->logical_chain_id); } } @@ -481,34 +489,6 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui << " op_name: " << ordered_op->op().op_name() << " global_order: " << JUST(MapAt(op_node2global_order, ordered_op)); } - - // NOTE(chengcheng): - // creat repeat tick - // 1.add acc ctrl tick between first chain src to acc chain sink for memory lock. - // 2.add acc tick between first chain sink to acc chain src for strict exec order. - const int64_t acc_num = job_builder->job().job_conf().num_gradient_accumulation_steps(); - CHECK_GT(acc_num, 1); - const OpNode* first_chain_src_op = info.ordered_logical_chains.front()->begin_op; - const auto& fcs_obns = first_chain_src_op->op().output_bns(); - CHECK(!fcs_obns.empty()); - const std::string& first_chain_src_out_lbn = - GenLogicalBlobName(first_chain_src_op->op().BnInOp2Lbi(fcs_obns.Get(0))); - - VLOG(3) << " first_chain_src_out_lbn : " << first_chain_src_out_lbn; - user_op::UserOpConfWrapper acc_ctrl_tick_op = - user_op::UserOpConfWrapperBuilder("Sys-AccCtrlTick4MergeFirstAccChain-" + NewUniqueId()) - .OpTypeName("acc_ctrl_tick") - .Input("in", first_chain_src_out_lbn) - .Output("out") - .ScopeSymbolId(first_chain_src_op->op().op_conf().scope_symbol_id()) - .Attr("max_acc_num", acc_num) - .Build(); - - JUST(MapAt(mut_op_name2conf, info.after_acc_logical_chain->end_op->op().op_name())) - .add_ctrl_in_op_name(acc_ctrl_tick_op.op_name()); - - JUST(job_builder->AddOp(first_chain_src_op->parallel_desc().parallel_conf(), - acc_ctrl_tick_op.op_conf())); } VLOG(3) << " In placement: " << placement From e4087e306f7544e068462cdbbb6aaed0c82a28f0 Mon Sep 17 00:00:00 2001 From: chengtbv <472491134@qq.com> Date: Sun, 4 Sep 2022 10:41:15 +0000 Subject: [PATCH 16/66] fix user op register --- oneflow/user/ops/acc_ctrl_tick_op.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/oneflow/user/ops/acc_ctrl_tick_op.cpp b/oneflow/user/ops/acc_ctrl_tick_op.cpp index 9db2fce4bdc..f97fcadd9f9 100644 --- a/oneflow/user/ops/acc_ctrl_tick_op.cpp +++ b/oneflow/user/ops/acc_ctrl_tick_op.cpp @@ -20,8 +20,7 @@ limitations under the License. namespace oneflow { /* static */ Maybe AccCtrlTickOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - Shape* out_shape = ctx->MutOutputShape("out", 0); - *out_shape = Shape({1}); + ctx->SetOutputShape("out", 0, Shape({1})); return Maybe::Ok(); } @@ -54,7 +53,7 @@ namespace oneflow { } /* static */ Maybe AccCtrlTickOp::InferDataType(user_op::InferContext* ctx) { - *ctx->MutOutputDType("out", 0) = ctx->InputDType("in", 0); + ctx->SetOutputDType("out", 0, ctx->InputDType("in", 0)); return Maybe::Ok(); } From 10a97558feb6123cfce1aa884294a5f6a2f1a89a Mon Sep 17 00:00:00 2001 From: chengtbv <472491134@qq.com> Date: Sun, 4 Sep 2022 10:54:18 +0000 Subject: [PATCH 17/66] fix GLOG error when no grad acc --- .../core/job_rewriter/logical_chain_pass.cpp | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index 20d8bc79250..f666122fe29 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -374,6 +374,19 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui InsertCtrlEdgeInChain(logical_chain->ordered_op_nodes); } + for (const auto& logical_chain : info.ordered_logical_chains) { + VLOG(3) << " In placement: " << placement + << " logical_chain_id: " << logical_chain->logical_chain_id + << " has op num = " << logical_chain->ordered_op_nodes.size(); + + for (int i = 0; i < logical_chain->ordered_op_nodes.size(); ++i) { + const OpNode* ordered_op = JUST(VectorAt(logical_chain->ordered_op_nodes, i)); + VLOG(3) << " ChainId: " << logical_chain_id << " order: " << i + << " op_name: " << ordered_op->op().op_name() + << " global_order: " << JUST(MapAt(op_node2global_order, ordered_op)); + } + } + // NOTE(chengcheng): create logical chain after acc, and merge with first logical chain. const std::vector& ordered_acc_op_nodes = info.ordered_acc_op_nodes; if (!ordered_acc_op_nodes.empty()) { @@ -476,31 +489,19 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui group->add_logical_chain_id_list(first_chain->logical_chain_id); group->add_logical_chain_id_list(info.after_acc_logical_chain->logical_chain_id); } - } - for (const auto& logical_chain : info.ordered_logical_chains) { VLOG(3) << " In placement: " << placement - << " logical_chain_id: " << logical_chain->logical_chain_id - << " has op num = " << logical_chain->ordered_op_nodes.size(); - - for (int i = 0; i < logical_chain->ordered_op_nodes.size(); ++i) { - const OpNode* ordered_op = JUST(VectorAt(logical_chain->ordered_op_nodes, i)); - VLOG(3) << " ChainId: " << logical_chain_id << " order: " << i - << " op_name: " << ordered_op->op().op_name() + << " AccLogicalChain: " << info.after_acc_logical_chain->logical_chain_id + << " has op num = " << info.after_acc_logical_chain->ordered_op_nodes.size(); + + for (int i = 0; i < info.after_acc_logical_chain->ordered_op_nodes.size(); ++i) { + const OpNode* ordered_op = + JUST(VectorAt(info.after_acc_logical_chain->ordered_op_nodes, i)); + VLOG(3) << " AfterAccChainId: " << info.after_acc_logical_chain->logical_chain_id + << " order: " << i << " op_name: " << ordered_op->op().op_name() << " global_order: " << JUST(MapAt(op_node2global_order, ordered_op)); } } - - VLOG(3) << " In placement: " << placement - << " AccLogicalChain: " << info.after_acc_logical_chain->logical_chain_id - << " has op num = " << info.after_acc_logical_chain->ordered_op_nodes.size(); - - for (int i = 0; i < info.after_acc_logical_chain->ordered_op_nodes.size(); ++i) { - const OpNode* ordered_op = JUST(VectorAt(info.after_acc_logical_chain->ordered_op_nodes, i)); - VLOG(3) << " AfterAccChainId: " << info.after_acc_logical_chain->logical_chain_id - << " order: " << i << " op_name: " << ordered_op->op().op_name() - << " global_order: " << JUST(MapAt(op_node2global_order, ordered_op)); - } } // NOTE(chengcheng): update global order and chain id for ops. From a1c71bc0a6307ea5934a1d73075044043ccc176a Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Sun, 4 Sep 2022 16:38:02 +0000 Subject: [PATCH 18/66] Inplace repeat variable --- oneflow/core/framework/nn_graph.cpp | 1 - .../graph_impl/repeat_compute_task_node.cpp | 11 +- oneflow/core/job/compiler.cpp | 2 + oneflow/core/job/plan_util.cpp | 97 +++++++-- oneflow/core/lazy/actor/repeat_actor.cpp | 186 ++++++++++++++---- oneflow/core/operator/device_tick_op.cpp | 1 - oneflow/user/kernels/nop_kernel.cpp | 1 + oneflow/user/kernels/repeat_kernel.cpp | 2 + 8 files changed, 236 insertions(+), 65 deletions(-) diff --git a/oneflow/core/framework/nn_graph.cpp b/oneflow/core/framework/nn_graph.cpp index cc3d3b1ab94..e86a5f81661 100644 --- a/oneflow/core/framework/nn_graph.cpp +++ b/oneflow/core/framework/nn_graph.cpp @@ -297,7 +297,6 @@ Maybe NNGraph::CompileAndInitRuntime() { PlanUtil::GenRegisterHint(&plan_); // TODO(chengcheng): test collective boxing for multi-job. PlanUtil::GenCollectiveBoxingPlan(&job_, &plan_); - // PlanUtil::SetForceInplaceMemBlock(&plan_); NOTE(chengcheng): only for ssp. PlanUtil::DumpCtrlRegstInfoToPlan(&plan_); PlanUtil::PlanMemoryLog(&plan_, name_); if (Singleton::Get()->enable_debug_mode()) { diff --git a/oneflow/core/graph_impl/repeat_compute_task_node.cpp b/oneflow/core/graph_impl/repeat_compute_task_node.cpp index 9ff9073b88a..8c887688d75 100644 --- a/oneflow/core/graph_impl/repeat_compute_task_node.cpp +++ b/oneflow/core/graph_impl/repeat_compute_task_node.cpp @@ -43,14 +43,23 @@ void RepeatCompTaskNode::ProduceAllRegstsAndBindEdges() { } void RepeatCompTaskNode::BuildExecGphAndRegst() { + std::shared_ptr in_regst = GetSoleConsumedRegst("in"); ExecNode* node = mut_exec_gph().NewNode(); std::shared_ptr sole_op = op(); node->mut_op() = sole_op; - node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in")); + node->BindBnWithRegst(sole_op->SoleIbn(), in_regst); std::shared_ptr out_regst = GetProducedRegst("out"); out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn())); node->BindBnWithRegst(sole_op->SoleObn(), out_regst); node->InferBlobDescs(parallel_ctx()); + + // NOTE(chengcheng): force inplace + CHECK_EQ(in_regst->NumOfLbi(), 1); + CHECK_EQ(out_regst->NumOfLbi(), 1); + CHECK_EQ(in_regst->min_register_num(), 1); + // NOTE(chengcheng): input need unreused mem + in_regst->set_enable_reuse_mem(false); + out_regst->set_force_inplace_consumed_regst_desc_id(in_regst->regst_desc_id()); } REGISTER_COMP_TASK_STREAM_INDEX_GETTER(TaskType::kRepeat); diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index fb757c75216..ef734de19b1 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -104,9 +104,11 @@ void Compiler::Compile(Job* job, Plan* plan) const { auto* job_id2job_conf = plan->mutable_job_confs()->mutable_job_id2job_conf(); (*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf(); // NOTE(chengcheng): infer mem blob id & set inplace & add ctrl + // TODO(chengcheng): set inplace hint by cpu regst IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan, IsReachable); PlanUtil::MergeMemBlockIdByLogicalChainId(plan, *job); PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan); + PlanUtil::SetForceInplaceMemBlock(plan); Singleton::Delete(); } diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index 40de2ab8d26..6e2cfa474f3 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -55,6 +55,7 @@ std::function PlanUtil::MakeGetterTaskProto4TaskId(co void PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(Plan* plan) { for (int i = 0; i < plan->task_size(); i++) { TaskProto* task = plan->mutable_task(i); + for (auto& pair : *task->mutable_produced_regst_desc()) { RegstDescProto* regst_desc = &pair.second; if (regst_desc->mem_block_id() == -1) { @@ -62,6 +63,24 @@ void PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(Plan* plan) { regst_desc->set_mem_block_id(Singleton::Get()->NewMemBlockId()); regst_desc->set_mem_block_offset(0); } + // NOTE(chengcheng): set variable_op_name before set separated header because var regst alway + // separated. + if (task->exec_sequence().exec_node_size() == 1) { + const auto& op_conf = + GetOpAttribute(plan, task->job_id(), task->exec_sequence().exec_node(0).kernel_conf()) + .op_conf(); + if (op_conf.has_variable_conf()) { regst_desc->set_variable_op_name(op_conf.name()); } + } + + RtRegstDesc rt_regst_desc(*regst_desc); + int64_t regst_separated_size = rt_regst_desc.TotalSeparatedHeaderByteSize4AllRegst(); + if (regst_separated_size > 0) { + int64_t separated_mem_block_id = Singleton::Get()->NewMemBlockId(); + regst_desc->set_separated_header_mem_block_id(separated_mem_block_id); + LOG(INFO) << "set sep id, regst: " << regst_desc->regst_desc_id() + << " , sep : " << separated_mem_block_id + << " , debug: " << regst_desc->DebugString(); + } } } } @@ -318,7 +337,6 @@ void PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan( int64_t mem_block_offset = regst_desc->mem_block_offset(); CHECK_NE(mem_block_id, -1); CHECK_NE(mem_block_offset, -1); - CHECK_EQ(regst_desc->separated_header_mem_block_id(), -1); std::string var_name; bool is_variable_regst = IsVariableRegst(task, &var_name); @@ -353,32 +371,59 @@ void PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan( .second); } else { MemBlockProto* mem_block = mem_block_id2mem_block.at(mem_block_id).get(); - CHECK(!mem_block->has_variable_op_name()); // variable regst mem block is unique. CHECK_EQ(mem_block->job_id(0), job_id); CHECK_EQ(mem_block->machine_id(), machine_id); CHECK(mem_block->mem_case() == regst_desc->mem_case()); CHECK_EQ(mem_block->enable_reuse_mem(), regst_desc->enable_reuse_mem()); - mem_block->set_mem_size(std::max(mem_block->mem_size(), regst_main_size + mem_block_offset)); + if (mem_block->enable_reuse_mem()) { + mem_block->set_mem_size( + std::max(mem_block->mem_size(), regst_main_size + mem_block_offset)); + } else { + CHECK_EQ(mem_block->mem_size(), regst_main_size); + CHECK_EQ(mem_block_offset, 0); + } + if (is_variable_regst) { + mem_block->set_variable_op_name(var_name); + mem_block->set_is_separated_header(false); + } } if (regst_separated_size > 0) { - int64_t separated_mem_block_id = Singleton::Get()->NewMemBlockId(); - regst_desc->set_separated_header_mem_block_id(separated_mem_block_id); - MemBlockProto mem_block; - mem_block.set_mem_block_id(separated_mem_block_id); - mem_block.add_job_id(job_id); - mem_block.set_machine_id(machine_id); - *(mem_block.mutable_mem_case()) = memory::GetPinnedHostMemoryCase(regst_desc->mem_case()); - mem_block.set_enable_reuse_mem(false); - mem_block.set_mem_size(regst_separated_size); - mem_block.set_thrd_id_hint(thrd_id); - if (is_variable_regst) { - mem_block.set_variable_op_name(var_name); - mem_block.set_is_separated_header(true); + if (regst_desc->has_separated_header_mem_block_id()) { + LOG(INFO) << "ccdebuglog: wrong, sep id, regst: " << regst_desc->regst_desc_id() + << " , debug: " << regst_desc->DebugString(); + } + CHECK(regst_desc->has_separated_header_mem_block_id()) << regst_desc->DebugString(); + int64_t separated_mem_block_id = regst_desc->separated_header_mem_block_id(); + CHECK_NE(separated_mem_block_id, -1); + if (mem_block_id2mem_block.find(separated_mem_block_id) == mem_block_id2mem_block.end()) { + MemBlockProto mem_block; + mem_block.set_mem_block_id(separated_mem_block_id); + mem_block.add_job_id(job_id); + mem_block.set_machine_id(machine_id); + *(mem_block.mutable_mem_case()) = memory::GetPinnedHostMemoryCase(regst_desc->mem_case()); + mem_block.set_enable_reuse_mem(false); + mem_block.set_mem_size(regst_separated_size); + mem_block.set_thrd_id_hint(thrd_id); + if (is_variable_regst) { + mem_block.set_variable_op_name(var_name); + mem_block.set_is_separated_header(true); + } + CHECK(mem_block_id2mem_block + .emplace(mem_block.mem_block_id(), std::make_unique(mem_block)) + .second); + } else { + MemBlockProto* mem_block = mem_block_id2mem_block.at(separated_mem_block_id).get(); + CHECK_EQ(mem_block->job_id(0), job_id); + CHECK_EQ(mem_block->machine_id(), machine_id); + CHECK(mem_block->mem_case() == memory::GetPinnedHostMemoryCase(regst_desc->mem_case())); + CHECK_EQ(mem_block->enable_reuse_mem(), false); + CHECK_EQ(mem_block->mem_size(), regst_separated_size); + if (is_variable_regst) { + mem_block->set_variable_op_name(var_name); + mem_block->set_is_separated_header(true); + } } - CHECK(mem_block_id2mem_block - .emplace(mem_block.mem_block_id(), std::make_unique(mem_block)) - .second); } }; @@ -756,8 +801,22 @@ void PlanUtil::SetForceInplaceMemBlock(Plan* plan) { CHECK_EQ(in_regst_desc->mem_block_offset(), 0); CHECK_EQ(regst_desc->mem_block_offset(), 0); CHECK_EQ(in_regst_desc->register_num(), regst_desc->register_num()); + CHECK(in_regst_desc->mem_case() == regst_desc->mem_case()); + RtRegstDesc in_regst_rt(*in_regst_desc); + RtRegstDesc regst_rt(*regst_desc); + CHECK_EQ(in_regst_rt.TotalByteSize4AllRegst(), regst_rt.TotalByteSize4AllRegst()); + CHECK_EQ(in_regst_rt.TotalMainByteSize4AllRegst(), regst_rt.TotalMainByteSize4AllRegst()); + CHECK_EQ(in_regst_rt.TotalSeparatedHeaderByteSize4AllRegst(), + regst_rt.TotalSeparatedHeaderByteSize4AllRegst()); regst_desc->set_mem_block_id(in_regst_desc->mem_block_id()); regst_desc->set_inplace_consumed_regst_desc_id(force_id); + if (in_regst_desc->has_separated_header_mem_block_id()) { + CHECK(regst_desc->has_separated_header_mem_block_id()); + regst_desc->set_separated_header_mem_block_id( + in_regst_desc->separated_header_mem_block_id()); + } + LOG(INFO) << " cclog: set force inplace from " << regst_desc->DebugString() << " to " + << in_regst_desc->DebugString(); } } } diff --git a/oneflow/core/lazy/actor/repeat_actor.cpp b/oneflow/core/lazy/actor/repeat_actor.cpp index c84643a719e..134870cd872 100644 --- a/oneflow/core/lazy/actor/repeat_actor.cpp +++ b/oneflow/core/lazy/actor/repeat_actor.cpp @@ -14,28 +14,81 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" +#include "oneflow/core/framework/framework.h" namespace oneflow { class RepeatActor final : public Actor { public: OF_DISALLOW_COPY_AND_MOVE(RepeatActor); - RepeatActor() = default; + RepeatActor() + : repeat_count_(0), + repeat_num_(0), + wait_all_regst_return_(false), + consumed_var_regst_desc_id_(-1), + produced_repeat_var_regst_desc_id_(-1){}; ~RepeatActor() override = default; private: + // NOTE(chengcheng): Empty rs for naive and inplace regst, all regst is customized. + std::pair> GetNaiveOrCustomizedConsumedRegstDescName() + override { + return std::make_pair(RegstNameType::kNaive, HashSet{}); + } + std::pair> GetNaiveOrCustomizedProducedRegstDescName() + override { + return std::make_pair(RegstNameType::kNaive, HashSet{}); + } + void TakeOverInplaceConsumedAndProduced( + const PbMap& produced_ids) override { + // NOTE(chengcheng): all regst is customized. + inplace_consumed_rs_.InitedDone(); + inplace_produced_rs_.InitedDone(); + } + + bool IsCustomizedReadReady() const override { + bool is_ready_ready = (!wait_all_regst_return_) && consumed_var_rs_.IsCurSlotReady(); + LOG(INFO) << " ccActorLog: actor: " << actor_id() << " is_ready_ready: " << is_ready_ready + << " of wait_all_regst_return_ = " << wait_all_regst_return_ + << " consumed_var_rs_.IsCurSlotReady = " << consumed_var_rs_.IsCurSlotReady(); + return (!wait_all_regst_return_) && consumed_var_rs_.IsCurSlotReady(); + } + bool IsCustomizedWriteReady() const override { + LOG(INFO) << " ccActorLog: actor: " << actor_id() + << " is_write_ready: " << (!wait_all_regst_return_) + && produced_repeat_var_rs_.IsCurSlotReady(); + return (!wait_all_regst_return_) && produced_repeat_var_rs_.IsCurSlotReady(); + } + + void NormalProcessCustomizedEordMsg(const ActorMsg&) override {} + bool IsCustomizedReadAlwaysUnReadyFromNow() const override { + // all Messages are flushed + return ReceiveEordMsg(consumed_var_regst_desc_id_); + } + void VirtualActorInit(const TaskProto& proto) override; void Act() override; - void VirtualAsyncSendNaiveConsumedRegstMsgToProducer() override; - void VirtualAsyncSendNaiveProducedRegstMsgToConsumer() override; - bool ConsumedCtrlRegstValid(int64_t regst_desc_id) const override; - bool IsCustomizedWriteReady() const override; - - int64_t repeat_num_; - int64_t repeat_count_; + void AsyncSendCustomizedProducedRegstMsgToConsumer() override; + void AsyncSendCustomizedConsumedRegstMsgToProducer() override; + void UpdtStateAsCustomizedProducedRegst(Regst* regst) override; + void NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) override; + + int32_t repeat_count_; + int32_t repeat_num_; + bool wait_all_regst_return_; + int64_t consumed_var_regst_desc_id_; + int64_t produced_repeat_var_regst_desc_id_; + RegstSlot consumed_var_rs_; + RegstSlot produced_repeat_var_rs_; + // Regst* var_regst_; }; void RepeatActor::VirtualActorInit(const TaskProto& proto) { + repeat_count_ = 0; + const OperatorConf op_conf = + proto.exec_sequence().exec_node(0).kernel_conf().op_attribute().op_conf(); + repeat_num_ = user_op::UserOpConfWrapper(op_conf).attr("repeat_num"); + const Shape& in_time_shape = Singleton::Get() ->RegstDesc4RegstDescId(Name2SoleRegstDescId("in")) .data_regst_time_shape(); @@ -47,63 +100,110 @@ void RepeatActor::VirtualActorInit(const TaskProto& proto) { FOR_RANGE(int64_t, i, 0, in_time_shape.NumAxes()) { CHECK_EQ(in_time_shape.At(i), out_time_shape.At(i)); } - repeat_num_ = out_time_shape.At(out_time_shape.NumAxes() - 1); - repeat_count_ = 0; - const RegstDescProto& out_regst_desc = proto.produced_regst_desc().at("out"); + CHECK_EQ(repeat_num_, out_time_shape.At(out_time_shape.NumAxes() - 1)); + + // input + const auto& consumed_ids = proto.consumed_regst_desc_id(); + CHECK_EQ(consumed_ids.size(), 1); + CHECK(consumed_ids.find("in") != consumed_ids.end()); + const auto& in_ids = consumed_ids.at("in"); + CHECK_EQ(in_ids.regst_desc_id_size(), 1); + consumed_var_regst_desc_id_ = in_ids.regst_desc_id(0); + consumed_var_rs_.InsertRegstDescId(consumed_var_regst_desc_id_); + consumed_var_rs_.InitedDone(); + + // output + const auto& produced_ids = proto.produced_regst_desc(); + CHECK_EQ(produced_ids.size(), 1); + CHECK(produced_ids.find("out") != produced_ids.end()); + const RegstDescProto& out_regst_desc = produced_ids.at("out"); CHECK(!out_regst_desc.enable_reuse_mem()); CHECK_EQ(out_regst_desc.register_num(), 1); - + // check inplace + CHECK_EQ(out_regst_desc.inplace_consumed_regst_desc_id(), consumed_var_regst_desc_id_); + produced_repeat_var_regst_desc_id_ = out_regst_desc.regst_desc_id(); + produced_repeat_var_rs_.InsertRegstDescId(produced_repeat_var_regst_desc_id_); + produced_repeat_var_rs_.InitedDone(); // Regst number hacking - if (naive_consumed_rs_.total_regst_desc_cnt() != 1) { - LOG(WARNING) - << "RepeatActor has more than one consumed register. This will impact performance."; + for (int64_t i = 1; i < repeat_num_; ++i) { + Singleton::Get()->NewRegsts(out_regst_desc, [this](Regst* regst) { + produced_regsts_[this->produced_repeat_var_regst_desc_id_].emplace_back(regst); + produced_regst2reading_cnt_[regst] = 0; + }); } - for (const auto& pair : proto.produced_regst_desc()) { - const RegstDescProto& regst_desc = pair.second; - int64_t regst_desc_id = regst_desc.regst_desc_id(); - // This itor begins from 1 because first regst was already inserted in TakeOverNaiveProduced - for (int64_t i = 1; i < repeat_num_; ++i) { - Singleton::Get()->NewRegsts(regst_desc, [this, regst_desc_id](Regst* regst) { - produced_regsts_[regst_desc_id].emplace_back(regst); - produced_regst2reading_cnt_[regst] = 0; - naive_produced_rs_.TryPushBackRegst(regst); - }); - } - } + ForEachProducedRegst([&](Regst* regst) { + if (regst->regst_desc_id() != produced_repeat_var_regst_desc_id_) { return; } + CHECK_EQ(0, produced_repeat_var_rs_.TryPushBackRegst(regst)); + }); OF_SET_MSG_HANDLER(&RepeatActor::HandlerNormal); } void RepeatActor::Act() { - // reset repeat_count if need - if (repeat_count_ == repeat_num_) { repeat_count_ = 0; } + repeat_count_ += 1; - if (repeat_count_ == 0) { AsyncLaunchKernel(); } + if (repeat_count_ == repeat_num_) { + wait_all_regst_return_ = true; + repeat_count_ = 0; + } - repeat_count_ += 1; + Regst* out_regst = produced_repeat_var_rs_.Front(produced_repeat_var_regst_desc_id_); + Regst* in_regst = consumed_var_rs_.Front(consumed_var_regst_desc_id_); + CHECK(out_regst && in_regst); + // LOG(WARNING) << "cclog: RepeatVarActor: " + // << out_regst->regst_desc()->regst_desc_type().DebugString(); + CHECK(out_regst->main_mem_ptr() == in_regst->main_mem_ptr()); + CHECK(out_regst->separated_header_mem_ptr() == in_regst->separated_header_mem_ptr()); + CHECK_EQ(out_regst->regst_desc()->MainByteSize4OneRegst(), + in_regst->regst_desc()->MainByteSize4OneRegst()); + CHECK_EQ(out_regst->regst_desc()->SeparatedHeaderByteSize4OneRegst(), + in_regst->regst_desc()->SeparatedHeaderByteSize4OneRegst()); } -void RepeatActor::VirtualAsyncSendNaiveConsumedRegstMsgToProducer() { - if (repeat_count_ == repeat_num_) { HandleConsumedNaiveDataRegstToProducer(); } -} +void RepeatActor::AsyncSendCustomizedProducedRegstMsgToConsumer() { + CHECK(produced_repeat_var_rs_.IsCurSlotReady()); + Regst* const repeat_var_regst = produced_repeat_var_rs_.Front(produced_repeat_var_regst_desc_id_); + CHECK_GT(HandleRegstToConsumer(repeat_var_regst), 0); + produced_repeat_var_rs_.PopFrontRegsts({produced_repeat_var_regst_desc_id_}); -void RepeatActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() { - HandleProducedNaiveDataRegstToConsumer(); + LOG(INFO) << "ccActorLog: repeat actor: " << actor_id() << " in repeat count: " << repeat_count_ + << " Send var regst " << produced_repeat_var_regst_desc_id_ << " to Consumer."; } -bool RepeatActor::ConsumedCtrlRegstValid(int64_t regst_desc_id) const { - return repeat_count_ == repeat_num_; +void RepeatActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { + // NOTE(chengcheng): do nothing. consumed var regst will return in inplace done. + LOG(INFO) << "ccActorLog: repeat actor: " << actor_id() << " in repeat count: " << repeat_count_ + << " NOT return var regst for waiting inplace var regst returned. "; } -bool RepeatActor::IsCustomizedWriteReady() const { - if (repeat_count_ % repeat_num_ == 0) { - return total_reading_cnt_ == 0; - } else { - return true; +void RepeatActor::UpdtStateAsCustomizedProducedRegst(Regst* regst) { + CHECK_EQ(regst->regst_desc_id(), produced_repeat_var_regst_desc_id_); + CHECK_EQ(produced_repeat_var_rs_.TryPushBackRegst(regst), 0); + LOG(INFO) << "ccActorLog: repeat actor: " << actor_id() << " in count: " << repeat_count_ + << " regst_desc_id: " << produced_repeat_var_regst_desc_id_ << " ready size = " + << produced_repeat_var_rs_.GetReadyRegstSize(produced_repeat_var_regst_desc_id_); + + if (wait_all_regst_return_ + && produced_repeat_var_rs_.GetReadyRegstSize(produced_repeat_var_regst_desc_id_) + == repeat_num_) { + Regst* in_regst = consumed_var_rs_.Front(consumed_var_regst_desc_id_); + CHECK(in_regst); + AsyncSendRegstMsgToProducer(in_regst); + CHECK_EQ(0, consumed_var_rs_.TryPopFrontRegst(consumed_var_regst_desc_id_)); + wait_all_regst_return_ = false; + + LOG(INFO) << "ccActorLog: repeat actor: " << actor_id() << " in count: " << repeat_count_ + << " consumed_var_regst_desc_id: " << consumed_var_regst_desc_id_ + << " return with all produced regst."; } } +void RepeatActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) { + CHECK_EQ(0, consumed_var_rs_.TryPushBackRegst(msg.regst())); + LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << repeat_count_ + << " receive var regst: " << msg.regst()->regst_desc_id(); +} REGISTER_ACTOR(TaskType::kRepeat, RepeatActor); } // namespace oneflow diff --git a/oneflow/core/operator/device_tick_op.cpp b/oneflow/core/operator/device_tick_op.cpp index dc331fbb43f..306371f6bfa 100644 --- a/oneflow/core/operator/device_tick_op.cpp +++ b/oneflow/core/operator/device_tick_op.cpp @@ -82,7 +82,6 @@ Maybe DeviceTickOp::InferOpTimeShape( return Maybe::Ok(); } -REGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kDeviceTickConf, 2); REGISTER_OP(OperatorConf::kDeviceTickConf, DeviceTickOp); REGISTER_TICK_TOCK_OP(OperatorConf::kDeviceTickConf); diff --git a/oneflow/user/kernels/nop_kernel.cpp b/oneflow/user/kernels/nop_kernel.cpp index 66017e492f8..a8f813cf921 100644 --- a/oneflow/user/kernels/nop_kernel.cpp +++ b/oneflow/user/kernels/nop_kernel.cpp @@ -36,6 +36,7 @@ class NopKernel final : public user_op::OpKernel { REGISTER_NOP_KERNEL("cast_to_tick") REGISTER_NOP_KERNEL("acc_ctrl_tick") +REGISTER_NOP_KERNEL("repeat") } // namespace diff --git a/oneflow/user/kernels/repeat_kernel.cpp b/oneflow/user/kernels/repeat_kernel.cpp index 4ea643a787d..c664887936f 100644 --- a/oneflow/user/kernels/repeat_kernel.cpp +++ b/oneflow/user/kernels/repeat_kernel.cpp @@ -38,6 +38,7 @@ class RepeatKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; +/* #define REGISTER_REPEAT_KERNEL(device) \ REGISTER_USER_KERNEL("repeat").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == device)); @@ -45,6 +46,7 @@ class RepeatKernel final : public user_op::OpKernel { OF_PP_FOR_EACH_TUPLE(REGISTER_REPEAT_KERNEL, DEVICE_TYPE_SEQ) #undef REGISTER_REPEAT_KERNEL +*/ } // namespace From cb31e5322e7a535be680df03c1923eca1a89c3c2 Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Tue, 6 Sep 2022 07:14:43 +0000 Subject: [PATCH 19/66] Inplace repeat support consumed/produced ctrl regst --- oneflow/core/job/plan_util.cpp | 8 +++++ oneflow/core/lazy/actor/repeat_actor.cpp | 39 ++++++++++++++++++------ 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index 6e2cfa474f3..2384feb41b8 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -820,6 +820,14 @@ void PlanUtil::SetForceInplaceMemBlock(Plan* plan) { } } } + + // tmp debug log + for (int i = 0; i < plan->task_size(); i++) { + TaskProto* task = plan->mutable_task(i); + if (task->task_type() == kRepeat && task->produced_regst_desc().size() > 1) { + LOG(WARNING) << " bad repeat : " << task->DebugString(); + } + } } void PlanUtil::DumpCtrlRegstInfoToPlan(Plan* plan) { diff --git a/oneflow/core/lazy/actor/repeat_actor.cpp b/oneflow/core/lazy/actor/repeat_actor.cpp index 134870cd872..23ac25b3a83 100644 --- a/oneflow/core/lazy/actor/repeat_actor.cpp +++ b/oneflow/core/lazy/actor/repeat_actor.cpp @@ -104,7 +104,6 @@ void RepeatActor::VirtualActorInit(const TaskProto& proto) { // input const auto& consumed_ids = proto.consumed_regst_desc_id(); - CHECK_EQ(consumed_ids.size(), 1); CHECK(consumed_ids.find("in") != consumed_ids.end()); const auto& in_ids = consumed_ids.at("in"); CHECK_EQ(in_ids.regst_desc_id_size(), 1); @@ -114,7 +113,6 @@ void RepeatActor::VirtualActorInit(const TaskProto& proto) { // output const auto& produced_ids = proto.produced_regst_desc(); - CHECK_EQ(produced_ids.size(), 1); CHECK(produced_ids.find("out") != produced_ids.end()); const RegstDescProto& out_regst_desc = produced_ids.at("out"); CHECK(!out_regst_desc.enable_reuse_mem()); @@ -124,19 +122,40 @@ void RepeatActor::VirtualActorInit(const TaskProto& proto) { produced_repeat_var_regst_desc_id_ = out_regst_desc.regst_desc_id(); produced_repeat_var_rs_.InsertRegstDescId(produced_repeat_var_regst_desc_id_); produced_repeat_var_rs_.InitedDone(); - // Regst number hacking - for (int64_t i = 1; i < repeat_num_; ++i) { - Singleton::Get()->NewRegsts(out_regst_desc, [this](Regst* regst) { - produced_regsts_[this->produced_repeat_var_regst_desc_id_].emplace_back(regst); - produced_regst2reading_cnt_[regst] = 0; - }); + + // NOTE(chengcheng): repeat actor may has output ctrl regst. ctrl regst also need hack regst num. + for (const auto& pair : proto.produced_regst_desc()) { + const RegstDescProto& regst_desc = pair.second; + int64_t regst_desc_id = regst_desc.regst_desc_id(); + // This iter begins from 1 because first ctrl regst was already inserted in + // TakeOverNaiveProduced + for (int64_t i = 1; i < repeat_num_; ++i) { + Singleton::Get()->NewRegsts(regst_desc, [this, regst_desc_id](Regst* regst) { + produced_regsts_[regst_desc_id].emplace_back(regst); + produced_regst2reading_cnt_[regst] = 0; + if (regst_desc_id != produced_repeat_var_regst_desc_id_) { + CHECK_EQ(0, naive_produced_rs_.TryPushBackRegst(regst)); + } + }); + } } ForEachProducedRegst([&](Regst* regst) { - if (regst->regst_desc_id() != produced_repeat_var_regst_desc_id_) { return; } - CHECK_EQ(0, produced_repeat_var_rs_.TryPushBackRegst(regst)); + if (regst->regst_desc_id() == produced_repeat_var_regst_desc_id_) { + CHECK_EQ(0, produced_repeat_var_rs_.TryPushBackRegst(regst)); + } }); + for (const auto& pair : proto.produced_regst_desc()) { + const RegstDescProto& regst_desc = pair.second; + int64_t regst_desc_id = regst_desc.regst_desc_id(); + if (regst_desc_id == produced_repeat_var_regst_desc_id_) { + CHECK_EQ(produced_repeat_var_rs_.GetReadyRegstSize(regst_desc_id), repeat_num_); + } else { + CHECK_EQ(naive_produced_rs_.GetReadyRegstSize(regst_desc_id), repeat_num_); + } + } + OF_SET_MSG_HANDLER(&RepeatActor::HandlerNormal); } From f05b9638ca7eb053e77240fa1f55de4bd733df14 Mon Sep 17 00:00:00 2001 From: cheng cheng <472491134@qq.com> Date: Thu, 8 Sep 2022 17:43:17 +0800 Subject: [PATCH 20/66] Part-4: merge acc op in to chain for reuse memory acc input (#9071) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LogicalChain can merge acc op in to chain for reuse memory acc input 实测 GPT 的显存与 part-3 一致。 bert 与 t5 大部分的显存都略低于 part-3 https://github.com/Oneflow-Inc/OneTeam/issues/1670#issuecomment-1240468576 --- oneflow/core/graph/task_graph.cpp | 11 +++- .../core/job_rewriter/logical_chain_pass.cpp | 52 +++++++++++-------- 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 830cff94a17..12c73d53d91 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -387,7 +387,16 @@ void ForEachOpGraphNecessaryCtrlEdge( if (dst_time_shape == nullptr) { dst_time_shape = CHECK_JUST(dst->op().GetOpTimeShape()).get(); } - CHECK_EQ(src_time_shape->elem_cnt(), dst_time_shape->elem_cnt()); + if (src_time_shape->elem_cnt() != dst_time_shape->elem_cnt()) { + // NOTE(chengcheng): acc op node can be merged and add ctrl edge. + CHECK(src->op().op_conf().has_user_conf() + && src->op().op_conf().user_conf().op_type_name() == "acc"); + const Shape* src_input_time_shape = + CHECK_JUST(src->op().GetInputBlobFastestTimeShape()).get(); + CHECK_EQ(src_input_time_shape->elem_cnt(), dst_time_shape->elem_cnt()); + } else { + CHECK_EQ(src_time_shape->elem_cnt(), dst_time_shape->elem_cnt()); + } Handler(src, dst); } } diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index f666122fe29..a0988cd8210 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -73,11 +73,9 @@ bool IsBreakpointOpNode(const OpNode* node) { if (op_conf.has_user_conf()) { const std::string& user_type_name = op_conf.user_conf().op_type_name(); - // TODO(chengcheng): acc node can be merged in chain. - if (user_type_name == "repeat" || user_type_name == "acc" || user_type_name == "pack" - || user_type_name == "unpack" || user_type_name == "identity_buffer" - || user_type_name == "copy_h2d" || user_type_name == "copy_d2h" - || user_type_name == "acc_ctrl_tick") { + if (user_type_name == "repeat" || user_type_name == "pack" || user_type_name == "unpack" + || user_type_name == "identity_buffer" || user_type_name == "copy_h2d" + || user_type_name == "copy_d2h" || user_type_name == "acc_ctrl_tick") { return true; } } @@ -455,14 +453,30 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui .ScopeSymbolId(first_chain_sink_op->op().op_conf().scope_symbol_id()) .Build(); - OperatorConf sink_acc_tick_conf; - sink_acc_tick_conf.set_name(std::string("Sys-LogicalChainSink-AccTick_") + NewUniqueId()); - sink_acc_tick_conf.set_scope_symbol_id( - first_chain_sink_op->op().op_conf().scope_symbol_id()); - auto* acc_conf = sink_acc_tick_conf.mutable_acc_tick_conf(); - acc_conf->set_one(cast_to_tick_op.output("out", 0)); - acc_conf->set_acc("acc"); - acc_conf->set_max_acc_num(acc_num); + CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), + cast_to_tick_op.op_conf())); + + std::string acc_tick_output_lbn = cast_to_tick_op.output("out", 0); + if (!IsAccOpNode(first_chain_sink_op)) { + // NOTE(chengcheng): Acc Op can be merged in fw/bw chain, if the last op is acc op, + // there is no need and CANNOT insert acc tick op. + + OperatorConf sink_acc_tick_conf; + sink_acc_tick_conf.set_name(std::string("Sys-LogicalChainSink-AccTick_") + NewUniqueId()); + sink_acc_tick_conf.set_scope_symbol_id( + first_chain_sink_op->op().op_conf().scope_symbol_id()); + auto* acc_conf = sink_acc_tick_conf.mutable_acc_tick_conf(); + acc_conf->set_one(cast_to_tick_op.output("out", 0)); + acc_conf->set_acc("acc"); + acc_conf->set_max_acc_num(acc_num); + acc_tick_output_lbn = GenLogicalBlobName(sink_acc_tick_conf.name(), "acc"); + + VLOG(3) << " insert acc tick op : " << sink_acc_tick_conf.name() + << " of last op in fw/bw chain."; + + CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), + sink_acc_tick_conf)); + } OperatorConf sink_final_tick_conf; sink_final_tick_conf.set_name(std::string("Sys-LogicalChainSink-FinalTick-DeviceTick_") @@ -470,19 +484,15 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui sink_final_tick_conf.set_scope_symbol_id( first_chain_sink_op->op().op_conf().scope_symbol_id()); auto* tick_conf = sink_final_tick_conf.mutable_device_tick_conf(); - tick_conf->add_tick(GenLogicalBlobName(sink_acc_tick_conf.name(), "acc")); + tick_conf->add_tick(acc_tick_output_lbn); tick_conf->set_out("out"); - JUST(MapAt(mut_op_name2conf, info.after_acc_logical_chain->begin_op->op().op_name())) - .add_ctrl_in_op_name(sink_final_tick_conf.name()); - - CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), - cast_to_tick_op.op_conf())); - CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), - sink_acc_tick_conf)); CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), sink_final_tick_conf)); + JUST(MapAt(mut_op_name2conf, info.after_acc_logical_chain->begin_op->op().op_name())) + .add_ctrl_in_op_name(sink_final_tick_conf.name()); + // NOTE(chengcheng): // 3. merge first chain and acc chain MergedLogicalChainIdGroup* group = job_builder->add_logical_chain_groups(); From f5f852d3fabd7b50f0b1174ed283a438819447ac Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 9 Sep 2022 10:09:00 +0000 Subject: [PATCH 21/66] find first source/sink op in acc chain which can be insert ctrl --- .../core/job_rewriter/logical_chain_pass.cpp | 134 +++++++++++------- 1 file changed, 80 insertions(+), 54 deletions(-) diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index a0988cd8210..c1827aedd73 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -165,28 +165,12 @@ void GetLogicalChainsWithTimeShape(std::vector>* ret, ret->back().swap(this_subgraph); } } - - /* - std::sort(ret->begin(), ret->end(), - [](const HashSet& lhs, const HashSet& rhs) { - return lhs.size() > rhs.size(); - }); - */ } struct LogicalChain { int64_t logical_chain_id; std::vector ordered_op_nodes; - int64_t begin_op_global_order; - int64_t end_op_global_order; - const OpNode* begin_op; - const OpNode* end_op; - LogicalChain() - : logical_chain_id(-1), - begin_op_global_order(-1), - end_op_global_order(-1), - begin_op(nullptr), - end_op(nullptr) {} + LogicalChain() : logical_chain_id(-1) {} }; struct PlacementLogicalChainsInfo { @@ -213,12 +197,12 @@ void InitPlacementLogicalChainsInfoFromSet( logical_chain_ordered_nodes->assign(origin_logical_chain.begin(), origin_logical_chain.end()); std::sort(logical_chain_ordered_nodes->begin(), logical_chain_ordered_nodes->end(), CmpOpNodeOrder); - logical_chain->begin_op = logical_chain_ordered_nodes->front(); - logical_chain->end_op = logical_chain_ordered_nodes->back(); - logical_chain->begin_op_global_order = op_node2global_order.at(logical_chain->begin_op); - logical_chain->end_op_global_order = op_node2global_order.at(logical_chain->end_op); - CHECK(logical_chain->begin_op != logical_chain->end_op); - CHECK_LT(logical_chain->begin_op_global_order, logical_chain->end_op_global_order); + const OpNode* begin_op = logical_chain_ordered_nodes->front(); + const OpNode* end_op = logical_chain_ordered_nodes->back(); + int64_t begin_op_global_order = op_node2global_order.at(begin_op); + int64_t end_op_global_order = op_node2global_order.at(end_op); + CHECK(begin_op != end_op); + CHECK_LT(begin_op_global_order, end_op_global_order); } void CreateAfterAccLogicalChain(const std::shared_ptr& after_acc_logical_chain, @@ -299,7 +283,9 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui }; auto CmpLogicalChainOrder = [&](const std::shared_ptr& lhs, const std::shared_ptr& rhs) { - return lhs->begin_op_global_order < rhs->begin_op_global_order; + int64_t lhs_begin_op_global_order = op_node2global_order.at(lhs->ordered_op_nodes.front()); + int64_t rhs_begin_op_global_order = op_node2global_order.at(rhs->ordered_op_nodes.front()); + return lhs_begin_op_global_order < rhs_begin_op_global_order; }; auto IsReachable = op_graph.MakePredicatorIsOpNameDataOrCtrlReachable(); @@ -361,11 +347,6 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui const auto& placement = pair.first; auto& info = pair.second; CHECK_GE(info.ordered_logical_chains.size(), 1); - for (int i = 0; i < info.ordered_logical_chains.size() - 1; i++) { - CHECK_LT(JUST(VectorAt(info.ordered_logical_chains, i))->begin_op_global_order, - JUST(VectorAt(info.ordered_logical_chains, i + 1))->begin_op_global_order); - } - for (auto& logical_chain : info.ordered_logical_chains) { logical_chain->logical_chain_id = NewLogicalChainId(); InsertLogicalChainId(logical_chain->ordered_op_nodes, logical_chain->logical_chain_id); @@ -391,28 +372,64 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui info.after_acc_logical_chain = std::make_shared(); CreateAfterAccLogicalChain(info.after_acc_logical_chain, ordered_acc_op_nodes, *info.seed_parallel_desc); - if (info.after_acc_logical_chain->ordered_op_nodes.size() > 1) { + auto& acc_chain_order_ops = info.after_acc_logical_chain->ordered_op_nodes; + if (acc_chain_order_ops.size() > 1) { info.after_acc_logical_chain->logical_chain_id = NewLogicalChainId(); - std::sort(info.after_acc_logical_chain->ordered_op_nodes.begin(), - info.after_acc_logical_chain->ordered_op_nodes.end(), CmpOpNodeOrder); - const auto& chain_order_ops = info.after_acc_logical_chain->ordered_op_nodes; - info.after_acc_logical_chain->begin_op = chain_order_ops.front(); - info.after_acc_logical_chain->end_op = chain_order_ops.back(); - info.after_acc_logical_chain->begin_op_global_order = - JUST(MapAt(op_node2global_order, chain_order_ops.front())); - info.after_acc_logical_chain->end_op_global_order = - JUST(MapAt(op_node2global_order, chain_order_ops.back())); - - InsertLogicalChainId(chain_order_ops, info.after_acc_logical_chain->logical_chain_id); - InsertCtrlEdgeInChain(chain_order_ops); - + std::sort(acc_chain_order_ops.begin(), acc_chain_order_ops.end(), CmpOpNodeOrder); const auto& first_chain = info.ordered_logical_chains.front(); + const OpNode* first_chain_src_op = first_chain->ordered_op_nodes.front(); + const OpNode* first_chain_sink_op = first_chain->ordered_op_nodes.back(); + + const OpNode* acc_chain_src_op = acc_chain_order_ops.front(); + const OpNode* acc_chain_sink_op = acc_chain_order_ops.back(); + + // NOTE(chengcheng): find last op can insert acc ctrl tick. + while ( + (!acc_chain_sink_op->op().op_conf().has_user_conf()) + || IsReachable(acc_chain_sink_op->op().op_name(), first_chain_src_op->op().op_name())) { + VLOG(3) << " cannot insert acc ctrl edge between: [" << first_chain_src_op->op().op_name() + << "] -> [" << acc_chain_sink_op->op().op_name() << "] , debug info :\n" + << first_chain_src_op->op().op_conf().DebugString() << "\n" + << acc_chain_sink_op->op().op_conf().DebugString() << "\n"; + + VLOG(3) << "remove op : " << acc_chain_sink_op->op().op_name() + << " from after acc logical chain: " + << info.after_acc_logical_chain->logical_chain_id; + acc_chain_order_ops.pop_back(); + if (acc_chain_order_ops.size() > 1) { + acc_chain_sink_op = acc_chain_order_ops.back(); + } else { + acc_chain_sink_op = nullptr; + break; + } + } + if (acc_chain_sink_op == nullptr) { continue; } + + // NOTE(chengcheng): find first op can insert acc tick. + while (IsReachable(acc_chain_src_op->op().op_name(), first_chain_sink_op->op().op_name())) { + VLOG(3) << " cannot insert acc tick edge between: [" + << first_chain_sink_op->op().op_name() << "] -> [" + << acc_chain_src_op->op().op_name() << "] , debug info :\n" + << first_chain_sink_op->op().op_conf().DebugString() << "\n" + << acc_chain_src_op->op().op_conf().DebugString() << "\n"; + + VLOG(3) << "remove op : " << acc_chain_src_op->op().op_name() + << " from after acc logical chain: " + << info.after_acc_logical_chain->logical_chain_id; + acc_chain_order_ops.erase(acc_chain_order_ops.begin()); + if (acc_chain_order_ops.size() > 1) { + acc_chain_src_op = acc_chain_order_ops.front(); + } else { + acc_chain_src_op = nullptr; + break; + } + } + if (acc_chain_src_op == nullptr) { continue; } // NOTE(chengcheng): // 1.add acc ctrl tick between first chain src to acc chain sink for memory lock. const int64_t acc_num = job_builder->job().job_conf().num_gradient_accumulation_steps(); CHECK_GT(acc_num, 1); - const OpNode* first_chain_src_op = first_chain->begin_op; const auto& fc_src_obns = first_chain_src_op->op().output_bns(); CHECK(!fc_src_obns.empty()); const std::string& first_chain_src_out_lbn = @@ -428,17 +445,19 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui .Attr("max_acc_num", acc_num) .Build(); - OperatorConf& consumer = - JUST(MapAt(mut_op_name2conf, info.after_acc_logical_chain->end_op->op().op_name())); - CHECK(consumer.has_user_conf()); - (*consumer.mutable_user_conf()->mutable_input())[user_op::kUserSourceOpTickInputArgName] + OperatorConf& acc_chain_sink_op_conf = + JUST(MapAt(mut_op_name2conf, acc_chain_sink_op->op().op_name())); + CHECK(acc_chain_sink_op_conf.has_user_conf()); + (*acc_chain_sink_op_conf.mutable_user_conf() + ->mutable_input())[user_op::kUserSourceOpTickInputArgName] .add_s(acc_ctrl_tick_op.output("out", 0)); JUST(job_builder->AddOp(first_chain_src_op->parallel_desc().parallel_conf(), acc_ctrl_tick_op.op_conf())); + VLOG(3) << " Insert acc ctrl tick between: [" << first_chain_src_op->op().op_name() + << "] -> [" << acc_chain_sink_op->op().op_name() << "]"; // NOTE(chengcheng): // 2.add acc tick between first chain sink to acc chain src for strict exec order. - const OpNode* first_chain_sink_op = first_chain->end_op; const auto& fc_sink_obns = first_chain_sink_op->op().output_bns(); CHECK(!fc_sink_obns.empty()); const std::string first_chain_sink_lbn = @@ -490,27 +509,34 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), sink_final_tick_conf)); - JUST(MapAt(mut_op_name2conf, info.after_acc_logical_chain->begin_op->op().op_name())) + JUST(MapAt(mut_op_name2conf, acc_chain_src_op->op().op_name())) .add_ctrl_in_op_name(sink_final_tick_conf.name()); + VLOG(3) << " Insert acc tick between: [" << first_chain_sink_op->op().op_name() << "] -> [" + << acc_chain_src_op->op().op_name() << "]"; + // NOTE(chengcheng): // 3. merge first chain and acc chain MergedLogicalChainIdGroup* group = job_builder->add_logical_chain_groups(); group->add_logical_chain_id_list(first_chain->logical_chain_id); group->add_logical_chain_id_list(info.after_acc_logical_chain->logical_chain_id); + VLOG(3) << " Merge acc chain : " << info.after_acc_logical_chain->logical_chain_id + << " to first logcal chain : " << first_chain->logical_chain_id; } VLOG(3) << " In placement: " << placement << " AccLogicalChain: " << info.after_acc_logical_chain->logical_chain_id - << " has op num = " << info.after_acc_logical_chain->ordered_op_nodes.size(); + << " has op num = " << acc_chain_order_ops.size(); - for (int i = 0; i < info.after_acc_logical_chain->ordered_op_nodes.size(); ++i) { - const OpNode* ordered_op = - JUST(VectorAt(info.after_acc_logical_chain->ordered_op_nodes, i)); + for (int i = 0; i < acc_chain_order_ops.size(); ++i) { + const OpNode* ordered_op = JUST(VectorAt(acc_chain_order_ops, i)); VLOG(3) << " AfterAccChainId: " << info.after_acc_logical_chain->logical_chain_id << " order: " << i << " op_name: " << ordered_op->op().op_name() << " global_order: " << JUST(MapAt(op_node2global_order, ordered_op)); } + + InsertLogicalChainId(acc_chain_order_ops, info.after_acc_logical_chain->logical_chain_id); + InsertCtrlEdgeInChain(acc_chain_order_ops); } } From 5d982161166c5a682c56d42d49c35272e92fe96a Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 9 Sep 2022 11:58:19 +0000 Subject: [PATCH 22/66] TryMergeAfterAccLogicalChainToFirstLogicalChain --- .../core/job_rewriter/logical_chain_pass.cpp | 324 +++++++++--------- 1 file changed, 167 insertions(+), 157 deletions(-) diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index c1827aedd73..f191d2b458b 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -253,6 +253,154 @@ void CreateAfterAccLogicalChain(const std::shared_ptr& after_acc_l } } +void TryMergeAfterAccLogicalChainToFirstLogicalChain( + PlacementLogicalChainsInfo* info, HashMap* mut_op_name2conf, + JobBuilder* job_builder, + const std::function& IsReachable) { + const int64_t acc_chain_id = info->after_acc_logical_chain->logical_chain_id; + auto& acc_chain_order_ops = info->after_acc_logical_chain->ordered_op_nodes; + const auto& first_chain = info->ordered_logical_chains.front(); + const OpNode* first_chain_src_op = first_chain->ordered_op_nodes.front(); + const OpNode* first_chain_sink_op = first_chain->ordered_op_nodes.back(); + + const OpNode* acc_chain_src_op = acc_chain_order_ops.front(); + const OpNode* acc_chain_sink_op = acc_chain_order_ops.back(); + + // NOTE(chengcheng): find last op can insert acc ctrl tick. + while ((!acc_chain_sink_op->op().op_conf().has_user_conf()) + || IsReachable(acc_chain_sink_op->op().op_name(), first_chain_src_op->op().op_name())) { + VLOG(3) << " cannot insert acc ctrl edge between: [" << first_chain_src_op->op().op_name() + << "] -> [" << acc_chain_sink_op->op().op_name() << "] , debug info :\n" + << first_chain_src_op->op().op_conf().DebugString() << "\n" + << acc_chain_sink_op->op().op_conf().DebugString() << "\n"; + + VLOG(3) << "remove op : " << acc_chain_sink_op->op().op_name() + << " from after acc logical chain: " << acc_chain_id; + acc_chain_order_ops.pop_back(); + if (acc_chain_order_ops.size() > 1) { + acc_chain_sink_op = acc_chain_order_ops.back(); + } else { + acc_chain_sink_op = nullptr; + break; + } + } + if (acc_chain_sink_op == nullptr) { return; } + + // NOTE(chengcheng): find first op can insert acc tick. + while (IsReachable(acc_chain_src_op->op().op_name(), first_chain_sink_op->op().op_name())) { + VLOG(3) << " cannot insert acc tick edge between: [" << first_chain_sink_op->op().op_name() + << "] -> [" << acc_chain_src_op->op().op_name() << "] , debug info :\n" + << first_chain_sink_op->op().op_conf().DebugString() << "\n" + << acc_chain_src_op->op().op_conf().DebugString() << "\n"; + + VLOG(3) << "remove op : " << acc_chain_src_op->op().op_name() + << " from after acc logical chain: " << acc_chain_id; + acc_chain_order_ops.erase(acc_chain_order_ops.begin()); + if (acc_chain_order_ops.size() > 1) { + acc_chain_src_op = acc_chain_order_ops.front(); + } else { + acc_chain_src_op = nullptr; + break; + } + } + if (acc_chain_src_op == nullptr) { return; } + + // NOTE(chengcheng): + // 1.add acc ctrl tick between first chain src to acc chain sink for memory lock. + const int64_t acc_num = job_builder->job().job_conf().num_gradient_accumulation_steps(); + CHECK_GT(acc_num, 1); + const auto& fc_src_obns = first_chain_src_op->op().output_bns(); + CHECK(!fc_src_obns.empty()); + const std::string& first_chain_src_out_lbn = + GenLogicalBlobName(first_chain_src_op->op().BnInOp2Lbi(fc_src_obns.Get(0))); + + VLOG(3) << " first_chain_src_out_lbn : " << first_chain_src_out_lbn; + user_op::UserOpConfWrapper acc_ctrl_tick_op = + user_op::UserOpConfWrapperBuilder("Sys-AccCtrlTick4MergeFirstAccChain-" + NewUniqueId()) + .OpTypeName("acc_ctrl_tick") + .Input("in", first_chain_src_out_lbn) + .Output("out") + .ScopeSymbolId(first_chain_src_op->op().op_conf().scope_symbol_id()) + .Attr("max_acc_num", acc_num) + .Build(); + + OperatorConf& acc_chain_sink_op_conf = + CHECK_JUST(MapAt(*mut_op_name2conf, acc_chain_sink_op->op().op_name())); + CHECK(acc_chain_sink_op_conf.has_user_conf()); + (*acc_chain_sink_op_conf.mutable_user_conf() + ->mutable_input())[user_op::kUserSourceOpTickInputArgName] + .add_s(acc_ctrl_tick_op.output("out", 0)); + CHECK_JUST(job_builder->AddOp(first_chain_src_op->parallel_desc().parallel_conf(), + acc_ctrl_tick_op.op_conf())); + VLOG(3) << " Insert acc ctrl tick between: [" << first_chain_src_op->op().op_name() << "] -> [" + << acc_chain_sink_op->op().op_name() << "]"; + + // NOTE(chengcheng): + // 2.add acc tick between first chain sink to acc chain src for strict exec order. + const auto& fc_sink_obns = first_chain_sink_op->op().output_bns(); + CHECK(!fc_sink_obns.empty()); + const std::string first_chain_sink_lbn = + GenLogicalBlobName(first_chain_sink_op->op().BnInOp2Lbi(fc_sink_obns.Get(0))); + VLOG(3) << " first_chain_sink_lbn : " << first_chain_sink_lbn; + + user_op::UserOpConfWrapper cast_to_tick_op = + user_op::UserOpConfWrapperBuilder("Sys-LogicalChainSink-CastToTick-" + NewUniqueId()) + .OpTypeName("cast_to_tick") + .Input("in", first_chain_sink_lbn) + .Output("out") + .ScopeSymbolId(first_chain_sink_op->op().op_conf().scope_symbol_id()) + .Build(); + + CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), + cast_to_tick_op.op_conf())); + + std::string acc_tick_output_lbn = cast_to_tick_op.output("out", 0); + if (!IsAccOpNode(first_chain_sink_op)) { + // NOTE(chengcheng): Acc Op can be merged in fw/bw chain, if the last op is acc op, + // there is no need and CANNOT insert acc tick op. + + OperatorConf sink_acc_tick_conf; + sink_acc_tick_conf.set_name(std::string("Sys-LogicalChainSink-AccTick_") + NewUniqueId()); + sink_acc_tick_conf.set_scope_symbol_id(first_chain_sink_op->op().op_conf().scope_symbol_id()); + auto* acc_conf = sink_acc_tick_conf.mutable_acc_tick_conf(); + acc_conf->set_one(cast_to_tick_op.output("out", 0)); + acc_conf->set_acc("acc"); + acc_conf->set_max_acc_num(acc_num); + acc_tick_output_lbn = GenLogicalBlobName(sink_acc_tick_conf.name(), "acc"); + + VLOG(3) << " insert acc tick op : " << sink_acc_tick_conf.name() + << " of last op in fw/bw chain."; + + CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), + sink_acc_tick_conf)); + } + + OperatorConf sink_final_tick_conf; + sink_final_tick_conf.set_name(std::string("Sys-LogicalChainSink-FinalTick-DeviceTick_") + + NewUniqueId()); + sink_final_tick_conf.set_scope_symbol_id(first_chain_sink_op->op().op_conf().scope_symbol_id()); + auto* tick_conf = sink_final_tick_conf.mutable_device_tick_conf(); + tick_conf->add_tick(acc_tick_output_lbn); + tick_conf->set_out("out"); + + CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), + sink_final_tick_conf)); + + CHECK_JUST(MapAt(*mut_op_name2conf, acc_chain_src_op->op().op_name())) + .add_ctrl_in_op_name(sink_final_tick_conf.name()); + + VLOG(3) << " Insert acc tick between: [" << first_chain_sink_op->op().op_name() << "] -> [" + << acc_chain_src_op->op().op_name() << "]"; + + // NOTE(chengcheng): + // 3. merge first chain and acc chain + MergedLogicalChainIdGroup* group = job_builder->add_logical_chain_groups(); + group->add_logical_chain_id_list(first_chain->logical_chain_id); + group->add_logical_chain_id_list(acc_chain_id); + VLOG(3) << " Merge acc chain : " << acc_chain_id + << " to first logcal chain : " << first_chain->logical_chain_id; +} + Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { std::vector ordered_op_nodes; HashMap op_node2global_order; @@ -347,6 +495,9 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui const auto& placement = pair.first; auto& info = pair.second; CHECK_GE(info.ordered_logical_chains.size(), 1); + + // NOTE(chengcheng): set logical chain id for each op in each logical chain, and insert ctrl + // edge for order. for (auto& logical_chain : info.ordered_logical_chains) { logical_chain->logical_chain_id = NewLogicalChainId(); InsertLogicalChainId(logical_chain->ordered_op_nodes, logical_chain->logical_chain_id); @@ -376,167 +527,26 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui if (acc_chain_order_ops.size() > 1) { info.after_acc_logical_chain->logical_chain_id = NewLogicalChainId(); std::sort(acc_chain_order_ops.begin(), acc_chain_order_ops.end(), CmpOpNodeOrder); - const auto& first_chain = info.ordered_logical_chains.front(); - const OpNode* first_chain_src_op = first_chain->ordered_op_nodes.front(); - const OpNode* first_chain_sink_op = first_chain->ordered_op_nodes.back(); - - const OpNode* acc_chain_src_op = acc_chain_order_ops.front(); - const OpNode* acc_chain_sink_op = acc_chain_order_ops.back(); - - // NOTE(chengcheng): find last op can insert acc ctrl tick. - while ( - (!acc_chain_sink_op->op().op_conf().has_user_conf()) - || IsReachable(acc_chain_sink_op->op().op_name(), first_chain_src_op->op().op_name())) { - VLOG(3) << " cannot insert acc ctrl edge between: [" << first_chain_src_op->op().op_name() - << "] -> [" << acc_chain_sink_op->op().op_name() << "] , debug info :\n" - << first_chain_src_op->op().op_conf().DebugString() << "\n" - << acc_chain_sink_op->op().op_conf().DebugString() << "\n"; - - VLOG(3) << "remove op : " << acc_chain_sink_op->op().op_name() - << " from after acc logical chain: " - << info.after_acc_logical_chain->logical_chain_id; - acc_chain_order_ops.pop_back(); - if (acc_chain_order_ops.size() > 1) { - acc_chain_sink_op = acc_chain_order_ops.back(); - } else { - acc_chain_sink_op = nullptr; - break; - } - } - if (acc_chain_sink_op == nullptr) { continue; } - - // NOTE(chengcheng): find first op can insert acc tick. - while (IsReachable(acc_chain_src_op->op().op_name(), first_chain_sink_op->op().op_name())) { - VLOG(3) << " cannot insert acc tick edge between: [" - << first_chain_sink_op->op().op_name() << "] -> [" - << acc_chain_src_op->op().op_name() << "] , debug info :\n" - << first_chain_sink_op->op().op_conf().DebugString() << "\n" - << acc_chain_src_op->op().op_conf().DebugString() << "\n"; - - VLOG(3) << "remove op : " << acc_chain_src_op->op().op_name() - << " from after acc logical chain: " - << info.after_acc_logical_chain->logical_chain_id; - acc_chain_order_ops.erase(acc_chain_order_ops.begin()); - if (acc_chain_order_ops.size() > 1) { - acc_chain_src_op = acc_chain_order_ops.front(); - } else { - acc_chain_src_op = nullptr; - break; - } - } - if (acc_chain_src_op == nullptr) { continue; } - - // NOTE(chengcheng): - // 1.add acc ctrl tick between first chain src to acc chain sink for memory lock. - const int64_t acc_num = job_builder->job().job_conf().num_gradient_accumulation_steps(); - CHECK_GT(acc_num, 1); - const auto& fc_src_obns = first_chain_src_op->op().output_bns(); - CHECK(!fc_src_obns.empty()); - const std::string& first_chain_src_out_lbn = - GenLogicalBlobName(first_chain_src_op->op().BnInOp2Lbi(fc_src_obns.Get(0))); - - VLOG(3) << " first_chain_src_out_lbn : " << first_chain_src_out_lbn; - user_op::UserOpConfWrapper acc_ctrl_tick_op = - user_op::UserOpConfWrapperBuilder("Sys-AccCtrlTick4MergeFirstAccChain-" + NewUniqueId()) - .OpTypeName("acc_ctrl_tick") - .Input("in", first_chain_src_out_lbn) - .Output("out") - .ScopeSymbolId(first_chain_src_op->op().op_conf().scope_symbol_id()) - .Attr("max_acc_num", acc_num) - .Build(); - - OperatorConf& acc_chain_sink_op_conf = - JUST(MapAt(mut_op_name2conf, acc_chain_sink_op->op().op_name())); - CHECK(acc_chain_sink_op_conf.has_user_conf()); - (*acc_chain_sink_op_conf.mutable_user_conf() - ->mutable_input())[user_op::kUserSourceOpTickInputArgName] - .add_s(acc_ctrl_tick_op.output("out", 0)); - JUST(job_builder->AddOp(first_chain_src_op->parallel_desc().parallel_conf(), - acc_ctrl_tick_op.op_conf())); - VLOG(3) << " Insert acc ctrl tick between: [" << first_chain_src_op->op().op_name() - << "] -> [" << acc_chain_sink_op->op().op_name() << "]"; - - // NOTE(chengcheng): - // 2.add acc tick between first chain sink to acc chain src for strict exec order. - const auto& fc_sink_obns = first_chain_sink_op->op().output_bns(); - CHECK(!fc_sink_obns.empty()); - const std::string first_chain_sink_lbn = - GenLogicalBlobName(first_chain_sink_op->op().BnInOp2Lbi(fc_sink_obns.Get(0))); - VLOG(3) << " first_chain_sink_lbn : " << first_chain_sink_lbn; - - user_op::UserOpConfWrapper cast_to_tick_op = - user_op::UserOpConfWrapperBuilder("Sys-LogicalChainSink-CastToTick-" + NewUniqueId()) - .OpTypeName("cast_to_tick") - .Input("in", first_chain_sink_lbn) - .Output("out") - .ScopeSymbolId(first_chain_sink_op->op().op_conf().scope_symbol_id()) - .Build(); - - CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), - cast_to_tick_op.op_conf())); - - std::string acc_tick_output_lbn = cast_to_tick_op.output("out", 0); - if (!IsAccOpNode(first_chain_sink_op)) { - // NOTE(chengcheng): Acc Op can be merged in fw/bw chain, if the last op is acc op, - // there is no need and CANNOT insert acc tick op. - - OperatorConf sink_acc_tick_conf; - sink_acc_tick_conf.set_name(std::string("Sys-LogicalChainSink-AccTick_") + NewUniqueId()); - sink_acc_tick_conf.set_scope_symbol_id( - first_chain_sink_op->op().op_conf().scope_symbol_id()); - auto* acc_conf = sink_acc_tick_conf.mutable_acc_tick_conf(); - acc_conf->set_one(cast_to_tick_op.output("out", 0)); - acc_conf->set_acc("acc"); - acc_conf->set_max_acc_num(acc_num); - acc_tick_output_lbn = GenLogicalBlobName(sink_acc_tick_conf.name(), "acc"); - - VLOG(3) << " insert acc tick op : " << sink_acc_tick_conf.name() - << " of last op in fw/bw chain."; - - CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), - sink_acc_tick_conf)); - } - OperatorConf sink_final_tick_conf; - sink_final_tick_conf.set_name(std::string("Sys-LogicalChainSink-FinalTick-DeviceTick_") - + NewUniqueId()); - sink_final_tick_conf.set_scope_symbol_id( - first_chain_sink_op->op().op_conf().scope_symbol_id()); - auto* tick_conf = sink_final_tick_conf.mutable_device_tick_conf(); - tick_conf->add_tick(acc_tick_output_lbn); - tick_conf->set_out("out"); - - CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), - sink_final_tick_conf)); - - JUST(MapAt(mut_op_name2conf, acc_chain_src_op->op().op_name())) - .add_ctrl_in_op_name(sink_final_tick_conf.name()); - - VLOG(3) << " Insert acc tick between: [" << first_chain_sink_op->op().op_name() << "] -> [" - << acc_chain_src_op->op().op_name() << "]"; - - // NOTE(chengcheng): - // 3. merge first chain and acc chain - MergedLogicalChainIdGroup* group = job_builder->add_logical_chain_groups(); - group->add_logical_chain_id_list(first_chain->logical_chain_id); - group->add_logical_chain_id_list(info.after_acc_logical_chain->logical_chain_id); - VLOG(3) << " Merge acc chain : " << info.after_acc_logical_chain->logical_chain_id - << " to first logcal chain : " << first_chain->logical_chain_id; - } + TryMergeAfterAccLogicalChainToFirstLogicalChain(&info, &mut_op_name2conf, job_builder, + IsReachable); - VLOG(3) << " In placement: " << placement - << " AccLogicalChain: " << info.after_acc_logical_chain->logical_chain_id - << " has op num = " << acc_chain_order_ops.size(); + if (acc_chain_order_ops.size() <= 1) { continue; } - for (int i = 0; i < acc_chain_order_ops.size(); ++i) { - const OpNode* ordered_op = JUST(VectorAt(acc_chain_order_ops, i)); - VLOG(3) << " AfterAccChainId: " << info.after_acc_logical_chain->logical_chain_id - << " order: " << i << " op_name: " << ordered_op->op().op_name() - << " global_order: " << JUST(MapAt(op_node2global_order, ordered_op)); - } + VLOG(3) << " In placement: " << placement + << " AccLogicalChain: " << info.after_acc_logical_chain->logical_chain_id + << " has op num = " << acc_chain_order_ops.size(); - InsertLogicalChainId(acc_chain_order_ops, info.after_acc_logical_chain->logical_chain_id); - InsertCtrlEdgeInChain(acc_chain_order_ops); + for (int i = 0; i < acc_chain_order_ops.size(); ++i) { + const OpNode* ordered_op = JUST(VectorAt(acc_chain_order_ops, i)); + VLOG(3) << " AfterAccChainId: " << info.after_acc_logical_chain->logical_chain_id + << " order: " << i << " op_name: " << ordered_op->op().op_name() + << " global_order: " << JUST(MapAt(op_node2global_order, ordered_op)); + } + + InsertLogicalChainId(acc_chain_order_ops, info.after_acc_logical_chain->logical_chain_id); + InsertCtrlEdgeInChain(acc_chain_order_ops); + } } } From 55db36d171c0ce7212b8d4377ac18494300b4845 Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 9 Sep 2022 12:05:49 +0000 Subject: [PATCH 23/66] remove debug log --- oneflow/core/job/plan_util.cpp | 10 +-- .../core/job_rewriter/logical_chain_pass.cpp | 2 - .../core/lazy/actor/acc_ctrl_tick_actor.cpp | 64 +------------------ oneflow/core/lazy/actor/repeat_actor.cpp | 22 ------- oneflow/core/operator/op_conf.proto | 1 - 5 files changed, 4 insertions(+), 95 deletions(-) diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index 2384feb41b8..337d4124b95 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -815,19 +815,11 @@ void PlanUtil::SetForceInplaceMemBlock(Plan* plan) { regst_desc->set_separated_header_mem_block_id( in_regst_desc->separated_header_mem_block_id()); } - LOG(INFO) << " cclog: set force inplace from " << regst_desc->DebugString() << " to " + VLOG(3) << " cclog: set force inplace from " << regst_desc->DebugString() << " to " << in_regst_desc->DebugString(); } } } - - // tmp debug log - for (int i = 0; i < plan->task_size(); i++) { - TaskProto* task = plan->mutable_task(i); - if (task->task_type() == kRepeat && task->produced_regst_desc().size() > 1) { - LOG(WARNING) << " bad repeat : " << task->DebugString(); - } - } } void PlanUtil::DumpCtrlRegstInfoToPlan(Plan* plan) { diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index f191d2b458b..4e991d9612c 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -467,8 +467,6 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui it->second.ordered_acc_op_nodes.emplace_back(this_node); } } - JUST(MapAt(mut_op_name2conf, this_node->op().op_name())) - .set_logical_order(JUST(MapAt(op_node2global_order, this_node))); } auto InsertCtrlEdgeInChain = [&](const std::vector& ordered_op_nodes) { diff --git a/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp b/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp index a1459812ac6..cf44fca2326 100644 --- a/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp +++ b/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp @@ -42,14 +42,9 @@ class AccCtrlTickActor : public Actor { bool IsCustomizedReadReady() const override { bool is_ready_ready = (!inplace_consume_) && consumed_tick_rs_.IsCurSlotReady(); - LOG(INFO) << " ccActorLog: actor: " << actor_id() << " is_ready_ready: " << is_ready_ready - << " of inplace_consume_ = " << inplace_consume_ - << " consumed_tick_rs_.IsCurSlotReady = " << consumed_tick_rs_.IsCurSlotReady(); return (!inplace_consume_) && consumed_tick_rs_.IsCurSlotReady(); } bool IsCustomizedWriteReady() const override { - LOG(INFO) << " ccActorLog: actor: " << actor_id() - << " is_write_ready: " << produced_tick_rs_.IsCurSlotReady(); return produced_tick_rs_.IsCurSlotReady(); } @@ -85,8 +80,6 @@ void AccCtrlTickActor::VirtualActorInit(const TaskProto& proto) { const Shape& in_time_shape = Singleton::Get() ->RegstDesc4RegstDescId(Name2SoleRegstDescId("in")) .data_regst_time_shape(); - // max_acc_num_ = in_time_shape.elem_cnt(); - // CHECK_GT(max_acc_num_, 1); const Shape& out_time_shape = Singleton::Get() ->RegstDesc4RegstDescId(Name2SoleRegstDescId("out")) .data_regst_time_shape(); @@ -106,19 +99,7 @@ void AccCtrlTickActor::VirtualActorInit(const TaskProto& proto) { // output CHECK_EQ(proto.produced_regst_desc().size(), 1); - /* - for (const auto& pair : proto.produced_regst_desc()) { - const RegstDescProto& out_regst_desc = pair.second; - if (out_regst_desc.regst_desc_type().has_ctrl_regst_desc()) { - CHECK_EQ(out_regst_desc.register_num(), 1); - CHECK_EQ(produced_tick_regst_desc_id_, -1); - produced_tick_regst_desc_id_ = out_regst_desc.regst_desc_id(); - produced_tick_rs_.InsertRegstDescId(produced_tick_regst_desc_id_); - produced_tick_rs_.InitedDone(); - } - } - CHECK_NE(produced_tick_regst_desc_id_, -1); - */ + const auto& produced_ids = proto.produced_regst_desc(); CHECK_EQ(produced_ids.size(), 1); CHECK(produced_ids.find("out") != produced_ids.end()); @@ -128,37 +109,18 @@ void AccCtrlTickActor::VirtualActorInit(const TaskProto& proto) { produced_tick_rs_.InitedDone(); ForEachProducedRegst([&](Regst* regst) { - // if (regst->regst_desc_id() == produced_tick_regst_desc_id_) { CHECK_EQ(regst->regst_desc_id(), produced_tick_regst_desc_id_); CHECK_EQ(0, produced_tick_rs_.TryPushBackRegst(regst)); - // } }); - LOG(INFO) << " ccActorLog: actor: " << actor_id() - << " has produced_tick_rs_ regst_descs = " << produced_tick_rs_.total_regst_desc_cnt() - << " with regsts size = " - << produced_tick_rs_.GetReadyRegstSize(produced_tick_regst_desc_id_); - LOG(INFO) << " ccActorLog: actor: " << actor_id() - << " has consumed_tick_rs_ regst_descs = " << consumed_tick_rs_.total_regst_desc_cnt() - << " with regsts size = " - << consumed_tick_rs_.GetReadyRegstSize(consumed_tick_regst_desc_id_); - LOG(INFO) - << " ccActorLog: actor: " << actor_id() - << " has inplace_consumed_rs_ regst_descs = " << inplace_consumed_rs_.total_regst_desc_cnt() - << " \nhas inplace_produced_rs_ regst_descs = " << inplace_produced_rs_.total_regst_desc_cnt() - << " \nhas naive_consumed_rs_ regst_descs = " << naive_consumed_rs_.total_regst_desc_cnt() - << " \nhas naive_produced_rs_ regst_descs = " << naive_produced_rs_.total_regst_desc_cnt(); - OF_SET_MSG_HANDLER(&AccCtrlTickActor::HandlerNormal); + OF_SET_MSG_HANDLER(&AccCtrlTickActor::HandlerNormal); } void AccCtrlTickActor::Act() { acc_cnt_ += 1; - LOG(INFO) << " ccActorLog: actor: " << actor_id() << " acc_count_ = " << acc_cnt_ - << " max_acc_num = " << max_acc_num_; if (acc_cnt_ == max_acc_num_) { CHECK(!inplace_consume_); inplace_consume_ = true; - LOG(INFO) << " ccActorLog: actor: " << actor_id() << " inplace_consume_ = true"; acc_cnt_ = 0; } } @@ -171,12 +133,7 @@ void AccCtrlTickActor::AsyncSendCustomizedProducedRegstMsgToConsumer() { CHECK_GT(HandleRegstToConsumer(tick_regst), 0); produced_tick_rs_.PopFrontRegsts({produced_tick_regst_desc_id_}); - LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << acc_cnt_ - << " Send ctrl_tick regst " << produced_tick_regst_desc_id_ << " to Consumer."; - } else { - LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << acc_cnt_ - << " SKIP to send produced to consumer."; - } + } } void AccCtrlTickActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { @@ -185,12 +142,6 @@ void AccCtrlTickActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { CHECK_NOTNULL(tick_regst); AsyncSendRegstMsgToProducer(tick_regst); CHECK_EQ(0, consumed_tick_rs_.TryPopFrontRegst(consumed_tick_regst_desc_id_)); - - LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << acc_cnt_ - << " return tick regst " << consumed_tick_regst_desc_id_ << " to producer."; - } else { - LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << acc_cnt_ - << " NOT return tick regst for waiting inplace tick regst returned. "; } } @@ -198,25 +149,16 @@ void AccCtrlTickActor::UpdtStateAsCustomizedProducedRegst(Regst* regst) { CHECK(inplace_consume_); CHECK_EQ(regst->regst_desc_id(), produced_tick_regst_desc_id_); CHECK_EQ(produced_tick_rs_.TryPushBackRegst(regst), 0); - LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << acc_cnt_ - << " regst_desc_id: " << produced_tick_regst_desc_id_ << " ready size = " - << produced_tick_rs_.GetReadyRegstSize(produced_tick_regst_desc_id_); Regst* in_regst = consumed_tick_rs_.Front(consumed_tick_regst_desc_id_); CHECK(in_regst); AsyncSendRegstMsgToProducer(in_regst); CHECK_EQ(0, consumed_tick_rs_.TryPopFrontRegst(consumed_tick_regst_desc_id_)); inplace_consume_ = false; - - LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << acc_cnt_ - << " consumed_regst_desc_id: " << consumed_tick_regst_desc_id_ - << " return with all produced regst."; } void AccCtrlTickActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) { CHECK_EQ(0, consumed_tick_rs_.TryPushBackRegst(msg.regst())); - LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << acc_cnt_ - << " receive input regst: " << msg.regst()->regst_desc_id(); } REGISTER_ACTOR(TaskType::kAccCtrlTick, AccCtrlTickActor); diff --git a/oneflow/core/lazy/actor/repeat_actor.cpp b/oneflow/core/lazy/actor/repeat_actor.cpp index 23ac25b3a83..f7a4ab7d893 100644 --- a/oneflow/core/lazy/actor/repeat_actor.cpp +++ b/oneflow/core/lazy/actor/repeat_actor.cpp @@ -48,15 +48,9 @@ class RepeatActor final : public Actor { bool IsCustomizedReadReady() const override { bool is_ready_ready = (!wait_all_regst_return_) && consumed_var_rs_.IsCurSlotReady(); - LOG(INFO) << " ccActorLog: actor: " << actor_id() << " is_ready_ready: " << is_ready_ready - << " of wait_all_regst_return_ = " << wait_all_regst_return_ - << " consumed_var_rs_.IsCurSlotReady = " << consumed_var_rs_.IsCurSlotReady(); return (!wait_all_regst_return_) && consumed_var_rs_.IsCurSlotReady(); } bool IsCustomizedWriteReady() const override { - LOG(INFO) << " ccActorLog: actor: " << actor_id() - << " is_write_ready: " << (!wait_all_regst_return_) - && produced_repeat_var_rs_.IsCurSlotReady(); return (!wait_all_regst_return_) && produced_repeat_var_rs_.IsCurSlotReady(); } @@ -170,8 +164,6 @@ void RepeatActor::Act() { Regst* out_regst = produced_repeat_var_rs_.Front(produced_repeat_var_regst_desc_id_); Regst* in_regst = consumed_var_rs_.Front(consumed_var_regst_desc_id_); CHECK(out_regst && in_regst); - // LOG(WARNING) << "cclog: RepeatVarActor: " - // << out_regst->regst_desc()->regst_desc_type().DebugString(); CHECK(out_regst->main_mem_ptr() == in_regst->main_mem_ptr()); CHECK(out_regst->separated_header_mem_ptr() == in_regst->separated_header_mem_ptr()); CHECK_EQ(out_regst->regst_desc()->MainByteSize4OneRegst(), @@ -185,23 +177,15 @@ void RepeatActor::AsyncSendCustomizedProducedRegstMsgToConsumer() { Regst* const repeat_var_regst = produced_repeat_var_rs_.Front(produced_repeat_var_regst_desc_id_); CHECK_GT(HandleRegstToConsumer(repeat_var_regst), 0); produced_repeat_var_rs_.PopFrontRegsts({produced_repeat_var_regst_desc_id_}); - - LOG(INFO) << "ccActorLog: repeat actor: " << actor_id() << " in repeat count: " << repeat_count_ - << " Send var regst " << produced_repeat_var_regst_desc_id_ << " to Consumer."; } void RepeatActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { // NOTE(chengcheng): do nothing. consumed var regst will return in inplace done. - LOG(INFO) << "ccActorLog: repeat actor: " << actor_id() << " in repeat count: " << repeat_count_ - << " NOT return var regst for waiting inplace var regst returned. "; } void RepeatActor::UpdtStateAsCustomizedProducedRegst(Regst* regst) { CHECK_EQ(regst->regst_desc_id(), produced_repeat_var_regst_desc_id_); CHECK_EQ(produced_repeat_var_rs_.TryPushBackRegst(regst), 0); - LOG(INFO) << "ccActorLog: repeat actor: " << actor_id() << " in count: " << repeat_count_ - << " regst_desc_id: " << produced_repeat_var_regst_desc_id_ << " ready size = " - << produced_repeat_var_rs_.GetReadyRegstSize(produced_repeat_var_regst_desc_id_); if (wait_all_regst_return_ && produced_repeat_var_rs_.GetReadyRegstSize(produced_repeat_var_regst_desc_id_) @@ -211,17 +195,11 @@ void RepeatActor::UpdtStateAsCustomizedProducedRegst(Regst* regst) { AsyncSendRegstMsgToProducer(in_regst); CHECK_EQ(0, consumed_var_rs_.TryPopFrontRegst(consumed_var_regst_desc_id_)); wait_all_regst_return_ = false; - - LOG(INFO) << "ccActorLog: repeat actor: " << actor_id() << " in count: " << repeat_count_ - << " consumed_var_regst_desc_id: " << consumed_var_regst_desc_id_ - << " return with all produced regst."; } } void RepeatActor::NormalProcessCustomizedReadableRegstMsg(const ActorMsg& msg) { CHECK_EQ(0, consumed_var_rs_.TryPushBackRegst(msg.regst())); - LOG(INFO) << "ccActorLog: actor: " << actor_id() << " in count: " << repeat_count_ - << " receive var regst: " << msg.regst()->regst_desc_id(); } REGISTER_ACTOR(TaskType::kRepeat, RepeatActor); diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index cd77ae32f70..162b4308038 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -398,7 +398,6 @@ message OperatorConf { optional string pass_tag = 10; optional string loc = 11 [default = ""]; optional int64 logical_chain_id = 12 [default = -1]; - optional int64 logical_order = 13 [default = -1]; oneof op_type { // system op CopyCommNetOpConf copy_comm_net_conf = 106; From 317e4a521af6a51ff1168eddbac916982c06f8e0 Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 9 Sep 2022 12:06:21 +0000 Subject: [PATCH 24/66] rm old version repeat kernel --- oneflow/user/kernels/repeat_kernel.cpp | 53 -------------------------- 1 file changed, 53 deletions(-) delete mode 100644 oneflow/user/kernels/repeat_kernel.cpp diff --git a/oneflow/user/kernels/repeat_kernel.cpp b/oneflow/user/kernels/repeat_kernel.cpp deleted file mode 100644 index c664887936f..00000000000 --- a/oneflow/user/kernels/repeat_kernel.cpp +++ /dev/null @@ -1,53 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/framework/framework.h" -#include "oneflow/core/kernel/new_kernel_util.h" - -namespace oneflow { - -namespace { - -template -class RepeatKernel final : public user_op::OpKernel { - public: - RepeatKernel() = default; - ~RepeatKernel() override = default; - - private: - void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); - user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); - CHECK_EQ(in->shape_view().elem_cnt(), out->shape_view().elem_cnt()); - CHECK_EQ(in->data_type(), out->data_type()); - Memcpy(ctx->stream(), out->mut_dptr(), in->dptr(), - in->shape_view().elem_cnt() * GetSizeOfDataType(in->data_type())); - } - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } -}; - -/* -#define REGISTER_REPEAT_KERNEL(device) \ - REGISTER_USER_KERNEL("repeat").SetCreateFn>().SetIsMatchedHob( \ - (user_op::HobDeviceType() == device)); - -OF_PP_FOR_EACH_TUPLE(REGISTER_REPEAT_KERNEL, DEVICE_TYPE_SEQ) - -#undef REGISTER_REPEAT_KERNEL -*/ - -} // namespace - -} // namespace oneflow From 2e56c11c8a5aa1367d46eb1021e2c4bd6ec9cf9a Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 9 Sep 2022 14:09:07 +0000 Subject: [PATCH 25/66] fix format --- oneflow/core/job/plan_util.cpp | 2 +- oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index 337d4124b95..064e4635cc9 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -816,7 +816,7 @@ void PlanUtil::SetForceInplaceMemBlock(Plan* plan) { in_regst_desc->separated_header_mem_block_id()); } VLOG(3) << " cclog: set force inplace from " << regst_desc->DebugString() << " to " - << in_regst_desc->DebugString(); + << in_regst_desc->DebugString(); } } } diff --git a/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp b/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp index cf44fca2326..259734ade26 100644 --- a/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp +++ b/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp @@ -44,9 +44,7 @@ class AccCtrlTickActor : public Actor { bool is_ready_ready = (!inplace_consume_) && consumed_tick_rs_.IsCurSlotReady(); return (!inplace_consume_) && consumed_tick_rs_.IsCurSlotReady(); } - bool IsCustomizedWriteReady() const override { - return produced_tick_rs_.IsCurSlotReady(); - } + bool IsCustomizedWriteReady() const override { return produced_tick_rs_.IsCurSlotReady(); } void NormalProcessCustomizedEordMsg(const ActorMsg&) override {} bool IsCustomizedReadAlwaysUnReadyFromNow() const override { @@ -99,7 +97,7 @@ void AccCtrlTickActor::VirtualActorInit(const TaskProto& proto) { // output CHECK_EQ(proto.produced_regst_desc().size(), 1); - + const auto& produced_ids = proto.produced_regst_desc(); CHECK_EQ(produced_ids.size(), 1); CHECK(produced_ids.find("out") != produced_ids.end()); @@ -113,7 +111,7 @@ void AccCtrlTickActor::VirtualActorInit(const TaskProto& proto) { CHECK_EQ(0, produced_tick_rs_.TryPushBackRegst(regst)); }); - OF_SET_MSG_HANDLER(&AccCtrlTickActor::HandlerNormal); + OF_SET_MSG_HANDLER(&AccCtrlTickActor::HandlerNormal); } void AccCtrlTickActor::Act() { @@ -132,8 +130,7 @@ void AccCtrlTickActor::AsyncSendCustomizedProducedRegstMsgToConsumer() { Regst* const tick_regst = produced_tick_rs_.Front(produced_tick_regst_desc_id_); CHECK_GT(HandleRegstToConsumer(tick_regst), 0); produced_tick_rs_.PopFrontRegsts({produced_tick_regst_desc_id_}); - - } + } } void AccCtrlTickActor::AsyncSendCustomizedConsumedRegstMsgToProducer() { From 5d83933e83e83a3f340c90cf7bb71b391faa9e59 Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 9 Sep 2022 15:10:18 +0000 Subject: [PATCH 26/66] MergeChainByLogicalChainId/PhysicalTaskGraph --- 1 | 24 +++++++++++++++ oneflow/core/graph/task_graph.cpp | 51 +++++++++++++++++-------------- oneflow/core/graph/task_graph.h | 3 +- 3 files changed, 54 insertions(+), 24 deletions(-) create mode 100644 1 diff --git a/1 b/1 new file mode 100644 index 00000000000..c7906f4c585 --- /dev/null +++ b/1 @@ -0,0 +1,24 @@ +Unknown argument 2 +Usage: cmake --build [options] [-- [native-options]] + cmake --build --preset [options] [-- [native-options]] +Options: + = Project binary directory to be built. + --preset , --preset= + = Specify a build preset. + --list-presets + = List available build presets. + --parallel [], -j [] + = Build in parallel using the given number of jobs. + If is omitted the native build tool's + default number is used. + The CMAKE_BUILD_PARALLEL_LEVEL environment variable + specifies a default parallel level when this option + is not given. + --target ..., -t ... + = Build instead of default targets. + --config = For multi-configuration tools, choose . + --clean-first = Build target 'clean' first, then build. + (To clean only, use --target 'clean'.) + --verbose, -v = Enable verbose output - if supported - including + the build commands to be executed. + -- = Pass remaining options to the native tool. diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 12c73d53d91..d55d5eb262c 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -538,7 +538,12 @@ void TaskGraph::RemoveEmptyRegsts() { } void TaskGraph::MergeChainAndAddOrderingCtrlEdgeInSameChain() { - MergeChain(); + if (EnableLogicalChain()) { + MergeChainByLogicalChainId(); + } else { + // TODO(chengcheng): erase old chain version in the future. + MergeChainByPhysicalTaskGraph(); + } BuildCtrlRegstDescInSameChain(); } @@ -552,30 +557,30 @@ void TaskGraph::SetOrderInGraphForEachNode() { TopoForEachNode(SetOrderInGraph); } -void TaskGraph::MergeChain() { - if (EnableLogicalChain()) { - for (TaskNode* this_node : ordered_task_nodes_) { - CompTaskNode* comp_node = dynamic_cast(this_node); - if (!comp_node) { continue; } - const int64_t logical_chain_id = comp_node->op()->op_conf().logical_chain_id(); - if (logical_chain_id != -1) { this_node->set_chain_id(logical_chain_id); } - } - } else { - int64_t chain_id = 0; - for (auto* this_node : ordered_task_nodes_) { - // skip if this node has been set in a chain. - if (this_node->chain_id() != -1) { continue; } - - CHECK_EQ(this_node->chain_id(), -1); - if (CanBeMergedInChain(this_node)) { - TraverseConnectedSubGraphMergeInThisChain(this_node, chain_id); - } else { - this_node->set_chain_id(chain_id); - } +void TaskGraph::MergeChainByPhysicalTaskGraph() { + int64_t chain_id = 0; + for (auto* this_node : ordered_task_nodes_) { + // skip if this node has been set in a chain. + if (this_node->chain_id() != -1) { continue; } - ++chain_id; + CHECK_EQ(this_node->chain_id(), -1); + if (CanBeMergedInChain(this_node)) { + TraverseConnectedSubGraphMergeInThisChain(this_node, chain_id); + } else { + this_node->set_chain_id(chain_id); } - for (auto* node : ordered_task_nodes_) { CHECK_NE(node->chain_id(), -1); } + + ++chain_id; + } + for (auto* node : ordered_task_nodes_) { CHECK_NE(node->chain_id(), -1); } +} + +void TaskGraph::MergeChainByLogicalChainId() { + for (TaskNode* this_node : ordered_task_nodes_) { + CompTaskNode* comp_node = dynamic_cast(this_node); + if (!comp_node) { continue; } + const int64_t logical_chain_id = comp_node->op()->op_conf().logical_chain_id(); + if (logical_chain_id != -1) { this_node->set_chain_id(logical_chain_id); } } } diff --git a/oneflow/core/graph/task_graph.h b/oneflow/core/graph/task_graph.h index f537f601a36..3dd6cb6077f 100644 --- a/oneflow/core/graph/task_graph.h +++ b/oneflow/core/graph/task_graph.h @@ -81,7 +81,8 @@ class TaskGraph final : public Graph { const std::vector& dst_task_nodes); void SetOrderInGraphForEachNode(); - void MergeChain(); + void MergeChainByPhysicalTaskGraph(); + void MergeChainByLogicalChainId(); void BuildCtrlRegstDescInSameChain(); // inplace From 0e7eb729744424eaa460aca542facba5e68b59d4 Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 9 Sep 2022 15:35:36 +0000 Subject: [PATCH 27/66] IsValidChainId --- oneflow/core/graph/plan_task_graph.cpp | 1 - oneflow/core/graph/task_graph.cpp | 20 ++++++++++---------- oneflow/core/graph/task_node.cpp | 2 +- oneflow/core/graph/task_node.h | 2 ++ oneflow/core/job/plan_util.cpp | 5 +++-- 5 files changed, 16 insertions(+), 14 deletions(-) diff --git a/oneflow/core/graph/plan_task_graph.cpp b/oneflow/core/graph/plan_task_graph.cpp index 02079dbc82d..95b9cbff48f 100644 --- a/oneflow/core/graph/plan_task_graph.cpp +++ b/oneflow/core/graph/plan_task_graph.cpp @@ -19,7 +19,6 @@ namespace oneflow { int64_t PlanTaskNode::chain_id() const { int64_t chain_id = task_proto_->task_set_info().chain_id(); - CHECK_NE(chain_id, -1); return chain_id; } diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index d55d5eb262c..d5734530c3c 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -109,8 +109,8 @@ std::shared_ptr GetTaskNodeTimeShape(const TaskNode* node) { } void TraverseConnectedSubGraphMergeInThisChain(TaskNode* this_node, const int64_t this_chain_id) { - CHECK_NE(this_chain_id, -1); - CHECK_EQ(this_node->chain_id(), -1); + CHECK(IsValidChainId(this_chain_id)); + CHECK(!IsValidChainId(this_node->chain_id())); // bfs search all node can be merged in this chain std::shared_ptr seed_time_shape = GetTaskNodeTimeShape(this_node); HashSet visited_nodes; @@ -121,14 +121,14 @@ void TraverseConnectedSubGraphMergeInThisChain(TaskNode* this_node, const int64_ TaskNode* cur_node = queued_nodes.front(); queued_nodes.pop(); - CHECK_EQ(cur_node->chain_id(), -1); + CHECK(!IsValidChainId(cur_node->chain_id())); cur_node->set_chain_id(this_chain_id); cur_node->ForEachNodeOnInOutDataEdge([&](TaskNode* next_node) { if (visited_nodes.find(next_node) == visited_nodes.end() && CanBeMergedInChain(next_node) && this_node->thrd_id() == next_node->thrd_id() && (*GetTaskNodeTimeShape(next_node)) == (*seed_time_shape)) { - if (next_node->chain_id() == -1) { + if (!IsValidChainId(next_node->chain_id())) { queued_nodes.push(next_node); visited_nodes.insert(next_node); } else { @@ -170,7 +170,7 @@ MakePredicatorIsLbiAllConsumersReachable( IsOpNameDataOrCtrlReachable) { auto IsDataOrCtrlReachable = [IsOpNameDataOrCtrlReachable](const TaskNode* src_node, const TaskNode* dst_node) -> bool { - if (src_node->chain_id() != -1 && dst_node->chain_id() != -1 + if (IsValidChainId(src_node->chain_id()) && IsValidChainId(dst_node->chain_id()) && src_node->chain_id() == dst_node->chain_id() && src_node->order_in_graph() <= dst_node->order_in_graph()) { return true; @@ -561,9 +561,8 @@ void TaskGraph::MergeChainByPhysicalTaskGraph() { int64_t chain_id = 0; for (auto* this_node : ordered_task_nodes_) { // skip if this node has been set in a chain. - if (this_node->chain_id() != -1) { continue; } + if (IsValidChainId(this_node->chain_id())) { continue; } - CHECK_EQ(this_node->chain_id(), -1); if (CanBeMergedInChain(this_node)) { TraverseConnectedSubGraphMergeInThisChain(this_node, chain_id); } else { @@ -572,7 +571,7 @@ void TaskGraph::MergeChainByPhysicalTaskGraph() { ++chain_id; } - for (auto* node : ordered_task_nodes_) { CHECK_NE(node->chain_id(), -1); } + for (auto* node : ordered_task_nodes_) { CHECK(IsValidChainId(node->chain_id())); } } void TaskGraph::MergeChainByLogicalChainId() { @@ -580,7 +579,7 @@ void TaskGraph::MergeChainByLogicalChainId() { CompTaskNode* comp_node = dynamic_cast(this_node); if (!comp_node) { continue; } const int64_t logical_chain_id = comp_node->op()->op_conf().logical_chain_id(); - if (logical_chain_id != -1) { this_node->set_chain_id(logical_chain_id); } + if (IsValidChainId(logical_chain_id)) { this_node->set_chain_id(logical_chain_id); } } } @@ -592,7 +591,8 @@ void TaskGraph::BuildCtrlRegstDescInSameChain() { HashMap physical_chain_id2node; for (auto* node : ordered_task_nodes_) { if (IsConnectToTickOp(node)) { continue; } - if (node->chain_id() == -1) { continue; } // NOTE(chengcheng): skip chain id default -1. + // NOTE(chengcheng): skip invalid chain id + if (!IsValidChainId(node->chain_id())) { continue; } int64_t physical_chain_id = GenPhysicalChainId(node); auto iter = physical_chain_id2node.find(physical_chain_id); if (iter == physical_chain_id2node.end()) { diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index 65b741d47d7..1ea81b6c39e 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -84,7 +84,7 @@ void TaskNode::set_thrd_id(int64_t val) { } void TaskNode::set_chain_id(int64_t val) { - CHECK_EQ(chain_id_, -1); + CHECK(!IsValidChainId(chain_id_)); chain_id_ = val; } diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index f35db587bf2..c25e6261f96 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -40,6 +40,8 @@ RegstDescProto* FindOrCreateProducedCtrlRegstDesc(TaskProto* task_proto, RegstDescIdSet* FindOrCreateConsumedCtrlRegstDescIdSet(TaskProto* task_proto, const std::string& regst_desc_name); +bool inline IsValidChainId(int64_t val) { return val >= 0; } + class TaskEdge; class TaskNode : public Node { diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index 064e4635cc9..7e0856ff17f 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -27,6 +27,7 @@ limitations under the License. #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/ep/include/device_manager_registry.h" #include "oneflow/core/operator/operator.h" +#include "oneflow/core/graph/task_node.h" namespace oneflow { @@ -221,7 +222,7 @@ void PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job) { DeviceType device_type = stream_id.device_id().device_type(); // TODO(zwx): eliminate this special 'is cpu' determine if (device_type == DeviceType::kCPU) { continue; } - if (task->task_set_info().chain_id() == -1) { continue; } + if (!IsValidChainId(task->task_set_info().chain_id())) { continue; } int64_t logical_chain_id = task->task_set_info().chain_id(); for (auto& pair : *(task->mutable_produced_regst_desc())) { @@ -282,7 +283,7 @@ void PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job) { DeviceType device_type = stream_id.device_id().device_type(); // TODO(zwx): eliminate this special 'is cpu' determine if (device_type == DeviceType::kCPU) { continue; } - if (task->task_set_info().chain_id() == -1) { continue; } + if (!IsValidChainId(task->task_set_info().chain_id())) { continue; } for (auto& pair : *(task->mutable_produced_regst_desc())) { RegstDescProto* regst_desc = &pair.second; From 0459319c2cfd688ee1ccec6a3db3a5de756f665e Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 9 Sep 2022 15:39:42 +0000 Subject: [PATCH 28/66] rm useless file --- 1 | 24 ------------------------ 1 file changed, 24 deletions(-) delete mode 100644 1 diff --git a/1 b/1 deleted file mode 100644 index c7906f4c585..00000000000 --- a/1 +++ /dev/null @@ -1,24 +0,0 @@ -Unknown argument 2 -Usage: cmake --build [options] [-- [native-options]] - cmake --build --preset [options] [-- [native-options]] -Options: - = Project binary directory to be built. - --preset , --preset= - = Specify a build preset. - --list-presets - = List available build presets. - --parallel [], -j [] - = Build in parallel using the given number of jobs. - If is omitted the native build tool's - default number is used. - The CMAKE_BUILD_PARALLEL_LEVEL environment variable - specifies a default parallel level when this option - is not given. - --target ..., -t ... - = Build instead of default targets. - --config = For multi-configuration tools, choose . - --clean-first = Build target 'clean' first, then build. - (To clean only, use --target 'clean'.) - --verbose, -v = Enable verbose output - if supported - including - the build commands to be executed. - -- = Pass remaining options to the native tool. From e554b89ffc099d16a58e2b76fe6176917fa33197 Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Fri, 9 Sep 2022 15:45:54 +0000 Subject: [PATCH 29/66] remove note --- oneflow/core/job/plan_util.cpp | 10 ---------- oneflow/core/job_rewriter/logical_chain_pass.cpp | 1 - 2 files changed, 11 deletions(-) diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index 7e0856ff17f..d7465b60e58 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -213,8 +213,6 @@ void PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job) { HashMap> mem_block_id2regsts; HashMap> logical_chain_id2machine_id2mem_block_id; - // HashMap mem_block_id2machine_id; - for (int64_t i = 0; i < plan->task_size(); ++i) { TaskProto* task = plan->mutable_task(i); const StreamId stream_id = PlanUtil::GetStreamId(*task); @@ -238,14 +236,6 @@ void PlanUtil::MergeMemBlockIdByLogicalChainId(Plan* plan, const Job& job) { } else { CHECK_EQ(rank2blocks->at(machine_id), mem_block_id); } - - /* - if (mem_block_id2machine_id.find(mem_block_id) == mem_block_id2machine_id.end()) { - mem_block_id2machine_id.emplace(mem_block_id, machine_id); - } else { - CHECK_EQ(mem_block_id2machine_id.at(mem_block_id), machine_id); - } - */ } } } diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index 4e991d9612c..310ccfba6ef 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -237,7 +237,6 @@ void CreateAfterAccLogicalChain(const std::shared_ptr& after_acc_l CHECK(after_acc_chain_ops.insert(cur_node).second); for (const OpEdge* in_edge : cur_node->in_edges()) { - // NOTE(chengcheng): maybe bad case for too early source op before repeat. SearchToNextNode(cur_node, in_edge->src_node(), in_edge); } for (const OpEdge* out_edge : cur_node->out_edges()) { From db24bad183cbcdee4a598a4c7a00d7dc0fe2f5ba Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Sat, 10 Sep 2022 05:26:37 +0000 Subject: [PATCH 30/66] fix clang-tidy --- oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp | 1 - oneflow/core/lazy/actor/repeat_actor.cpp | 1 - 2 files changed, 2 deletions(-) diff --git a/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp b/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp index 259734ade26..cf68ee0c2fc 100644 --- a/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp +++ b/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp @@ -41,7 +41,6 @@ class AccCtrlTickActor : public Actor { } bool IsCustomizedReadReady() const override { - bool is_ready_ready = (!inplace_consume_) && consumed_tick_rs_.IsCurSlotReady(); return (!inplace_consume_) && consumed_tick_rs_.IsCurSlotReady(); } bool IsCustomizedWriteReady() const override { return produced_tick_rs_.IsCurSlotReady(); } diff --git a/oneflow/core/lazy/actor/repeat_actor.cpp b/oneflow/core/lazy/actor/repeat_actor.cpp index f7a4ab7d893..8160d1bba9a 100644 --- a/oneflow/core/lazy/actor/repeat_actor.cpp +++ b/oneflow/core/lazy/actor/repeat_actor.cpp @@ -47,7 +47,6 @@ class RepeatActor final : public Actor { } bool IsCustomizedReadReady() const override { - bool is_ready_ready = (!wait_all_regst_return_) && consumed_var_rs_.IsCurSlotReady(); return (!wait_all_regst_return_) && consumed_var_rs_.IsCurSlotReady(); } bool IsCustomizedWriteReady() const override { From 8b9760879a81368623355e63d067ce574994abb0 Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Wed, 14 Sep 2022 05:20:41 +0000 Subject: [PATCH 31/66] more IsValidChainId --- oneflow/core/job/intra_job_mem_sharing_util.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/job/intra_job_mem_sharing_util.cpp b/oneflow/core/job/intra_job_mem_sharing_util.cpp index f1a2fef5b31..6b8c6cc58b1 100644 --- a/oneflow/core/job/intra_job_mem_sharing_util.cpp +++ b/oneflow/core/job/intra_job_mem_sharing_util.cpp @@ -99,7 +99,7 @@ void InitMemoryChains(Plan* plan, DeviceType device_type = stream_id.device_id().device_type(); // TODO(zwx): eliminate this special 'is cpu' determine if (device_type == DeviceType::kCPU) { continue; } - if (task->task_set_info().chain_id() == -1) { continue; } + if (!IsValidChainId(task->task_set_info().chain_id())) { continue; } int64_t device_id = stream_id.device_id().device_index(); int64_t device_unique_id = GenDeviceUniqueId(machine_id, device_id); MemoryChain* mem_chain = From 448f5e96a324b159074a27d644336f412bb53349 Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Wed, 14 Sep 2022 07:02:54 +0000 Subject: [PATCH 32/66] rm debug log --- oneflow/core/job/plan_util.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index d7465b60e58..c1466478b7e 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -78,9 +78,6 @@ void PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(Plan* plan) { if (regst_separated_size > 0) { int64_t separated_mem_block_id = Singleton::Get()->NewMemBlockId(); regst_desc->set_separated_header_mem_block_id(separated_mem_block_id); - LOG(INFO) << "set sep id, regst: " << regst_desc->regst_desc_id() - << " , sep : " << separated_mem_block_id - << " , debug: " << regst_desc->DebugString(); } } } From 5a2ff2b8869cd62949f2f784ce5a28db9f4d6336 Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Wed, 14 Sep 2022 08:16:02 +0000 Subject: [PATCH 33/66] rm note --- oneflow/core/job/plan_util.cpp | 4 ---- oneflow/core/lazy/actor/repeat_actor.cpp | 1 - 2 files changed, 5 deletions(-) diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index c1466478b7e..6bb6fe1bec6 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -377,10 +377,6 @@ void PlanUtil::GenMemBlockAndChunkWithVariableOpNames4Plan( } if (regst_separated_size > 0) { - if (regst_desc->has_separated_header_mem_block_id()) { - LOG(INFO) << "ccdebuglog: wrong, sep id, regst: " << regst_desc->regst_desc_id() - << " , debug: " << regst_desc->DebugString(); - } CHECK(regst_desc->has_separated_header_mem_block_id()) << regst_desc->DebugString(); int64_t separated_mem_block_id = regst_desc->separated_header_mem_block_id(); CHECK_NE(separated_mem_block_id, -1); diff --git a/oneflow/core/lazy/actor/repeat_actor.cpp b/oneflow/core/lazy/actor/repeat_actor.cpp index 8160d1bba9a..899102b4245 100644 --- a/oneflow/core/lazy/actor/repeat_actor.cpp +++ b/oneflow/core/lazy/actor/repeat_actor.cpp @@ -73,7 +73,6 @@ class RepeatActor final : public Actor { int64_t produced_repeat_var_regst_desc_id_; RegstSlot consumed_var_rs_; RegstSlot produced_repeat_var_rs_; - // Regst* var_regst_; }; void RepeatActor::VirtualActorInit(const TaskProto& proto) { From 680cb0df3a1ded4182a9f591028b948c0d61c53b Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Wed, 14 Sep 2022 16:13:00 +0000 Subject: [PATCH 34/66] fix bug of cpu repeat inplace var bug --- oneflow/core/job/plan_util.cpp | 49 +++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index 6bb6fe1bec6..475c90fccb0 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -53,7 +53,46 @@ std::function PlanUtil::MakeGetterTaskProto4TaskId(co return [task_id2task_proto](int64_t task_id) { return task_id2task_proto->at(task_id); }; } +namespace { + +void SetVariableOpNamesForVariableAndRepeatRegst(Plan* plan) { + // NOTE(chengcheng): set variable_op_name before set separated header because var regst alway + // separated. + HashMap regst_id2var_name; + for (int i = 0; i < plan->task_size(); i++) { + TaskProto* task = plan->mutable_task(i); + if (task->exec_sequence().exec_node_size() == 1) { + const auto& op_conf = + PlanUtil::GetOpAttribute(plan, task->job_id(), + task->exec_sequence().exec_node(0).kernel_conf()) + .op_conf(); + if (op_conf.has_variable_conf()) { + RegstDescProto* regst = PlanUtil::GetSoleProducedDataRegst(task); + regst_id2var_name.emplace(regst->regst_desc_id(), op_conf.name()); + regst->set_variable_op_name(op_conf.name()); + } + } + } + + for (int i = 0; i < plan->task_size(); i++) { + TaskProto* task = plan->mutable_task(i); + if (task->task_type() == TaskType::kRepeat) { + RegstDescProto* regst = PlanUtil::GetSoleProducedDataRegst(task); + CHECK(regst->has_force_inplace_consumed_regst_desc_id()); + int64_t force_inplace_regst_id = regst->force_inplace_consumed_regst_desc_id(); + if (regst_id2var_name.find(force_inplace_regst_id) != regst_id2var_name.end()) { + regst->set_variable_op_name(regst_id2var_name.at(force_inplace_regst_id)); + VLOG(3) << " set var op name to repeat regst : " << regst->DebugString(); + } + } + } +} + +} // namespace + void PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(Plan* plan) { + SetVariableOpNamesForVariableAndRepeatRegst(plan); + for (int i = 0; i < plan->task_size(); i++) { TaskProto* task = plan->mutable_task(i); @@ -64,14 +103,6 @@ void PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(Plan* plan) { regst_desc->set_mem_block_id(Singleton::Get()->NewMemBlockId()); regst_desc->set_mem_block_offset(0); } - // NOTE(chengcheng): set variable_op_name before set separated header because var regst alway - // separated. - if (task->exec_sequence().exec_node_size() == 1) { - const auto& op_conf = - GetOpAttribute(plan, task->job_id(), task->exec_sequence().exec_node(0).kernel_conf()) - .op_conf(); - if (op_conf.has_variable_conf()) { regst_desc->set_variable_op_name(op_conf.name()); } - } RtRegstDesc rt_regst_desc(*regst_desc); int64_t regst_separated_size = rt_regst_desc.TotalSeparatedHeaderByteSize4AllRegst(); @@ -799,7 +830,7 @@ void PlanUtil::SetForceInplaceMemBlock(Plan* plan) { regst_desc->set_separated_header_mem_block_id( in_regst_desc->separated_header_mem_block_id()); } - VLOG(3) << " cclog: set force inplace from " << regst_desc->DebugString() << " to " + VLOG(3) << " set force inplace from " << regst_desc->DebugString() << " to " << in_regst_desc->DebugString(); } } From 70b7e6cc4d9306d5a9f6f4e82d5a99b70423b0d8 Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Thu, 15 Sep 2022 09:20:32 +0000 Subject: [PATCH 35/66] fix bug of memory reuse for 0-size regst in time line algo --- oneflow/core/job/intra_job_mem_sharing_util.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/oneflow/core/job/intra_job_mem_sharing_util.cpp b/oneflow/core/job/intra_job_mem_sharing_util.cpp index 6b8c6cc58b1..2d05df47aa3 100644 --- a/oneflow/core/job/intra_job_mem_sharing_util.cpp +++ b/oneflow/core/job/intra_job_mem_sharing_util.cpp @@ -107,13 +107,14 @@ void InitMemoryChains(Plan* plan, mem_chain->sorted_tasks.emplace_back(task); for (auto& pair : *(task->mutable_produced_regst_desc())) { RegstDescProto* regst_desc = &pair.second; + int64_t regst_total_main_size = RtRegstDesc(*regst_desc).TotalMainByteSize4AllRegst(); if (regst_desc->mem_case().device_type() == device_type && regst_desc->mem_case().device_id() == device_id && regst_desc->enable_reuse_mem() && regst_desc->register_num() == 1 && regst_desc->mem_block_id() == -1 && regst_desc->mem_block_offset() == -1 - && regst_desc->regst_desc_type().has_data_regst_desc()) { + && regst_desc->regst_desc_type().has_data_regst_desc() && regst_total_main_size > 0) { CHECK(mem_chain->mem_reused_regsts.insert(regst_desc).second); - mem_chain->total_mem_reused_size += RtRegstDesc(*regst_desc).TotalMainByteSize4AllRegst(); + mem_chain->total_mem_reused_size += regst_total_main_size; // for time shape in mem chain Shape regst_time_shape = From 2f3b2aefa883c08ae7bf1b14b998e65ab621dd6a Mon Sep 17 00:00:00 2001 From: chengtbf <472491134@qq.com> Date: Sun, 9 Oct 2022 07:09:42 +0000 Subject: [PATCH 36/66] fix bug of acc chain merge mem guard --- .../insert_nccl_logical_op_pass.cpp | 18 +- .../core/job_rewriter/logical_chain_pass.cpp | 158 ++++++++++++------ 2 files changed, 127 insertions(+), 49 deletions(-) diff --git a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp index d8aff601abf..c9f53a2aca3 100644 --- a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp +++ b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp @@ -102,6 +102,17 @@ bool SharedPtrShapeEqual(const std::shared_ptr& lhs, void FindAllConnectedSubgraphForGpuExecOrder(std::vector>* ret, const OpGraph& op_graph, const std::vector& order) { + // NOTE(chengcheng): acc subgraph may greater than fw/bw subgraph. we need use max time shape. + std::shared_ptr seed_time_shape = std::make_shared(Shape({1, 1})); + op_graph.ForEachNode([&](const OpNode* node) { + std::shared_ptr this_time_shape = GetOpNodeTimeShape(node); + if (this_time_shape->elem_cnt() > seed_time_shape->elem_cnt()) { + seed_time_shape = this_time_shape; + } + }); + + VLOG(2) << " seed time shape = " << seed_time_shape->ToString(); + HashSet visited; for (const OpNode* seed_node : order) { @@ -111,6 +122,7 @@ void FindAllConnectedSubgraphForGpuExecOrder(std::vector> // NOTE(chengcheng): ONLY consider GPU op and parallel num > 1. if (seed_parallel_desc.device_type() != DeviceType::kCUDA) { continue; } if (seed_parallel_desc.parallel_num() <= 1) { continue; } + if (!SharedPtrShapeEqual(GetOpNodeTimeShape(seed_node), seed_time_shape)) { continue; } if (IsBreakpointOpNode(seed_node)) { continue; } HashSet this_subgraph; @@ -745,6 +757,8 @@ void InsertNcclLogicalOpsAfterAcc(const OpGraph& op_graph, } } + if (nccl_op_infos.empty()) { return; } + for (const auto* node : after_acc_subgraph_nodes) { ordered_after_acc_subgraph.push_back(node); } CHECK_EQ(after_acc_subgraph_nodes.size(), ordered_after_acc_subgraph.size()); @@ -946,7 +960,9 @@ void InsertBwSinkAccTickAndNcclLogicalOpsInPlacementGroupAfterAcc( << ", we will try insert special identity and ctrl for " << " UNSAFE handle ALL nccl ops between different time shape: " << time_shape_before_acc->DebugStr() << "->acc->" << time_shape_after_acc->DebugStr() - << "\n\n"; + << "\n\n" + << " Debug: before acc op: " << bw_sink_op->op().op_conf().DebugString() + << " -> after acc op: " << first_acc_op->op().op_conf().DebugString(); CHECK_GT(time_shape_before_acc->elem_cnt(), time_shape_after_acc->elem_cnt()); CHECK_EQ(time_shape_before_acc->elem_cnt() % time_shape_after_acc->elem_cnt(), 0); diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index 310ccfba6ef..957d5ae3b19 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -31,6 +31,8 @@ limitations under the License. namespace oneflow { +DEFINE_ENV_BOOL(ENABLE_ACC_CHAIN_MERGE, true); + namespace { class LogicalChainPass final : public JobPass { @@ -51,6 +53,10 @@ class LogicalChainPass final : public JobPass { Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; }; +bool IsTickOpConf(const OperatorConf& conf) { + return IsClassRegistered(conf.op_type_case()); +} + bool IsBreakpointOpNode(const OpNode* node) { // NOTE(chengcheng): breakpoint op is special which CANNOT merge in chain such as: // variable, tick, repeat/acc/pack/unpack change timeshape @@ -59,12 +65,8 @@ bool IsBreakpointOpNode(const OpNode* node) { // TODO(chengcheng): filter ops which has special type // TODO(chengcheng): get stream by op type - if (op_conf.has_variable_conf() /* varialbe */ - || op_conf.has_tick_conf() || op_conf.has_device_tick_conf() - || op_conf.has_src_subset_tick_conf() || op_conf.has_dst_subset_tick_conf() - || op_conf.has_source_tick_conf() || op_conf.has_sink_tick_conf() - || op_conf.has_acc_tick_conf() || op_conf.has_critical_section_wait_tick_conf() - || op_conf.has_critical_section_callback_tick_conf() /* tick */ + if (op_conf.has_variable_conf() /* varialbe */ + || IsTickOpConf(op_conf) /* tick */ || op_conf.has_input_conf() || op_conf.has_output_conf() /* io */ || op_conf.has_wait_and_send_ids_conf() || op_conf.has_callback_notify_conf() /* ctrl */ || op_conf.has_image_decoder_random_crop_resize_conf() /* gpu decode */) { @@ -82,6 +84,13 @@ bool IsBreakpointOpNode(const OpNode* node) { return false; } +bool IsAccOrPackOpNode(const OpNode* node) { + const auto& op_conf = node->op().op_conf(); + return op_conf.has_user_conf() + && (op_conf.user_conf().op_type_name() == "acc" + || op_conf.user_conf().op_type_name() == "pack"); +} + bool IsAccOpNode(const OpNode* node) { return node->op().op_conf().has_user_conf() && node->op().op_conf().user_conf().op_type_name() == "acc"; @@ -178,6 +187,7 @@ struct PlacementLogicalChainsInfo { std::vector ordered_acc_op_nodes; std::shared_ptr after_acc_logical_chain; const ParallelDesc* seed_parallel_desc; + std::shared_ptr seed_time_shape; PlacementLogicalChainsInfo() : seed_parallel_desc(nullptr) {} }; @@ -252,25 +262,40 @@ void CreateAfterAccLogicalChain(const std::shared_ptr& after_acc_l } } -void TryMergeAfterAccLogicalChainToFirstLogicalChain( +void TryMergeAfterAccLogicalChainToLastLogicalChain( PlacementLogicalChainsInfo* info, HashMap* mut_op_name2conf, JobBuilder* job_builder, const std::function& IsReachable) { + if (!EnvBool()) { return; } + const int64_t acc_chain_id = info->after_acc_logical_chain->logical_chain_id; auto& acc_chain_order_ops = info->after_acc_logical_chain->ordered_op_nodes; - const auto& first_chain = info->ordered_logical_chains.front(); - const OpNode* first_chain_src_op = first_chain->ordered_op_nodes.front(); - const OpNode* first_chain_sink_op = first_chain->ordered_op_nodes.back(); + const auto& last_chain = info->ordered_logical_chains.back(); + const OpNode* last_chain_src_op = last_chain->ordered_op_nodes.front(); + const OpNode* last_chain_sink_op = last_chain->ordered_op_nodes.back(); + HashSet last_chain_ops(last_chain->ordered_op_nodes.begin(), + last_chain->ordered_op_nodes.end()); const OpNode* acc_chain_src_op = acc_chain_order_ops.front(); const OpNode* acc_chain_sink_op = acc_chain_order_ops.back(); + // NOTE(chengcheng): find all nontrivial sink consumer ops + HashSet nontrivial_sink_consumers; + for (const OpNode* chain_op : last_chain->ordered_op_nodes) { + chain_op->ForEachNodeOnOutEdge([&](const OpNode* out_node) { + if (last_chain_ops.find(out_node) == last_chain_ops.end() + && !IsTickOpConf(out_node->op().op_conf()) + && SharedPtrShapeEqual(GetOpNodeFastestTimeShape(out_node), info->seed_time_shape)) { + nontrivial_sink_consumers.insert(out_node); + } + }); + } // NOTE(chengcheng): find last op can insert acc ctrl tick. while ((!acc_chain_sink_op->op().op_conf().has_user_conf()) - || IsReachable(acc_chain_sink_op->op().op_name(), first_chain_src_op->op().op_name())) { - VLOG(3) << " cannot insert acc ctrl edge between: [" << first_chain_src_op->op().op_name() + || IsReachable(acc_chain_sink_op->op().op_name(), last_chain_src_op->op().op_name())) { + VLOG(3) << " cannot insert acc ctrl edge between: [" << last_chain_src_op->op().op_name() << "] -> [" << acc_chain_sink_op->op().op_name() << "] , debug info :\n" - << first_chain_src_op->op().op_conf().DebugString() << "\n" + << last_chain_src_op->op().op_conf().DebugString() << "\n" << acc_chain_sink_op->op().op_conf().DebugString() << "\n"; VLOG(3) << "remove op : " << acc_chain_sink_op->op().op_name() @@ -285,11 +310,11 @@ void TryMergeAfterAccLogicalChainToFirstLogicalChain( } if (acc_chain_sink_op == nullptr) { return; } - // NOTE(chengcheng): find first op can insert acc tick. - while (IsReachable(acc_chain_src_op->op().op_name(), first_chain_sink_op->op().op_name())) { - VLOG(3) << " cannot insert acc tick edge between: [" << first_chain_sink_op->op().op_name() + // NOTE(chengcheng): find last op can insert acc tick. + while (IsReachable(acc_chain_src_op->op().op_name(), last_chain_sink_op->op().op_name())) { + VLOG(3) << " cannot insert acc tick edge between: [" << last_chain_sink_op->op().op_name() << "] -> [" << acc_chain_src_op->op().op_name() << "] , debug info :\n" - << first_chain_sink_op->op().op_conf().DebugString() << "\n" + << last_chain_sink_op->op().op_conf().DebugString() << "\n" << acc_chain_src_op->op().op_conf().DebugString() << "\n"; VLOG(3) << "remove op : " << acc_chain_src_op->op().op_name() @@ -305,21 +330,21 @@ void TryMergeAfterAccLogicalChainToFirstLogicalChain( if (acc_chain_src_op == nullptr) { return; } // NOTE(chengcheng): - // 1.add acc ctrl tick between first chain src to acc chain sink for memory lock. + // 1.add acc ctrl tick between last chain src to acc chain sink for memory lock. const int64_t acc_num = job_builder->job().job_conf().num_gradient_accumulation_steps(); CHECK_GT(acc_num, 1); - const auto& fc_src_obns = first_chain_src_op->op().output_bns(); + const auto& fc_src_obns = last_chain_src_op->op().output_bns(); CHECK(!fc_src_obns.empty()); - const std::string& first_chain_src_out_lbn = - GenLogicalBlobName(first_chain_src_op->op().BnInOp2Lbi(fc_src_obns.Get(0))); + const std::string& last_chain_src_out_lbn = + GenLogicalBlobName(last_chain_src_op->op().BnInOp2Lbi(fc_src_obns.Get(0))); - VLOG(3) << " first_chain_src_out_lbn : " << first_chain_src_out_lbn; + VLOG(3) << " last_chain_src_out_lbn : " << last_chain_src_out_lbn; user_op::UserOpConfWrapper acc_ctrl_tick_op = - user_op::UserOpConfWrapperBuilder("Sys-AccCtrlTick4MergeFirstAccChain-" + NewUniqueId()) + user_op::UserOpConfWrapperBuilder("Sys-AccCtrlTick4MergeLastAccChain-" + NewUniqueId()) .OpTypeName("acc_ctrl_tick") - .Input("in", first_chain_src_out_lbn) + .Input("in", last_chain_src_out_lbn) .Output("out") - .ScopeSymbolId(first_chain_src_op->op().op_conf().scope_symbol_id()) + .ScopeSymbolId(last_chain_src_op->op().op_conf().scope_symbol_id()) .Attr("max_acc_num", acc_num) .Build(); @@ -329,40 +354,43 @@ void TryMergeAfterAccLogicalChainToFirstLogicalChain( (*acc_chain_sink_op_conf.mutable_user_conf() ->mutable_input())[user_op::kUserSourceOpTickInputArgName] .add_s(acc_ctrl_tick_op.output("out", 0)); - CHECK_JUST(job_builder->AddOp(first_chain_src_op->parallel_desc().parallel_conf(), + CHECK_JUST(job_builder->AddOp(last_chain_src_op->parallel_desc().parallel_conf(), acc_ctrl_tick_op.op_conf())); - VLOG(3) << " Insert acc ctrl tick between: [" << first_chain_src_op->op().op_name() << "] -> [" + VLOG(3) << " Insert acc ctrl tick between: [" << last_chain_src_op->op().op_name() << "] -> [" << acc_chain_sink_op->op().op_name() << "]"; // NOTE(chengcheng): - // 2.add acc tick between first chain sink to acc chain src for strict exec order. - const auto& fc_sink_obns = first_chain_sink_op->op().output_bns(); + // 2.add acc tick between last chain sink to acc chain src for strict exec order. + const auto& fc_sink_obns = last_chain_sink_op->op().output_bns(); CHECK(!fc_sink_obns.empty()); - const std::string first_chain_sink_lbn = - GenLogicalBlobName(first_chain_sink_op->op().BnInOp2Lbi(fc_sink_obns.Get(0))); - VLOG(3) << " first_chain_sink_lbn : " << first_chain_sink_lbn; + const std::string last_chain_sink_lbn = + GenLogicalBlobName(last_chain_sink_op->op().BnInOp2Lbi(fc_sink_obns.Get(0))); + VLOG(3) << " last_chain_sink_lbn : " << last_chain_sink_lbn; + /* user_op::UserOpConfWrapper cast_to_tick_op = user_op::UserOpConfWrapperBuilder("Sys-LogicalChainSink-CastToTick-" + NewUniqueId()) .OpTypeName("cast_to_tick") - .Input("in", first_chain_sink_lbn) + .Input("in", last_chain_sink_lbn) .Output("out") - .ScopeSymbolId(first_chain_sink_op->op().op_conf().scope_symbol_id()) + .ScopeSymbolId(last_chain_sink_op->op().op_conf().scope_symbol_id()) .Build(); - CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), + CHECK_JUST(job_builder->AddOp(last_chain_sink_op->parallel_desc().parallel_conf(), cast_to_tick_op.op_conf())); std::string acc_tick_output_lbn = cast_to_tick_op.output("out", 0); - if (!IsAccOpNode(first_chain_sink_op)) { + */ + std::string acc_tick_output_lbn = last_chain_sink_lbn; + if (!IsAccOrPackOpNode(last_chain_sink_op)) { // NOTE(chengcheng): Acc Op can be merged in fw/bw chain, if the last op is acc op, // there is no need and CANNOT insert acc tick op. OperatorConf sink_acc_tick_conf; sink_acc_tick_conf.set_name(std::string("Sys-LogicalChainSink-AccTick_") + NewUniqueId()); - sink_acc_tick_conf.set_scope_symbol_id(first_chain_sink_op->op().op_conf().scope_symbol_id()); + sink_acc_tick_conf.set_scope_symbol_id(last_chain_sink_op->op().op_conf().scope_symbol_id()); auto* acc_conf = sink_acc_tick_conf.mutable_acc_tick_conf(); - acc_conf->set_one(cast_to_tick_op.output("out", 0)); + acc_conf->set_one(last_chain_sink_lbn); acc_conf->set_acc("acc"); acc_conf->set_max_acc_num(acc_num); acc_tick_output_lbn = GenLogicalBlobName(sink_acc_tick_conf.name(), "acc"); @@ -370,34 +398,67 @@ void TryMergeAfterAccLogicalChainToFirstLogicalChain( VLOG(3) << " insert acc tick op : " << sink_acc_tick_conf.name() << " of last op in fw/bw chain."; - CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), + CHECK_JUST(job_builder->AddOp(last_chain_sink_op->parallel_desc().parallel_conf(), sink_acc_tick_conf)); } OperatorConf sink_final_tick_conf; sink_final_tick_conf.set_name(std::string("Sys-LogicalChainSink-FinalTick-DeviceTick_") + NewUniqueId()); - sink_final_tick_conf.set_scope_symbol_id(first_chain_sink_op->op().op_conf().scope_symbol_id()); + sink_final_tick_conf.set_scope_symbol_id(last_chain_sink_op->op().op_conf().scope_symbol_id()); auto* tick_conf = sink_final_tick_conf.mutable_device_tick_conf(); tick_conf->add_tick(acc_tick_output_lbn); tick_conf->set_out("out"); - CHECK_JUST(job_builder->AddOp(first_chain_sink_op->parallel_desc().parallel_conf(), + // NOTE(chengcheng): + // 3. Important Tips: If there have nontrivial_sink_consumers, there must insert ctrl + // between sink consumer with acc chain for exec order. + for (const OpNode* sink_consumer : nontrivial_sink_consumers) { + VLOG(2) << " insert acc tick between nontrivial_sink_consumer: [" + << sink_consumer->op().op_name() << "] -> [" << sink_final_tick_conf.name() + << "] for mem safe guard."; + CHECK(!IsReachable(acc_chain_src_op->op().op_name(), sink_consumer->op().op_name())); + const auto& sink_consumer_obns = sink_consumer->op().output_bns(); + CHECK(!sink_consumer_obns.empty()); + std::string sink_consumer_acc_tick_lbn = + GenLogicalBlobName(sink_consumer->op().BnInOp2Lbi(sink_consumer_obns.Get(0))); + if (!IsAccOrPackOpNode(sink_consumer)) { + OperatorConf sink_consumer_acc_tick_conf; + sink_consumer_acc_tick_conf.set_name(std::string("Sys-LogicalChainSinkConsumer-AccTick_") + + NewUniqueId()); + sink_consumer_acc_tick_conf.set_scope_symbol_id( + acc_chain_src_op->op().op_conf().scope_symbol_id()); + auto* acc_conf = sink_consumer_acc_tick_conf.mutable_acc_tick_conf(); + acc_conf->set_one(sink_consumer_acc_tick_lbn); + acc_conf->set_acc("acc"); + acc_conf->set_max_acc_num(acc_num); + sink_consumer_acc_tick_lbn = GenLogicalBlobName(sink_consumer_acc_tick_conf.name(), "acc"); + + VLOG(3) << " insert acc tick op : " << sink_consumer_acc_tick_conf.name() + << " of nontrivial_sink_consumer in fw/bw chain."; + + CHECK_JUST(job_builder->AddOp(last_chain_sink_op->parallel_desc().parallel_conf(), + sink_consumer_acc_tick_conf)); + } + tick_conf->add_tick(sink_consumer_acc_tick_lbn); + } + + CHECK_JUST(job_builder->AddOp(last_chain_sink_op->parallel_desc().parallel_conf(), sink_final_tick_conf)); CHECK_JUST(MapAt(*mut_op_name2conf, acc_chain_src_op->op().op_name())) .add_ctrl_in_op_name(sink_final_tick_conf.name()); - VLOG(3) << " Insert acc tick between: [" << first_chain_sink_op->op().op_name() << "] -> [" + VLOG(3) << " Insert acc tick between: [" << last_chain_sink_op->op().op_name() << "] -> [" << acc_chain_src_op->op().op_name() << "]"; // NOTE(chengcheng): - // 3. merge first chain and acc chain + // 4. merge last chain and acc chain MergedLogicalChainIdGroup* group = job_builder->add_logical_chain_groups(); - group->add_logical_chain_id_list(first_chain->logical_chain_id); + group->add_logical_chain_id_list(last_chain->logical_chain_id); group->add_logical_chain_id_list(acc_chain_id); VLOG(3) << " Merge acc chain : " << acc_chain_id - << " to first logcal chain : " << first_chain->logical_chain_id; + << " to last logcal chain : " << last_chain->logical_chain_id; } Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_builder) const { @@ -445,6 +506,7 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui if (it == placement2logical_chains.end()) { it = placement2logical_chains.emplace(key, PlacementLogicalChainsInfo()).first; it->second.seed_parallel_desc = &this_parallel_desc; + it->second.seed_time_shape = seed_time_shape; } auto& info = it->second; info.ordered_logical_chains.emplace_back(std::make_shared()); @@ -514,7 +576,7 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui } } - // NOTE(chengcheng): create logical chain after acc, and merge with first logical chain. + // NOTE(chengcheng): create logical chain after acc, and merge with last logical chain. const std::vector& ordered_acc_op_nodes = info.ordered_acc_op_nodes; if (!ordered_acc_op_nodes.empty()) { info.after_acc_logical_chain = std::make_shared(); @@ -525,8 +587,8 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui info.after_acc_logical_chain->logical_chain_id = NewLogicalChainId(); std::sort(acc_chain_order_ops.begin(), acc_chain_order_ops.end(), CmpOpNodeOrder); - TryMergeAfterAccLogicalChainToFirstLogicalChain(&info, &mut_op_name2conf, job_builder, - IsReachable); + TryMergeAfterAccLogicalChainToLastLogicalChain(&info, &mut_op_name2conf, job_builder, + IsReachable); if (acc_chain_order_ops.size() <= 1) { continue; } From 724fe49b500ca47b45e1893b18a48b4cc6aebf39 Mon Sep 17 00:00:00 2001 From: chengtbv <472491134@qq.com> Date: Tue, 11 Oct 2022 05:23:49 +0000 Subject: [PATCH 37/66] reuse cast to tick op --- .../core/job_rewriter/logical_chain_pass.cpp | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index 957d5ae3b19..ddd5b0a35ee 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -53,8 +53,13 @@ class LogicalChainPass final : public JobPass { Maybe Apply(const OpGraph& op_graph, JobBuilder* job_builder) const; }; -bool IsTickOpConf(const OperatorConf& conf) { - return IsClassRegistered(conf.op_type_case()); +bool IsTickOpConf(const OperatorConf& op_conf) { + if (IsClassRegistered(op_conf.op_type_case())) { return true; } + if (op_conf.has_user_conf()) { + const std::string& user_type_name = op_conf.user_conf().op_type_name(); + if (user_type_name == "cast_to_tick" || user_type_name == "acc_ctrl_tick") { return true; } + } + return false; } bool IsBreakpointOpNode(const OpNode* node) { @@ -77,7 +82,7 @@ bool IsBreakpointOpNode(const OpNode* node) { const std::string& user_type_name = op_conf.user_conf().op_type_name(); if (user_type_name == "repeat" || user_type_name == "pack" || user_type_name == "unpack" || user_type_name == "identity_buffer" || user_type_name == "copy_h2d" - || user_type_name == "copy_d2h" || user_type_name == "acc_ctrl_tick") { + || user_type_name == "copy_d2h") { return true; } } @@ -367,7 +372,6 @@ void TryMergeAfterAccLogicalChainToLastLogicalChain( GenLogicalBlobName(last_chain_sink_op->op().BnInOp2Lbi(fc_sink_obns.Get(0))); VLOG(3) << " last_chain_sink_lbn : " << last_chain_sink_lbn; - /* user_op::UserOpConfWrapper cast_to_tick_op = user_op::UserOpConfWrapperBuilder("Sys-LogicalChainSink-CastToTick-" + NewUniqueId()) .OpTypeName("cast_to_tick") @@ -380,8 +384,6 @@ void TryMergeAfterAccLogicalChainToLastLogicalChain( cast_to_tick_op.op_conf())); std::string acc_tick_output_lbn = cast_to_tick_op.output("out", 0); - */ - std::string acc_tick_output_lbn = last_chain_sink_lbn; if (!IsAccOrPackOpNode(last_chain_sink_op)) { // NOTE(chengcheng): Acc Op can be merged in fw/bw chain, if the last op is acc op, // there is no need and CANNOT insert acc tick op. @@ -420,8 +422,22 @@ void TryMergeAfterAccLogicalChainToLastLogicalChain( CHECK(!IsReachable(acc_chain_src_op->op().op_name(), sink_consumer->op().op_name())); const auto& sink_consumer_obns = sink_consumer->op().output_bns(); CHECK(!sink_consumer_obns.empty()); - std::string sink_consumer_acc_tick_lbn = + + std::string sink_consumer_output_lbn = GenLogicalBlobName(sink_consumer->op().BnInOp2Lbi(sink_consumer_obns.Get(0))); + user_op::UserOpConfWrapper sink_consumer_cast_to_tick_op = + user_op::UserOpConfWrapperBuilder("Sys-LogicalChainSinkConsumer-CastToTick-" + + NewUniqueId()) + .OpTypeName("cast_to_tick") + .Input("in", sink_consumer_output_lbn) + .Output("out") + .ScopeSymbolId(sink_consumer->op().op_conf().scope_symbol_id()) + .Build(); + + CHECK_JUST(job_builder->AddOp(sink_consumer->parallel_desc().parallel_conf(), + sink_consumer_cast_to_tick_op.op_conf())); + + std::string sink_consumer_acc_tick_lbn = sink_consumer_cast_to_tick_op.output("out", 0); if (!IsAccOrPackOpNode(sink_consumer)) { OperatorConf sink_consumer_acc_tick_conf; sink_consumer_acc_tick_conf.set_name(std::string("Sys-LogicalChainSinkConsumer-AccTick_") @@ -437,7 +453,7 @@ void TryMergeAfterAccLogicalChainToLastLogicalChain( VLOG(3) << " insert acc tick op : " << sink_consumer_acc_tick_conf.name() << " of nontrivial_sink_consumer in fw/bw chain."; - CHECK_JUST(job_builder->AddOp(last_chain_sink_op->parallel_desc().parallel_conf(), + CHECK_JUST(job_builder->AddOp(sink_consumer->parallel_desc().parallel_conf(), sink_consumer_acc_tick_conf)); } tick_conf->add_tick(sink_consumer_acc_tick_lbn); From 54cb129d774b3b7565a8931eae6a4c75d3abd4be Mon Sep 17 00:00:00 2001 From: chengtbv <472491134@qq.com> Date: Tue, 11 Oct 2022 09:32:18 +0000 Subject: [PATCH 38/66] fix bug of acc different stream hint cause sync backward compute --- oneflow/core/graph/task_graph.cpp | 7 +++-- .../core/job_rewriter/logical_chain_pass.cpp | 28 ++++++++++++++++++- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 7d1c72d1d23..76e4eb33e7d 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -421,9 +421,10 @@ void ForEachOpGraphNecessaryCtrlEdge( dst_time_shape = CHECK_JUST(dst->op().GetOpTimeShape()).get(); } if (src_time_shape->elem_cnt() != dst_time_shape->elem_cnt()) { - // NOTE(chengcheng): acc op node can be merged and add ctrl edge. - CHECK(src->op().op_conf().has_user_conf() - && src->op().op_conf().user_conf().op_type_name() == "acc"); + // NOTE(chengcheng): acc / pack op node can be merged and add ctrl edge. + CHECK(src->op().op_conf().has_user_conf()); + const std::string& op_type_name = src->op().op_conf().user_conf().op_type_name(); + CHECK(op_type_name == "acc" || op_type_name == "pack"); const Shape* src_input_time_shape = CHECK_JUST(src->op().GetInputBlobFastestTimeShape()).get(); CHECK_EQ(src_input_time_shape->elem_cnt(), dst_time_shape->elem_cnt()); diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index ddd5b0a35ee..1783d6c0f55 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -80,7 +80,7 @@ bool IsBreakpointOpNode(const OpNode* node) { if (op_conf.has_user_conf()) { const std::string& user_type_name = op_conf.user_conf().op_type_name(); - if (user_type_name == "repeat" || user_type_name == "pack" || user_type_name == "unpack" + if (user_type_name == "repeat" || user_type_name == "unpack" || user_type_name == "identity_buffer" || user_type_name == "copy_h2d" || user_type_name == "copy_d2h") { return true; @@ -566,6 +566,30 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui } }; + auto FixLogicalChainOpStreamHint = [&](const std::vector& ordered_op_nodes) { + std::string stream_index_name = ""; + for (const OpNode* op_node : ordered_op_nodes) { + const OperatorConf& op_conf = op_node->op().op_conf(); + if (op_conf.has_stream_name_hint() && !op_conf.stream_name_hint().empty()) { + if (stream_index_name.empty()) { + stream_index_name = op_conf.stream_name_hint(); + } else { + CHECK_EQ(stream_index_name, op_conf.stream_name_hint()); + } + } + } + + if (!stream_index_name.empty()) { + for (const OpNode* op_node : ordered_op_nodes) { + OperatorConf& op_conf = CHECK_JUST(MapAt(mut_op_name2conf, op_node->op().op_name())); + if (!op_conf.has_stream_name_hint()) { + op_conf.set_stream_name_hint(stream_index_name); + VLOG(3) << " Op: " << op_conf.name() << " fix stream name hint : " << stream_index_name; + } + } + } + }; + for (auto& pair : placement2logical_chains) { const auto& placement = pair.first; auto& info = pair.second; @@ -576,6 +600,8 @@ Maybe LogicalChainPass::Apply(const OpGraph& op_graph, JobBuilder* job_bui for (auto& logical_chain : info.ordered_logical_chains) { logical_chain->logical_chain_id = NewLogicalChainId(); InsertLogicalChainId(logical_chain->ordered_op_nodes, logical_chain->logical_chain_id); + // TODO(chengcheng): rm fix hint and use thrd id in logical op node. + FixLogicalChainOpStreamHint(logical_chain->ordered_op_nodes); InsertCtrlEdgeInChain(logical_chain->ordered_op_nodes); } From 0527ce18590c979ea0fc9e268a12c35a941bd194 Mon Sep 17 00:00:00 2001 From: chengtbv <472491134@qq.com> Date: Tue, 11 Oct 2022 09:34:38 +0000 Subject: [PATCH 39/66] actor name log --- oneflow/core/lazy/actor/actor.cpp | 18 ++++++++++++++++++ oneflow/core/lazy/actor/actor.h | 2 ++ oneflow/core/lazy/actor/light_actor.cpp | 19 +++++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/oneflow/core/lazy/actor/actor.cpp b/oneflow/core/lazy/actor/actor.cpp index d3cf3a21f04..c3e15a7b215 100644 --- a/oneflow/core/lazy/actor/actor.cpp +++ b/oneflow/core/lazy/actor/actor.cpp @@ -132,11 +132,13 @@ void Actor::Init(const JobDesc* job_desc, ActorContext* actor_ctx) { actor_id_ = task_proto.task_id(); thrd_id_ = ThrdId4ActorId(actor_id_); job_id_ = task_proto.job_id(); + op_name_ = "NULL_OP"; for (const ExecNodeProto& node : task_proto.exec_sequence().exec_node()) { ExecKernel ek; ek.kernel_ctx.reset(new KernelContextImpl(actor_ctx)); ek.kernel = ConstructKernel(node.kernel_conf(), ek.kernel_ctx.get()); exec_kernel_vec_.emplace_back(std::move(ek)); + op_name_ = node.kernel_conf().op_attribute().op_conf().name(); } is_kernel_launch_synchronized_ = @@ -144,6 +146,15 @@ void Actor::Init(const JobDesc* job_desc, ActorContext* actor_ctx) { [](const ExecKernel& ek) { return ek.kernel->IsKernelLaunchSynchronized(); }); if (!is_kernel_launch_synchronized_) { CHECK_EQ(exec_kernel_vec_.size(), 1); } + /* + if (is_kernel_launch_synchronized_ == 0) { + LOG(WARNING) << "ccdebuglog: actor_id: " << actor_id_ + << " IsKernelLaunchSynchronized: " << is_kernel_launch_synchronized_; + } + LOG(INFO) << "ccdebuglog: actor_id: " << actor_id_ + << " IsKernelLaunchSynchronized: " << is_kernel_launch_synchronized_; + */ + remaining_eord_cnt_ = 0; msg_handler_ = nullptr; eord_regst_desc_ids_.clear(); @@ -662,6 +673,13 @@ void Actor::EnqueueAsyncMsg(const ActorMsg& msg) { Singleton::Get()->SendMsg(msg); } else { async_msg_queue_.emplace_back(msg); + /* + LOG(INFO) << "actor async post msg by: " << op_name_ << " actor_id: " << actor_id_ + << " dst_actor_id: " << msg.dst_actor_id() + << " is_kernel_launch_synchronized: " << is_kernel_launch_synchronized_ + << " thrd_id_ = " << thrd_id_ + << " dst_thrd_id = " << ThrdId4ActorId(msg.dst_actor_id()); + */ } } diff --git a/oneflow/core/lazy/actor/actor.h b/oneflow/core/lazy/actor/actor.h index 898366af446..535b861c78b 100644 --- a/oneflow/core/lazy/actor/actor.h +++ b/oneflow/core/lazy/actor/actor.h @@ -215,6 +215,8 @@ class Actor : public ActorBase { std::deque async_msg_queue_; bool is_kernel_launch_synchronized_; std::vector tmp_regst_desc_id_vec_; + + std::string op_name_; }; } // namespace oneflow diff --git a/oneflow/core/lazy/actor/light_actor.cpp b/oneflow/core/lazy/actor/light_actor.cpp index 52922c0fabb..107917694fb 100644 --- a/oneflow/core/lazy/actor/light_actor.cpp +++ b/oneflow/core/lazy/actor/light_actor.cpp @@ -225,9 +225,11 @@ class LightActor : public ActorBase, public KernelContext, public ActorContextPr void Init(const JobDesc* job_desc, ActorContext* actor_ctx) override { const TaskProto& task_proto = actor_ctx->task_proto(); CHECK_EQ(task_proto.exec_sequence().exec_node_size(), 1); + op_name_ = "NULL_OP"; if (exec_kernel) { kernel_info_[0].reset(new KernelInfo()); const KernelConf& kernel_conf = task_proto.exec_sequence().exec_node(0).kernel_conf(); + op_name_ = kernel_conf.op_attribute().op_conf().name(); kernel_info_[0]->kernel = ConstructKernel(kernel_conf, this); #ifdef WITH_CUDA_GRAPHS auto* cuda_stream = dynamic_cast(actor_ctx->stream_ctx()->stream()); @@ -347,6 +349,14 @@ class LightActor : public ActorBase, public KernelContext, public ActorContextPr const bool is_kernel_launch_synchronized = (!exec_kernel) || kernel_info_[0]->kernel->IsKernelLaunchSynchronized(); const int64_t actor_id = actor_ctx_->task_proto().task_id(); + /* + if (is_kernel_launch_synchronized == 0) { + LOG(WARNING) << "ccdebuglog: actor_id: " << actor_id + << " IsKernelLaunchSynchronized: " << is_kernel_launch_synchronized; + } + LOG(INFO) << "ccdebuglog: actor_id: " << actor_id << " op_name: " << op_name_; + << " IsKernelLaunchSynchronized: " << is_kernel_launch_synchronized; + */ const int64_t thrd_id = ThrdId4ActorId(actor_id); auto IsSyncMsg = [&](const ActorMsg& msg) { return is_kernel_launch_synchronized && thrd_id == ThrdId4ActorId(msg.dst_actor_id()); @@ -356,6 +366,13 @@ class LightActor : public ActorBase, public KernelContext, public ActorContextPr sync_post_act_msgs_.emplace_back(msg); } else { async_post_act_msgs_.emplace_back(msg); + /* + LOG(INFO) << "actor async post msg by: " << op_name_ << " actor_id: " << actor_id + << " dst_actor_id: " << msg.dst_actor_id() + << " is_kernel_launch_synchronized: " << is_kernel_launch_synchronized + << " thrd_id_ = " << thrd_id + << " dst_thrd_id = " << ThrdId4ActorId(msg.dst_actor_id()); + */ } }; std::vector index2regst_desc_id; @@ -605,6 +622,8 @@ class LightActor : public ActorBase, public KernelContext, public ActorContextPr std::vector sync_post_act_msgs_; std::vector async_post_act_msgs_; KernelObserver* stream_kernel_observer_; + + std::string op_name_; }; template Date: Tue, 11 Oct 2022 10:02:38 +0000 Subject: [PATCH 40/66] fix for review --- oneflow/core/job/plan_util.cpp | 5 +++-- oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp | 13 +++++++------ oneflow/core/lazy/actor/repeat_actor.cpp | 13 +++++++------ oneflow/user/ops/acc_ctrl_tick_op.cpp | 1 + 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index 475c90fccb0..da3887d0350 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -80,8 +80,9 @@ void SetVariableOpNamesForVariableAndRepeatRegst(Plan* plan) { RegstDescProto* regst = PlanUtil::GetSoleProducedDataRegst(task); CHECK(regst->has_force_inplace_consumed_regst_desc_id()); int64_t force_inplace_regst_id = regst->force_inplace_consumed_regst_desc_id(); - if (regst_id2var_name.find(force_inplace_regst_id) != regst_id2var_name.end()) { - regst->set_variable_op_name(regst_id2var_name.at(force_inplace_regst_id)); + auto var_name_it = regst_id2var_name.find(force_inplace_regst_id); + if (var_name_it != regst_id2var_name.end()) { + regst->set_variable_op_name(var_name_it->second); VLOG(3) << " set var op name to repeat regst : " << regst->DebugString(); } } diff --git a/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp b/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp index cf68ee0c2fc..e879f3364aa 100644 --- a/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp +++ b/oneflow/core/lazy/actor/acc_ctrl_tick_actor.cpp @@ -87,10 +87,10 @@ void AccCtrlTickActor::VirtualActorInit(const TaskProto& proto) { // input const auto& consumed_ids = proto.consumed_regst_desc_id(); CHECK_EQ(consumed_ids.size(), 1); - CHECK(consumed_ids.find("in") != consumed_ids.end()); - const auto& in_ids = consumed_ids.at("in"); - CHECK_EQ(in_ids.regst_desc_id_size(), 1); - consumed_tick_regst_desc_id_ = in_ids.regst_desc_id(0); + auto in_it = consumed_ids.find("in"); + CHECK(in_it != consumed_ids.end()); + CHECK_EQ(in_it->second.regst_desc_id_size(), 1); + consumed_tick_regst_desc_id_ = in_it->second.regst_desc_id(0); consumed_tick_rs_.InsertRegstDescId(consumed_tick_regst_desc_id_); consumed_tick_rs_.InitedDone(); @@ -99,8 +99,9 @@ void AccCtrlTickActor::VirtualActorInit(const TaskProto& proto) { const auto& produced_ids = proto.produced_regst_desc(); CHECK_EQ(produced_ids.size(), 1); - CHECK(produced_ids.find("out") != produced_ids.end()); - const RegstDescProto& out_regst_desc = produced_ids.at("out"); + auto out_it = produced_ids.find("out"); + CHECK(out_it != produced_ids.end()); + const RegstDescProto& out_regst_desc = out_it->second; produced_tick_regst_desc_id_ = out_regst_desc.regst_desc_id(); produced_tick_rs_.InsertRegstDescId(produced_tick_regst_desc_id_); produced_tick_rs_.InitedDone(); diff --git a/oneflow/core/lazy/actor/repeat_actor.cpp b/oneflow/core/lazy/actor/repeat_actor.cpp index 899102b4245..884f1928fea 100644 --- a/oneflow/core/lazy/actor/repeat_actor.cpp +++ b/oneflow/core/lazy/actor/repeat_actor.cpp @@ -96,17 +96,18 @@ void RepeatActor::VirtualActorInit(const TaskProto& proto) { // input const auto& consumed_ids = proto.consumed_regst_desc_id(); - CHECK(consumed_ids.find("in") != consumed_ids.end()); - const auto& in_ids = consumed_ids.at("in"); - CHECK_EQ(in_ids.regst_desc_id_size(), 1); - consumed_var_regst_desc_id_ = in_ids.regst_desc_id(0); + auto in_it = consumed_ids.find("in"); + CHECK(in_it != consumed_ids.end()); + CHECK_EQ(in_it->second.regst_desc_id_size(), 1); + consumed_var_regst_desc_id_ = in_it->second.regst_desc_id(0); consumed_var_rs_.InsertRegstDescId(consumed_var_regst_desc_id_); consumed_var_rs_.InitedDone(); // output const auto& produced_ids = proto.produced_regst_desc(); - CHECK(produced_ids.find("out") != produced_ids.end()); - const RegstDescProto& out_regst_desc = produced_ids.at("out"); + auto out_it = produced_ids.find("out"); + CHECK(out_it != produced_ids.end()); + const RegstDescProto& out_regst_desc = out_it->second; CHECK(!out_regst_desc.enable_reuse_mem()); CHECK_EQ(out_regst_desc.register_num(), 1); // check inplace diff --git a/oneflow/user/ops/acc_ctrl_tick_op.cpp b/oneflow/user/ops/acc_ctrl_tick_op.cpp index f97fcadd9f9..dbbecdf0e3c 100644 --- a/oneflow/user/ops/acc_ctrl_tick_op.cpp +++ b/oneflow/user/ops/acc_ctrl_tick_op.cpp @@ -69,6 +69,7 @@ namespace oneflow { time_shape_dim_vec.back() /= max_acc_num; } else { const int64_t elem_cnt = in_time_shape.elem_cnt(); + CHECK_EQ_OR_RETURN(elem_cnt % max_acc_num, 0); time_shape_dim_vec.resize(1); time_shape_dim_vec.back() = elem_cnt / max_acc_num; } From fe49a4798aa1c5b67218c90822887ba9e2df0f12 Mon Sep 17 00:00:00 2001 From: chengtbv <472491134@qq.com> Date: Tue, 11 Oct 2022 10:06:35 +0000 Subject: [PATCH 41/66] remove log --- oneflow/core/lazy/actor/actor.cpp | 18 ------------------ oneflow/core/lazy/actor/actor.h | 2 -- oneflow/core/lazy/actor/light_actor.cpp | 19 ------------------- 3 files changed, 39 deletions(-) diff --git a/oneflow/core/lazy/actor/actor.cpp b/oneflow/core/lazy/actor/actor.cpp index c3e15a7b215..d3cf3a21f04 100644 --- a/oneflow/core/lazy/actor/actor.cpp +++ b/oneflow/core/lazy/actor/actor.cpp @@ -132,13 +132,11 @@ void Actor::Init(const JobDesc* job_desc, ActorContext* actor_ctx) { actor_id_ = task_proto.task_id(); thrd_id_ = ThrdId4ActorId(actor_id_); job_id_ = task_proto.job_id(); - op_name_ = "NULL_OP"; for (const ExecNodeProto& node : task_proto.exec_sequence().exec_node()) { ExecKernel ek; ek.kernel_ctx.reset(new KernelContextImpl(actor_ctx)); ek.kernel = ConstructKernel(node.kernel_conf(), ek.kernel_ctx.get()); exec_kernel_vec_.emplace_back(std::move(ek)); - op_name_ = node.kernel_conf().op_attribute().op_conf().name(); } is_kernel_launch_synchronized_ = @@ -146,15 +144,6 @@ void Actor::Init(const JobDesc* job_desc, ActorContext* actor_ctx) { [](const ExecKernel& ek) { return ek.kernel->IsKernelLaunchSynchronized(); }); if (!is_kernel_launch_synchronized_) { CHECK_EQ(exec_kernel_vec_.size(), 1); } - /* - if (is_kernel_launch_synchronized_ == 0) { - LOG(WARNING) << "ccdebuglog: actor_id: " << actor_id_ - << " IsKernelLaunchSynchronized: " << is_kernel_launch_synchronized_; - } - LOG(INFO) << "ccdebuglog: actor_id: " << actor_id_ - << " IsKernelLaunchSynchronized: " << is_kernel_launch_synchronized_; - */ - remaining_eord_cnt_ = 0; msg_handler_ = nullptr; eord_regst_desc_ids_.clear(); @@ -673,13 +662,6 @@ void Actor::EnqueueAsyncMsg(const ActorMsg& msg) { Singleton::Get()->SendMsg(msg); } else { async_msg_queue_.emplace_back(msg); - /* - LOG(INFO) << "actor async post msg by: " << op_name_ << " actor_id: " << actor_id_ - << " dst_actor_id: " << msg.dst_actor_id() - << " is_kernel_launch_synchronized: " << is_kernel_launch_synchronized_ - << " thrd_id_ = " << thrd_id_ - << " dst_thrd_id = " << ThrdId4ActorId(msg.dst_actor_id()); - */ } } diff --git a/oneflow/core/lazy/actor/actor.h b/oneflow/core/lazy/actor/actor.h index 535b861c78b..898366af446 100644 --- a/oneflow/core/lazy/actor/actor.h +++ b/oneflow/core/lazy/actor/actor.h @@ -215,8 +215,6 @@ class Actor : public ActorBase { std::deque async_msg_queue_; bool is_kernel_launch_synchronized_; std::vector tmp_regst_desc_id_vec_; - - std::string op_name_; }; } // namespace oneflow diff --git a/oneflow/core/lazy/actor/light_actor.cpp b/oneflow/core/lazy/actor/light_actor.cpp index 107917694fb..52922c0fabb 100644 --- a/oneflow/core/lazy/actor/light_actor.cpp +++ b/oneflow/core/lazy/actor/light_actor.cpp @@ -225,11 +225,9 @@ class LightActor : public ActorBase, public KernelContext, public ActorContextPr void Init(const JobDesc* job_desc, ActorContext* actor_ctx) override { const TaskProto& task_proto = actor_ctx->task_proto(); CHECK_EQ(task_proto.exec_sequence().exec_node_size(), 1); - op_name_ = "NULL_OP"; if (exec_kernel) { kernel_info_[0].reset(new KernelInfo()); const KernelConf& kernel_conf = task_proto.exec_sequence().exec_node(0).kernel_conf(); - op_name_ = kernel_conf.op_attribute().op_conf().name(); kernel_info_[0]->kernel = ConstructKernel(kernel_conf, this); #ifdef WITH_CUDA_GRAPHS auto* cuda_stream = dynamic_cast(actor_ctx->stream_ctx()->stream()); @@ -349,14 +347,6 @@ class LightActor : public ActorBase, public KernelContext, public ActorContextPr const bool is_kernel_launch_synchronized = (!exec_kernel) || kernel_info_[0]->kernel->IsKernelLaunchSynchronized(); const int64_t actor_id = actor_ctx_->task_proto().task_id(); - /* - if (is_kernel_launch_synchronized == 0) { - LOG(WARNING) << "ccdebuglog: actor_id: " << actor_id - << " IsKernelLaunchSynchronized: " << is_kernel_launch_synchronized; - } - LOG(INFO) << "ccdebuglog: actor_id: " << actor_id << " op_name: " << op_name_; - << " IsKernelLaunchSynchronized: " << is_kernel_launch_synchronized; - */ const int64_t thrd_id = ThrdId4ActorId(actor_id); auto IsSyncMsg = [&](const ActorMsg& msg) { return is_kernel_launch_synchronized && thrd_id == ThrdId4ActorId(msg.dst_actor_id()); @@ -366,13 +356,6 @@ class LightActor : public ActorBase, public KernelContext, public ActorContextPr sync_post_act_msgs_.emplace_back(msg); } else { async_post_act_msgs_.emplace_back(msg); - /* - LOG(INFO) << "actor async post msg by: " << op_name_ << " actor_id: " << actor_id - << " dst_actor_id: " << msg.dst_actor_id() - << " is_kernel_launch_synchronized: " << is_kernel_launch_synchronized - << " thrd_id_ = " << thrd_id - << " dst_thrd_id = " << ThrdId4ActorId(msg.dst_actor_id()); - */ } }; std::vector index2regst_desc_id; @@ -622,8 +605,6 @@ class LightActor : public ActorBase, public KernelContext, public ActorContextPr std::vector sync_post_act_msgs_; std::vector async_post_act_msgs_; KernelObserver* stream_kernel_observer_; - - std::string op_name_; }; template Date: Tue, 11 Oct 2022 10:09:04 +0000 Subject: [PATCH 42/66] fix note --- oneflow/core/job/compiler.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index f7f238b4978..217aad0dd2f 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -104,7 +104,7 @@ void Compiler::Compile(Job* job, Plan* plan) const { auto* job_id2job_conf = plan->mutable_job_confs()->mutable_job_id2job_conf(); (*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf(); // NOTE(chengcheng): infer mem blob id & set inplace & add ctrl - // TODO(chengcheng): set inplace hint by cpu regst + // TODO(chengcheng): set inplace hint for cpu regst IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan, IsReachable); PlanUtil::MergeMemBlockIdByLogicalChainId(plan, *job); PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan); From ff85c5bdad56d948978929def4da9ae844d12325 Mon Sep 17 00:00:00 2001 From: chengtbv <472491134@qq.com> Date: Mon, 17 Oct 2022 11:45:11 +0000 Subject: [PATCH 43/66] fix bug of connect to cast to tick op --- oneflow/core/job_rewriter/logical_chain_pass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/job_rewriter/logical_chain_pass.cpp b/oneflow/core/job_rewriter/logical_chain_pass.cpp index 1783d6c0f55..2139e93b34c 100644 --- a/oneflow/core/job_rewriter/logical_chain_pass.cpp +++ b/oneflow/core/job_rewriter/logical_chain_pass.cpp @@ -392,7 +392,7 @@ void TryMergeAfterAccLogicalChainToLastLogicalChain( sink_acc_tick_conf.set_name(std::string("Sys-LogicalChainSink-AccTick_") + NewUniqueId()); sink_acc_tick_conf.set_scope_symbol_id(last_chain_sink_op->op().op_conf().scope_symbol_id()); auto* acc_conf = sink_acc_tick_conf.mutable_acc_tick_conf(); - acc_conf->set_one(last_chain_sink_lbn); + acc_conf->set_one(cast_to_tick_op.output("out", 0)); acc_conf->set_acc("acc"); acc_conf->set_max_acc_num(acc_num); acc_tick_output_lbn = GenLogicalBlobName(sink_acc_tick_conf.name(), "acc"); From ee1f717942a634f1917dc4775de8c01746445bcd Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Mon, 24 Oct 2022 17:09:32 +0800 Subject: [PATCH 44/66] refactor(RanddomOp): refactor random op with consistent data --- oneflow/core/functional/impl/common.cpp | 12 ++++ oneflow/core/functional/impl/common.h | 2 + oneflow/core/functional/impl/nn_functor.cpp | 21 ++++-- .../distributions/normal_distribution.cu | 3 +- oneflow/user/kernels/random_seed_util.cpp | 66 +++++++++---------- oneflow/user/kernels/random_seed_util.h | 8 ++- 6 files changed, 70 insertions(+), 42 deletions(-) diff --git a/oneflow/core/functional/impl/common.cpp b/oneflow/core/functional/impl/common.cpp index 342a7890021..a70a8d6c5a0 100644 --- a/oneflow/core/functional/impl/common.cpp +++ b/oneflow/core/functional/impl/common.cpp @@ -17,6 +17,8 @@ limitations under the License. #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/common/wrap_dim_utils.h" +#include "oneflow/core/ccl/ccl.h" +#include "oneflow/core/job/rank_group.h" namespace oneflow { namespace one { @@ -220,6 +222,16 @@ Maybe> InferUnifiedShapeForBroadcasting(const Shap return std::make_tuple(target, need_to_broadcast.first, need_to_broadcast.second); } +Maybe BroadcastSeedToAllRanks(uint64_t* seed, int64_t root) { + CHECK_NOTNULL_OR_RETURN(seed) << "seed is not allowed to be nullptr"; + const auto& rank_group = JUST(RankGroup::DefaultRankGroup()); + const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(DeviceType::kCPU, rank_group)); + const auto& meta_transport_token = + JUST(TransportToken::NewTransportToken(kTransportTokenTypeMeta)); + JUST(ccl::CpuBroadcast(seed, seed, sizeof(*seed), root, parallel_desc, meta_transport_token)); + return Maybe::Ok(); +} + } // namespace functional } // namespace one } // namespace oneflow diff --git a/oneflow/core/functional/impl/common.h b/oneflow/core/functional/impl/common.h index 9a8ae1ee471..d9567667210 100644 --- a/oneflow/core/functional/impl/common.h +++ b/oneflow/core/functional/impl/common.h @@ -43,6 +43,8 @@ Maybe InferShapeUnspecifiedDim(const int64_t& elem_count, const Shape& sh Maybe> InferUnifiedShapeForBroadcasting(const Shape& input_shape, const Shape& other_shape); +Maybe BroadcastSeedToAllRanks(uint64_t* seed, int64_t root = 0); + } // namespace functional } // namespace one } // namespace oneflow diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 3aa81d2430f..478e634519b 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -26,6 +26,8 @@ limitations under the License. #include "oneflow/user/kernels/dropout_kernel.h" #include "oneflow/core/common/container_util.h" #include "oneflow/user/kernels/distributions/common.h" +#include "oneflow/user/kernels/random_seed_util.h" +#include "oneflow/core/rpc/include/global_process_ctx.h" namespace oneflow { namespace one { @@ -2105,20 +2107,31 @@ class GlobalNormalFunctor { dtype = output_tensor_dtype; } - const auto gen = optional_generator.value_or(JUST(one::DefaultAutoGenerator())); auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("mean", "std", "shape", "dtype", "seed", "nd_sbp"); - const auto& distribution_state = std::make_shared(gen); const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); + std::shared_ptr gen = optional_generator.value_or(JUST(one::DefaultAutoGenerator())); + uint64_t init_seed = JUST(gen->Get(0))->engine()(); + if (LazyMode::is_enabled()) { attrs.SetAllAttrs(static_cast(mean), static_cast(std), shape, - dtype->data_type(), static_cast(gen->current_seed()), + dtype->data_type(), static_cast(init_seed), *JUST(GetNdSbpStrList(nd_sbp))); } else { + uint64_t rank_seed = 0; + { + JUST(BroadcastSeedToAllRanks(&init_seed, /*root=*/0)); + rank_seed = + JUST(GetRandomSeedForRank(*placement, *nd_sbp, init_seed, GlobalProcessCtx::Rank())); + } attrs.SetAllAttrs(static_cast(mean), static_cast(std), shape, - dtype->data_type(), static_cast(gen->current_seed()), NullOpt); + dtype->data_type(), static_cast(rank_seed), NullOpt); + gen = JUST(MakeGenerator(placement->device_type())); + gen->set_current_seed(rank_seed); } + const auto& distribution_state = std::make_shared(gen); + if (out.has_value()) { std::shared_ptr outputs = std::make_shared(1); (*outputs)[0] = JUST(out); diff --git a/oneflow/user/kernels/distributions/normal_distribution.cu b/oneflow/user/kernels/distributions/normal_distribution.cu index e55965a6d05..97fe85dccf9 100644 --- a/oneflow/user/kernels/distributions/normal_distribution.cu +++ b/oneflow/user/kernels/distributions/normal_distribution.cu @@ -64,8 +64,7 @@ void NormalDistribution::operator()( ep::Stream* stream, const int64_t elem_cnt, T* dptr, const std::shared_ptr& generator) const { CHECK_GE(elem_cnt, 0); - const auto device_index = stream->device()->device_index(); - auto gen = CHECK_JUST(generator->Get(device_index)); + auto gen = CHECK_JUST(generator->Get()); int32_t block_num = gen->max_block_num(); int32_t thread_num = gen->max_thread_num(); auto* curand_states = gen->curand_states(); diff --git a/oneflow/user/kernels/random_seed_util.cpp b/oneflow/user/kernels/random_seed_util.cpp index 77c5c410a46..69daa515673 100644 --- a/oneflow/user/kernels/random_seed_util.cpp +++ b/oneflow/user/kernels/random_seed_util.cpp @@ -18,47 +18,47 @@ limitations under the License. namespace oneflow { -Maybe GetOpKernelRandomSeed(const user_op::KernelInitContext* ctx) { +Maybe GetOpKernelRandomSeed(const user_op::KernelInitContext* ctx) { int64_t seed = ctx->Attr("seed"); if (!ctx->Attr("has_seed")) { seed = NewRandomSeed(); } return GetOpKernelRandomSeedInCurrentRank(ctx, seed); } -Maybe GetOpKernelRandomSeedInCurrentRank(const user_op::KernelInitContext* ctx, - int64_t init_seed) { - int64_t seed = init_seed; - int64_t parallel_num = ctx->parallel_ctx().parallel_num(); - const auto& outputs = ctx->outputs(); - CHECK_EQ(outputs.size(), 1); - if (parallel_num > 1) { - const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); - int64_t parallel_id = ctx->parallel_ctx().parallel_id(); - const NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex(JUST(VectorAt(outputs, 0)).first, - JUST(VectorAt(outputs, 0)).second); - std::vector coordinate(hierarchy.NumAxes()); - int64_t seed_idx = 0; - int64_t stride = 1; - for (int i = nd_sbp.sbp_parallel_size() - 1; i >= 0; --i) { - // coordinate at axis i - int coord = parallel_id % hierarchy.At(i); - parallel_id = (parallel_id - coord) / hierarchy.At(i); - // coordinate reset to 0 if broadcast - if (nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { - // do nothing - } else if (nd_sbp.sbp_parallel(i).has_split_parallel()) { - seed_idx += coord * stride; - stride *= hierarchy.At(i); - } else { - // other sbp is not allowed - return Error::RuntimeError() << "random source op only support broadcast or split"; - } +Maybe GetRandomSeedForRank(const ParallelDesc& placement, const NdSbp& nd_sbp, + uint64_t init_seed, int64_t rank_id) { + uint64_t seed = init_seed; + const Shape& hierarchy = *placement.hierarchy(); + std::vector coordinate(hierarchy.NumAxes()); + int64_t seed_idx = 0; + int64_t stride = 1; + for (int i = nd_sbp.sbp_parallel_size() - 1; i >= 0; --i) { + // coordinate at axis i + int coord = rank_id % hierarchy.At(i); + rank_id = (rank_id - coord) / hierarchy.At(i); + // coordinate reset to 0 if broadcast + if (nd_sbp.sbp_parallel(i).has_broadcast_parallel()) { + // do nothing + } else if (nd_sbp.sbp_parallel(i).has_split_parallel()) { + seed_idx += coord * stride; + stride *= hierarchy.At(i); + } else { + // other sbp is not allowed + return Error::RuntimeError() << "random source op only support broadcast or split"; } - std::seed_seq seq{init_seed}; - std::vector seeds(stride); - seq.generate(seeds.begin(), seeds.end()); - seed = JUST(VectorAt(seeds, seed_idx)); } + std::seed_seq seq{init_seed}; + std::vector seeds(stride); + seq.generate(seeds.begin(), seeds.end()); + seed = JUST(VectorAt(seeds, seed_idx)); return seed; } +Maybe GetOpKernelRandomSeedInCurrentRank(const user_op::KernelInitContext* ctx, + uint64_t init_seed) { + const auto& outputs = ctx->outputs(); + CHECK_EQ(outputs.size(), 1); + return GetRandomSeedForRank(ctx->parallel_desc(), ctx->NdSbp4ArgNameAndIndex("out", 0), init_seed, + ctx->parallel_ctx().parallel_id()); +} + } // namespace oneflow diff --git a/oneflow/user/kernels/random_seed_util.h b/oneflow/user/kernels/random_seed_util.h index 7f22010886b..233f8384d51 100644 --- a/oneflow/user/kernels/random_seed_util.h +++ b/oneflow/user/kernels/random_seed_util.h @@ -20,10 +20,12 @@ limitations under the License. namespace oneflow { -Maybe GetOpKernelRandomSeed(const user_op::KernelInitContext* ctx); +Maybe GetRandomSeedForRank(const ParallelDesc& placement, const NdSbp& nd_sbp, + uint64_t init_seed, int64_t rank_id); -Maybe GetOpKernelRandomSeedInCurrentRank(const user_op::KernelInitContext* ctx, - int64_t init_seed); +Maybe GetOpKernelRandomSeed(const user_op::KernelInitContext* ctx); +Maybe GetOpKernelRandomSeedInCurrentRank(const user_op::KernelInitContext* ctx, + uint64_t init_seed); } // namespace oneflow From d7c11bcd1316f9aaf43253321cdd8765bc7881c5 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 24 Oct 2022 21:33:55 +0800 Subject: [PATCH 45/66] Add a GetSbpSignature with use parallel num instead of parallel description --- oneflow/core/operator/operator.cpp | 12 ++++++++++++ oneflow/core/operator/operator.h | 9 +++++++++ oneflow/core/operator/user_op.cpp | 14 +++++++------- oneflow/core/operator/user_op.h | 2 +- 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index a48f166e278..4e7f44113a9 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -495,6 +495,7 @@ Maybe Operator::GetInputOutputFastestTimeShape() const { return input_output_fastest_time_shape_; } +// TODO: Delete this function. We never use parallel_desc in GetSbpSignature Maybe Operator::GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const { @@ -506,6 +507,17 @@ Maybe Operator::GetSbpSignaturesIf( return Maybe::Ok(); } +Maybe Operator::GetSbpSignaturesIf( + const std::function(const std::string&)>& LogicalBlobDesc4Ibn, + int32_t parallel_num, SbpSignatureList* sbp_sig_list) const { + JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, parallel_num, sbp_sig_list)); + SbpSignatureBuilder() + .Broadcast(input_bns()) + .Broadcast(output_bns()) + .Build(sbp_sig_list->mutable_sbp_signature()->Add()); + return Maybe::Ok(); +} + Maybe Operator::GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const { diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h index afed6ba8d2f..18ef8ed5fc7 100644 --- a/oneflow/core/operator/operator.h +++ b/oneflow/core/operator/operator.h @@ -173,6 +173,9 @@ class Operator { Maybe GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const; + Maybe GetSbpSignaturesIf( + const std::function(const std::string&)>& LogicalBlobDesc4Ibn, + int32_t parallel_num, SbpSignatureList* sbp_sig_list) const; virtual Maybe GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const; @@ -212,11 +215,17 @@ class Operator { virtual Maybe InferInternalBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const; + // TODO: Delete this function. We never use parallel_desc in GetSbpSignature virtual Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const { return GetSbpSignatures(LogicalBlobDesc4Ibn, sbp_sig_list); } + virtual Maybe GetSbpSignatures( + const std::function(const std::string&)>& LogicalBlobDesc4Ibn, + int32_t parallel_num, SbpSignatureList* sbp_sig_list) const { + return GetSbpSignatures(LogicalBlobDesc4Ibn, sbp_sig_list); + } virtual Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index db311f600ad..b8765c22b7e 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -345,8 +345,9 @@ class UserOpSbpContext : public user_op::SbpContext { using ArgVec = std::vector>; UserOpSbpContext(const UserOp* op, SbpSignatureList* sbp_sig_list, - std::function(const std::string&)> LogicalBlobDesc4Ibn) - : op_(op), sbp_sig_list_(sbp_sig_list) { + std::function(const std::string&)> LogicalBlobDesc4Ibn, + int32_t parallel_num) + : op_(op), sbp_sig_list_(sbp_sig_list), parallel_num_(parallel_num) { const auto& user_op_conf = op->op_conf().user_conf(); for (auto it = user_op_conf.input().begin(); it != user_op_conf.input().end(); ++it) { const std::string& arg_name = it->first; @@ -375,14 +376,13 @@ class UserOpSbpContext : public user_op::SbpContext { DeviceType device_type() const override { return op_->device_type(); } - int64_t parallel_num() const override { - return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num(); - } + int64_t parallel_num() const override { return parallel_num_; } private: const UserOp* op_; SbpSignatureList* sbp_sig_list_; HashMap, user_op::NaiveTensorDesc> arg2tensor_desc_; + int32_t parallel_num_; }; class UserOpInferSbpSignatureFnContext : public user_op::InferSbpSignatureFnContext { @@ -876,10 +876,10 @@ Maybe UserOp::InferSbpSignature( Maybe UserOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const { + int32_t parallel_num, SbpSignatureList* sbp_sig_list) const { CHECK_OR_RETURN(val_ != nullptr) << "cannot find op_type: " << op_conf().user_conf().op_type_name() << " in op registry!"; - UserOpSbpContext sbp_ctx(this, sbp_sig_list, LogicalBlobDesc4Ibn); + UserOpSbpContext sbp_ctx(this, sbp_sig_list, LogicalBlobDesc4Ibn, parallel_num); JUST(val_->get_sbp_fn(&sbp_ctx)); // Add Broadcast for source user op tick input if (val_->op_def.input_size() == 1 && input_bns().size() == 1 diff --git a/oneflow/core/operator/user_op.h b/oneflow/core/operator/user_op.h index d0f39c8fce1..892111524cd 100644 --- a/oneflow/core/operator/user_op.h +++ b/oneflow/core/operator/user_op.h @@ -64,7 +64,7 @@ class UserOp final : public Operator { const ParallelDesc& parallel_desc) const override; Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const override; + int32_t parallel_num, SbpSignatureList* sbp_sig_list) const override; Maybe GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, From f30f29dd8b9103b5e44e4f98aed69206ab16554a Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 24 Oct 2022 22:15:49 +0800 Subject: [PATCH 46/66] Get sbp_sig_list for each dimension of hierarchy --- oneflow/core/framework/sbp_infer_util.cpp | 9 +++++--- oneflow/core/framework/sbp_infer_util.h | 3 ++- oneflow/core/operator/operator.cpp | 26 ++++++++++++++++------- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/oneflow/core/framework/sbp_infer_util.cpp b/oneflow/core/framework/sbp_infer_util.cpp index 3ec7562dd51..8211f98caf8 100644 --- a/oneflow/core/framework/sbp_infer_util.cpp +++ b/oneflow/core/framework/sbp_infer_util.cpp @@ -603,14 +603,17 @@ void SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp } void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims, - const SbpSignatureList& sbp_sig_list, + const Shape& hierarchy, + const HashMap& hierarchy_num2sbp_sig_list, std::vector* nd_sbp_sig_list) { if (depth == dims) { nd_sbp_sig_list->push_back(nd_sbp_sig); } else { - for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) { + for (const auto& sbp_signature : + hierarchy_num2sbp_sig_list.at(hierarchy.At(depth)).sbp_signature()) { SetNdSbpSignature(&nd_sbp_sig, sbp_signature, depth); - DfsGetNdSbpSignature(nd_sbp_sig, depth + 1, dims, sbp_sig_list, nd_sbp_sig_list); + DfsGetNdSbpSignature(nd_sbp_sig, depth + 1, dims, hierarchy, hierarchy_num2sbp_sig_list, + nd_sbp_sig_list); } } } diff --git a/oneflow/core/framework/sbp_infer_util.h b/oneflow/core/framework/sbp_infer_util.h index 21d7da6ae90..afff052a4b5 100644 --- a/oneflow/core/framework/sbp_infer_util.h +++ b/oneflow/core/framework/sbp_infer_util.h @@ -62,7 +62,8 @@ void SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp int32_t sbp_axis); void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims, - const SbpSignatureList& sbp_sig_list, + const Shape& hierarchy, + const HashMap& hierarchy_num2sbp_sig_list, std::vector* nd_sbp_sig_list); // Compute storage for given NdSbp diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 4e7f44113a9..124c6257111 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -499,7 +499,7 @@ Maybe Operator::GetInputOutputFastestTimeShape() const { Maybe Operator::GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const { - JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc, sbp_sig_list)); + JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), sbp_sig_list)); SbpSignatureBuilder() .Broadcast(input_bns()) .Broadcast(output_bns()) @@ -522,18 +522,27 @@ Maybe Operator::GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const { // Get 1D sbp signature list - SbpSignatureList sbp_sig_list; - JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc, &sbp_sig_list)); - CHECK_GT_OR_RETURN(sbp_sig_list.sbp_signature_size(), 0) - << op_name() << " gets no sbp signature from GetSbpSignaturesIf function!"; + HashMap hierarchy_num2sbp_sig_list; + for (int32_t hierarchy_num : *parallel_desc.hierarchy()) { + if (hierarchy_num2sbp_sig_list.find(hierarchy_num) == hierarchy_num2sbp_sig_list.end()) { + auto* sbp_sig_list = &hierarchy_num2sbp_sig_list[hierarchy_num]; + JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), sbp_sig_list)); + CHECK_GT_OR_RETURN(sbp_sig_list->sbp_signature_size(), 0) + << op_name() + << " gets no sbp signature from GetSbpSignaturesIf function for hierarchy num: " + << hierarchy_num; + } + } int32_t sbp_dimension = parallel_desc.hierarchy()->NumAxes(); NdSbpSignature nd_sbp_sig; - SbpSignatureToNdSbpSignature(sbp_sig_list.sbp_signature(0), &nd_sbp_sig); + SbpSignatureToNdSbpSignature(hierarchy_num2sbp_sig_list.begin()->second.sbp_signature(0), + &nd_sbp_sig); ResizeNdSbpSignature(nd_sbp_sig, sbp_dimension); // ND sbp signature list would be direct product of 1D sbp signatures CHECK_OR_RETURN(nd_sbp_sig_list->empty()); - DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, sbp_sig_list, nd_sbp_sig_list); + DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, *parallel_desc.hierarchy(), + hierarchy_num2sbp_sig_list, nd_sbp_sig_list); return Maybe::Ok(); } @@ -845,7 +854,8 @@ Maybe Operator::InferSbpSignature( SbpSignatureList valid_sbp_sig_list; { SbpSignatureList sbp_sig_candidates; - JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc, &sbp_sig_candidates)); + JUST( + GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), &sbp_sig_candidates)); // filter sbp signatures by logical shape JUST(FilterAndCheckValidSbpSignatureListByLogicalShape(sbp_sig_candidates, LogicalBlobDesc4Ibn, parallel_desc, &valid_sbp_sig_list)); From fdc7ee8558cab68aa9fa152cf1ba2a6dc2b4554e Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 24 Oct 2022 22:20:21 +0800 Subject: [PATCH 47/66] Add test script and print out information --- oneflow/core/operator/operator.cpp | 55 +++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 124c6257111..e2a05e0f548 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -17,6 +17,7 @@ limitations under the License. #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/decorator.h" +#include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/to_string.h" @@ -37,6 +38,12 @@ namespace oneflow { namespace { +std::string ParallelDesc2String(const ParallelDesc& parallel_desc) { + std::ostringstream out; + out << "hierarchy: " << *parallel_desc.hierarchy() << ", device: " << parallel_desc.device_tag(); + return out.str(); +} + DataType GetDataTypeFromBnInOpVec( std::function GetBlobDesc4BnInOp, const PbRpf& bn_in_ops) { @@ -787,6 +794,23 @@ Maybe Operator::GreedilyFindMinCopyCostNdSbp( producer_infer_hint4ibn->parallel_desc(), *JUST(GetParallelDesc4BnInOp(ibn)), requires_same_sbp[ibn_id]); sum_priority_ratio += priority_ratio; + + if (GlobalProcessCtx::Rank() == 0 + && op_name().find("model.t5_model.encoder.layers.0.self_attention-reshape-29") + != std::string::npos) { + if (i == 0) { + std::cout << "Producer " << NdSbpToString(producer_infer_hint4ibn->nd_sbp()) + << ", placement: " + << ParallelDesc2String(producer_infer_hint4ibn->parallel_desc()) + << std::endl; + std::cout << "Shape: " << producer_infer_hint4ibn->logical_blob_desc().shape() + << std::endl; + } + std::cout << "idx: " << i << ", sbp: " + << NdSbpToString(JUST(VectorAt(nd_sbp_sig_list, i)).bn_in_op2nd_sbp().at(ibn)) + << ", placement: " << ParallelDesc2String(*JUST(GetParallelDesc4BnInOp(ibn))) + << std::endl; + } // We do not accept any blob which has a priority ratio greater than 1 if (priority_ratio > 1.5) { total_copy_cost = GetMaxVal(); @@ -820,21 +844,26 @@ Maybe Operator::GreedilyFindMinCopyCostNdSbp( } } // Can't find any available sbp - if (select_sbp_idx == -1) { - std::ostringstream err; - err << "op: `" << op_name() << "` can't find available sbp signature." << std::endl; - err << "candidate nd sbp signature are: " - << *JUST(NdSbpSignatureListAsString(nd_sbp_sig_list, input_bns(), output_bns())); - err << ", but inputs sbp are:"; - for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) { - const auto& ibn = input_bns().at(ibn_id); - const NdSbp& nd_sbp = JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(); - err << " " << ibn << ": " << NdSbpToString(nd_sbp); - if (requires_same_sbp[ibn_id]) { err << " [ transfer disabled ]"; } - err << ";"; + std::ostringstream err; + err << "op: `" << op_name() << "` can't find available sbp signature." << std::endl; + err << "candidate nd sbp signature are: " + << *JUST(NdSbpSignatureListAsString(nd_sbp_sig_list, input_bns(), output_bns())); + err << ", but inputs sbp are:"; + for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) { + const auto& ibn = input_bns().at(ibn_id); + const NdSbp& nd_sbp = JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(); + err << " " << ibn << ": " << NdSbpToString(nd_sbp); + if (requires_same_sbp[ibn_id]) { err << " [ transfer disabled ]"; } + err << ";"; + + if (GlobalProcessCtx::Rank() == 0 + && op_name().find("model.t5_model.encoder.layers.0.self_attention-reshape-29") + != std::string::npos) { + std::cout << err.str() << std::endl; + std::cout << "select idx: " << select_sbp_idx << std::endl; } - return Error::RuntimeError() << err.str(); + if (select_sbp_idx == -1) { return Error::RuntimeError() << err.str(); } } } nd_sbp_signature->CopyFrom(nd_sbp_sig_list.at(select_sbp_idx)); From e1b4a96cf9c14b265b953ac4207d3f77e720a76f Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 24 Oct 2022 22:47:33 +0800 Subject: [PATCH 48/66] Remove parallel description in GetSbpSignature() --- .../optimizer_placement_optimization_pass.cpp | 3 ++- oneflow/core/operator/dynamic_reshape_op.cpp | 4 ++-- oneflow/core/operator/operator.cpp | 12 ------------ oneflow/core/operator/operator.h | 9 --------- 4 files changed, 4 insertions(+), 24 deletions(-) diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 2c6e16a8bb8..e62d2e7f3c5 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -297,7 +297,8 @@ bool IsS0SignatureSupported(const OpNode* node) { auto LogicalBlobDesc4Ibn = [&](const std::string& bn) -> Maybe { return Maybe(node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi(bn))); }; - CHECK_JUST(node->op().GetSbpSignaturesIf(LogicalBlobDesc4Ibn, node->parallel_desc(), &list)); + CHECK_JUST(node->op().GetSbpSignaturesIf(LogicalBlobDesc4Ibn, + node->parallel_desc().parallel_num(), &list)); const auto IsInOutS0Parallel = [&](const SbpSignature& signature) { return IsS0Parallel(signature, node->op().SoleIbn()) && IsS0Parallel(signature, node->op().SoleObn()); diff --git a/oneflow/core/operator/dynamic_reshape_op.cpp b/oneflow/core/operator/dynamic_reshape_op.cpp index 34e90416d96..72ee5dc47c8 100644 --- a/oneflow/core/operator/dynamic_reshape_op.cpp +++ b/oneflow/core/operator/dynamic_reshape_op.cpp @@ -104,7 +104,7 @@ class DynamicReshapeOp final : public Operator { private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const override { + SbpSignatureList* sbp_sig_list) const override { SbpSignatureBuilder() .Split(input_bns(), 0) .Split(output_bns(), 0) @@ -144,7 +144,7 @@ class DynamicReshapeLikeOp final : public Operator { private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const override { + SbpSignatureList* sbp_sig_list) const override { SbpSignatureBuilder() .Split(input_bns(), 0) .Split(output_bns(), 0) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index e2a05e0f548..788c1b95d8f 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -502,18 +502,6 @@ Maybe Operator::GetInputOutputFastestTimeShape() const { return input_output_fastest_time_shape_; } -// TODO: Delete this function. We never use parallel_desc in GetSbpSignature -Maybe Operator::GetSbpSignaturesIf( - const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const { - JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), sbp_sig_list)); - SbpSignatureBuilder() - .Broadcast(input_bns()) - .Broadcast(output_bns()) - .Build(sbp_sig_list->mutable_sbp_signature()->Add()); - return Maybe::Ok(); -} - Maybe Operator::GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, int32_t parallel_num, SbpSignatureList* sbp_sig_list) const { diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h index 18ef8ed5fc7..6fa19f13068 100644 --- a/oneflow/core/operator/operator.h +++ b/oneflow/core/operator/operator.h @@ -170,9 +170,6 @@ class Operator { Maybe NdSbp4BnInOp(const std::string& bn_in_op) const; Maybe OptLocalParallel4BnInOp(const std::string& bn_in_op) const; - Maybe GetSbpSignaturesIf( - const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const; Maybe GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, int32_t parallel_num, SbpSignatureList* sbp_sig_list) const; @@ -215,12 +212,6 @@ class Operator { virtual Maybe InferInternalBlobDescs( const std::function& GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const JobDesc* job_desc) const; - // TODO: Delete this function. We never use parallel_desc in GetSbpSignature - virtual Maybe GetSbpSignatures( - const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const { - return GetSbpSignatures(LogicalBlobDesc4Ibn, sbp_sig_list); - } virtual Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, int32_t parallel_num, SbpSignatureList* sbp_sig_list) const { From dc23ff7802dc4227a7204744f06ef555c63a0e76 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 24 Oct 2022 23:09:59 +0800 Subject: [PATCH 49/66] Fix small bug --- oneflow/core/operator/operator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 788c1b95d8f..fc3300675cb 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -521,7 +521,7 @@ Maybe Operator::GetNdSbpSignatureList( for (int32_t hierarchy_num : *parallel_desc.hierarchy()) { if (hierarchy_num2sbp_sig_list.find(hierarchy_num) == hierarchy_num2sbp_sig_list.end()) { auto* sbp_sig_list = &hierarchy_num2sbp_sig_list[hierarchy_num]; - JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), sbp_sig_list)); + JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, hierarchy_num, sbp_sig_list)); CHECK_GT_OR_RETURN(sbp_sig_list->sbp_signature_size(), 0) << op_name() << " gets no sbp signature from GetSbpSignaturesIf function for hierarchy num: " From 195b0ea149c77374737751356b97f6bf2da240ff Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Mon, 24 Oct 2022 23:15:49 +0800 Subject: [PATCH 50/66] Disable InferNdSbp for reshape op --- oneflow/user/ops/reshape_user_op_util.cpp | 99 ----------------------- 1 file changed, 99 deletions(-) diff --git a/oneflow/user/ops/reshape_user_op_util.cpp b/oneflow/user/ops/reshape_user_op_util.cpp index 32fab5354e9..78dab917c98 100644 --- a/oneflow/user/ops/reshape_user_op_util.cpp +++ b/oneflow/user/ops/reshape_user_op_util.cpp @@ -174,103 +174,4 @@ Maybe ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures( return Maybe::Ok(); } -namespace { - -Maybe GetInputNdSbp(user_op::InferNdSbpFnContext* ctx, const user_op::OpArg& in_arg, - NdSbp* distribution) { - *distribution = ctx->NdSbpHint4InputArgNameAndIndex(in_arg.name(), in_arg.index()); - const auto& constraints = ctx->nd_sbp_constraints(); - if (constraints.bn_in_op2nd_sbp_size() != 0) { - const auto it = - constraints.bn_in_op2nd_sbp().find(GenRepeatedBn(in_arg.name(), in_arg.index())); - if (it != constraints.bn_in_op2nd_sbp().end()) { *distribution = it->second; } - } - return Maybe::Ok(); -} - -Maybe ApplySbpParallel(const SbpParallel& sbp, const int64_t parallel_num, Shape* shape) { - if (sbp.has_split_parallel()) { - const int64_t axis = sbp.split_parallel().axis(); - CHECK_EQ_OR_RETURN(shape->At(axis) % parallel_num, 0) - << Error::RuntimeError() << "The size of tensor in the " << axis - << " must be an integer multiple of parallel_num, " - << "but got " << shape->At(axis) << " and " << parallel_num; - shape->Set(axis, shape->At(axis) / parallel_num); - } - return Maybe::Ok(); -} - -} // namespace - -Maybe ReshapeUserOpUtil::InferNdSbp(user_op::InferNdSbpFnContext* ctx, - const Shape& logical_in_shape, - const Shape& logical_out_shape) { - const std::string& op_type_name = ctx->user_op_conf().op_type_name(); - CHECK_OR_RETURN(op_type_name == "reshape" || op_type_name == "reshape_like") - << Error::RuntimeError() << "The op_type_name must be \"reshape\" or \"reshape_like\", " - << "but got " << op_type_name; - const bool is_reshape_like = (op_type_name == "reshape_like"); - std::vector in_args({{"in", 0}}); - if (is_reshape_like) { in_args.emplace_back(user_op::OpArg("like", 0)); } - HashMap ibn2nd_sbp; - ibn2nd_sbp.reserve(in_args.size()); - for (const auto& arg : in_args) { - NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex(arg.name(), arg.index()); - JUST(GetInputNdSbp(ctx, arg, in_distribution)); - CHECK_OR_RETURN( - ibn2nd_sbp.emplace(GenRepeatedBn(arg.name(), arg.index()), *in_distribution).second) - << "emplace error"; // NOLINT(maybe-need-error-msg) - } - NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - - Shape in_shape = logical_in_shape; - Shape out_shape = logical_out_shape; - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - for (int64_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - SbpSignatureList sbp_sig_list; - user_op::UserOpSbpSignatureBuilder builder(&sbp_sig_list); - builder.Broadcast(in_args).Broadcast(user_op::OpArg("out", 0)).Build(); - if (is_reshape_like) { - builder.PartialSum(user_op::OpArg("like", 0)) - .Broadcast(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - builder.Broadcast(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - JUST(GetReshapeUserOpSbpSignatures(in_shape, out_shape, {{"in", 0}}, - {{"like", 0}, {"out", 0}}, parallel_hierarchy.At(i), - &builder)); - } else { - JUST(GetReshapeUserOpSbpSignatures(in_shape, out_shape, {{"in", 0}}, {{"out", 0}}, - parallel_hierarchy.At(i), &builder)); - } - - const SbpSignature* matched_sbp_signature = nullptr; - for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) { - bool all_match = true; - for (const auto& in_arg : in_args) { - std::string ibn = GenRepeatedBn(in_arg.name(), in_arg.index()); - if (sbp_signature.bn_in_op2sbp_parallel().at(ibn) != ibn2nd_sbp.at(ibn).sbp_parallel(i)) { - all_match = false; - break; - } - } - if (all_match) { - matched_sbp_signature = &sbp_signature; - break; - } - } - CHECK_OR_RETURN(matched_sbp_signature != nullptr) - << "FusedLstmCellGrad::Pointer to the matched sbp signature is nullptr"; - SbpParallel out_sbp = matched_sbp_signature->bn_in_op2sbp_parallel().at("out_0"); - JUST(ApplySbpParallel(matched_sbp_signature->bn_in_op2sbp_parallel().at("in_0"), - parallel_hierarchy.At(i), &in_shape)); - JUST(ApplySbpParallel(out_sbp, parallel_hierarchy.At(i), &out_shape)); - *(out_distribution->add_sbp_parallel()) = out_sbp; - } - return Maybe::Ok(); -} - } // namespace oneflow From 3e39ce6e89c540dc8e80d59a9b1818c9da4fd0f1 Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Tue, 25 Oct 2022 12:51:09 +0800 Subject: [PATCH 51/66] test(RandomOp): add data consistent test --- .../modules/test_global_random_op_data.py | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 python/oneflow/test/modules/test_global_random_op_data.py diff --git a/python/oneflow/test/modules/test_global_random_op_data.py b/python/oneflow/test/modules/test_global_random_op_data.py new file mode 100644 index 00000000000..03f1e135159 --- /dev/null +++ b/python/oneflow/test/modules/test_global_random_op_data.py @@ -0,0 +1,102 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from collections import OrderedDict + +import oneflow as flow +import numpy as np +import oneflow.unittest +from oneflow.test_utils.automated_test_util import * + +from oneflow.test_utils.test_util import GenArgDict + + +_fn_param = { + "normal": {"mean": 0.0, "std": 1.0}, +} + + +def _call_fn(fn, shape, placement, sbp): + return eval(f"flow.{fn}")(size=shape, **_fn_param[fn], placement=placement, sbp=sbp) + + +def _test_data_consistent(test_case, shape, placement, sbp, fn): + # lazy result + class GlobalRandnGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self): + flow.manual_seed(233) + x = _call_fn(fn, shape, placement, sbp) + return x + + model = GlobalRandnGraph() + lazy_x = model() + + # eager result + flow.manual_seed(233) + eager_x = _call_fn(fn, shape, placement, sbp) + + test_case.assertTrue(np.array_equal(lazy_x.to_local().numpy(), eager_x.to_local().numpy())) + + +class TestGlobalRandomOpData(flow.unittest.TestCase): + @globaltest + def test_random_op_data_consistent_with_eager_and_lazy(test_case): + shape = (8, 8) + + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=2, except_partial_sum=True): + for fn in _fn_param.keys(): + _test_data_consistent(test_case, shape, placement, sbp, fn=fn) + + @globaltest + @oneflow.unittest.skip_unless_1n4d() + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + def test_random_op_data_correctness(test_case): + shape = (8, 8) + sbp = [flow.sbp.split(0), flow.sbp.broadcast] + + for device in ["cpu", "cuda"]: + placement = flow.placement(device, [[0, 1], [2, 3]]) + + for fn in _fn_param.keys(): + flow.manual_seed(233) + np_x_local = _call_fn(fn, shape, placement, sbp).to_local().numpy() + np.save(f"/tmp/{fn}_{flow.env.get_rank()}_local.npy", np_x_local) + flow.comm.barrier() + + # compare result in rank0 + if flow.env.get_rank() == 0: + np_local = [np.load(f"/tmp/{fn}_{int(i)}_local.npy") for i in range(4)] + # rank0 == rank1 + test_case.assertTrue(np.array_equal(np_local[0], np_local[1])) + # rank2 == rank3 + test_case.assertTrue(np.array_equal(np_local[2], np_local[3])) + # rank0 != rank2 + test_case.assertFalse(np.array_equal(np_local[0], np_local[2])) + # rank1 != rank3 + test_case.assertFalse(np.array_equal(np_local[1], np_local[3])) + + # clean data in rank0 + for i in range(4): + os.remove(f"/tmp/{fn}_{int(i)}_local.npy") + + +if __name__ == "__main__": + unittest.main() From f7d29d12052fd040a87b9c5dfc2ebeafabb1b78e Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Tue, 25 Oct 2022 14:29:05 +0800 Subject: [PATCH 52/66] Revert "Add test script and print out information" This reverts commit fdc7ee8558cab68aa9fa152cf1ba2a6dc2b4554e. --- oneflow/core/operator/operator.cpp | 55 +++++++----------------------- 1 file changed, 13 insertions(+), 42 deletions(-) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index fc3300675cb..d99127a1286 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -17,7 +17,6 @@ limitations under the License. #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/decorator.h" -#include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/vm/symbol_storage.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/to_string.h" @@ -38,12 +37,6 @@ namespace oneflow { namespace { -std::string ParallelDesc2String(const ParallelDesc& parallel_desc) { - std::ostringstream out; - out << "hierarchy: " << *parallel_desc.hierarchy() << ", device: " << parallel_desc.device_tag(); - return out.str(); -} - DataType GetDataTypeFromBnInOpVec( std::function GetBlobDesc4BnInOp, const PbRpf& bn_in_ops) { @@ -782,23 +775,6 @@ Maybe Operator::GreedilyFindMinCopyCostNdSbp( producer_infer_hint4ibn->parallel_desc(), *JUST(GetParallelDesc4BnInOp(ibn)), requires_same_sbp[ibn_id]); sum_priority_ratio += priority_ratio; - - if (GlobalProcessCtx::Rank() == 0 - && op_name().find("model.t5_model.encoder.layers.0.self_attention-reshape-29") - != std::string::npos) { - if (i == 0) { - std::cout << "Producer " << NdSbpToString(producer_infer_hint4ibn->nd_sbp()) - << ", placement: " - << ParallelDesc2String(producer_infer_hint4ibn->parallel_desc()) - << std::endl; - std::cout << "Shape: " << producer_infer_hint4ibn->logical_blob_desc().shape() - << std::endl; - } - std::cout << "idx: " << i << ", sbp: " - << NdSbpToString(JUST(VectorAt(nd_sbp_sig_list, i)).bn_in_op2nd_sbp().at(ibn)) - << ", placement: " << ParallelDesc2String(*JUST(GetParallelDesc4BnInOp(ibn))) - << std::endl; - } // We do not accept any blob which has a priority ratio greater than 1 if (priority_ratio > 1.5) { total_copy_cost = GetMaxVal(); @@ -832,26 +808,21 @@ Maybe Operator::GreedilyFindMinCopyCostNdSbp( } } // Can't find any available sbp - std::ostringstream err; - err << "op: `" << op_name() << "` can't find available sbp signature." << std::endl; - err << "candidate nd sbp signature are: " - << *JUST(NdSbpSignatureListAsString(nd_sbp_sig_list, input_bns(), output_bns())); - err << ", but inputs sbp are:"; - for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) { - const auto& ibn = input_bns().at(ibn_id); - const NdSbp& nd_sbp = JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(); - err << " " << ibn << ": " << NdSbpToString(nd_sbp); - if (requires_same_sbp[ibn_id]) { err << " [ transfer disabled ]"; } - err << ";"; - - if (GlobalProcessCtx::Rank() == 0 - && op_name().find("model.t5_model.encoder.layers.0.self_attention-reshape-29") - != std::string::npos) { - std::cout << err.str() << std::endl; - std::cout << "select idx: " << select_sbp_idx << std::endl; + if (select_sbp_idx == -1) { + std::ostringstream err; + err << "op: `" << op_name() << "` can't find available sbp signature." << std::endl; + err << "candidate nd sbp signature are: " + << *JUST(NdSbpSignatureListAsString(nd_sbp_sig_list, input_bns(), output_bns())); + err << ", but inputs sbp are:"; + for (int32_t ibn_id = 0; ibn_id < input_bns().size(); ibn_id++) { + const auto& ibn = input_bns().at(ibn_id); + const NdSbp& nd_sbp = JUST(NdSbpInferHint4Ibn(ibn))->nd_sbp(); + err << " " << ibn << ": " << NdSbpToString(nd_sbp); + if (requires_same_sbp[ibn_id]) { err << " [ transfer disabled ]"; } + err << ";"; } - if (select_sbp_idx == -1) { return Error::RuntimeError() << err.str(); } + return Error::RuntimeError() << err.str(); } } nd_sbp_signature->CopyFrom(nd_sbp_sig_list.at(select_sbp_idx)); From 4f084665bb03e58fc68544b0323169cdb4a3773f Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Tue, 25 Oct 2022 15:44:34 +0800 Subject: [PATCH 53/66] refactor(Initializer): refactor normal with oneflow kernel --- oneflow/core/job/initializer_conf.proto | 13 ++---- python/oneflow/nn/init.py | 53 +++++++++++++--------- python/oneflow/ops/initializer_register.py | 46 ------------------- 3 files changed, 34 insertions(+), 78 deletions(-) diff --git a/oneflow/core/job/initializer_conf.proto b/oneflow/core/job/initializer_conf.proto index 3ceac710cff..f21fd508040 100644 --- a/oneflow/core/job/initializer_conf.proto +++ b/oneflow/core/job/initializer_conf.proto @@ -24,12 +24,6 @@ message RandomNormalInitializerConf { optional float std = 2 [default = 1]; } -message TruncNormalInitializerConf { - required RandomNormalInitializerConf norm_conf = 1; - optional float min = 2 [default = -2.0]; - optional float max = 3 [default = 2.0]; -} - //output[D_0 ... D_(axis - 1) i D_(axis + 1) ... D_n] = start + i * stride message RangeInitializerConf { optional double start = 1 [default = 0]; @@ -53,10 +47,9 @@ message InitializerConf { RandomUniformInitializerConf random_uniform_conf = 3; RandomUniformIntInitializerConf random_uniform_int_conf = 4; RandomNormalInitializerConf random_normal_conf = 5; - TruncNormalInitializerConf trunc_normal_conf = 6; - RangeInitializerConf range_conf = 7; - IntRangeInitializerConf int_range_conf = 8; - EmptyInitializerConf empty_conf = 9; + RangeInitializerConf range_conf = 6; + IntRangeInitializerConf int_range_conf = 7; + EmptyInitializerConf empty_conf = 8; } } diff --git a/python/oneflow/nn/init.py b/python/oneflow/nn/init.py index 48967b6fbc1..907b8315695 100644 --- a/python/oneflow/nn/init.py +++ b/python/oneflow/nn/init.py @@ -14,11 +14,16 @@ limitations under the License. """ import os +import math import numpy as np import oneflow as flow -from oneflow.ops.util.initializer_util import calc_gain as calculate_gain +from oneflow.ops.util.initializer_util import ( + calc_gain as calculate_gain, + calc_fan, + get_data_format, +) from oneflow.framework.tensor import Tensor import oneflow.framework.dtype as dtype_util import oneflow.ops.initializer_register as initializer_register @@ -98,8 +103,18 @@ def normal_(tensor, mean=0.0, std=1.0): >>> w = flow.empty(3, 5) >>> nn.init.normal_(w) """ - initializer_conf = initializer_register.random_normal_initializer(mean, std) - return _init_by_initializer_conf(tensor, initializer_conf) + with flow.no_grad(): + if tensor.is_local: + return flow.normal(mean=mean, std=std, size=tensor.shape, out=tensor) + else: + return flow.normal( + mean=mean, + std=std, + size=tensor.shape, + out=tensor, + placement=tensor.placement, + sbp=tensor.sbp, + ) def xavier_uniform_(tensor, gain=1.0, *, data_format="NCHW"): @@ -156,10 +171,11 @@ def xavier_normal_(tensor, gain=1.0, *, data_format="NCHW"): >>> w = flow.empty(3, 5) >>> nn.init.xavier_normal_(w) """ - initializer_conf = initializer_register.xavier_initializer( - tensor.shape, gain=gain, data_format=data_format, distribution="random_normal" - ) - return _init_by_initializer_conf(tensor, initializer_conf) + if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": + data_format = "NHWC" + fan = calc_fan(tensor.shape, "fan_sum", get_data_format(data_format)) + std = gain * math.sqrt(2.0 / fan) + return normal_(tensor, 0.0, std) def orthogonal_(tensor, gain=1.0): @@ -251,8 +267,7 @@ def kaiming_normal_( Args: tensor: an n-dimensional `oneflow.Tensor` - a: the negative slope of the rectifier used after this layer (only - used with ``'leaky_relu'``) + a: the negative slope of the rectifier used after this layer. mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` preserves the magnitude of the variance of the weights in the forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the @@ -266,22 +281,16 @@ def kaiming_normal_( """ if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": data_format = "NHWC" - initializer_conf = initializer_register.kaiming_initializer( - tensor.shape, - a=a, - mode=mode, - nonlinearity=nonlinearity, - data_format=data_format, - distribution="random_normal", - ) - return _init_by_initializer_conf(tensor, initializer_conf) + assert mode in ["fan_in", "fan_out"] + fan = calc_fan(tensor.shape, mode, get_data_format(data_format)) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + return normal_(tensor, 0.0, std) def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): - initializer_conf = initializer_register.truncated_normal_initializer( - mean=mean, std=std, a=a, b=b, - ) - return _init_by_initializer_conf(tensor, initializer_conf) + with flow.no_grad(): + return tensor.normal_(mean, std).clamp_(a, b) def constant_(tensor, val): diff --git a/python/oneflow/ops/initializer_register.py b/python/oneflow/ops/initializer_register.py index 0aaaa9d53ba..e80b9bebfc2 100644 --- a/python/oneflow/ops/initializer_register.py +++ b/python/oneflow/ops/initializer_register.py @@ -260,35 +260,6 @@ def kaiming_initializer( raise NotImplementedError("Only support normal and uniform distribution") -def truncated_normal_initializer( - mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0, -): - r"""Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \leq \text{mean} \leq b`. - - Args: - tensor: an n-dimensional `torch.Tensor` - mean (float, optional): the mean of the normal distribution - std (float, optional): the standard deviation of the normal distribution - a (float, optional): the minimum cutoff value - b (float, optional): the maximum cutoff value - """ - initializer = initializer_conf_util.InitializerConf() - trunc_normal_conf = getattr(initializer, "trunc_normal_conf") - # set norm_conf - norm_conf = getattr(trunc_normal_conf, "norm_conf") - setattr(norm_conf, "mean", float(mean)) - setattr(norm_conf, "std", float(std)) - # set max/min - setattr(trunc_normal_conf, "min", float(a)) - setattr(trunc_normal_conf, "max", float(b)) - return initializer - - @register_initializer("constant_conf") @register_initializer("constant_int_conf") def ConstantInitializerImpl( @@ -340,23 +311,6 @@ def RandomUniformIntInitializerImpl( ) -@register_initializer("trunc_normal_conf") -def TruncNormalInitializerImpl( - initializer_conf: initializer_conf_util.TruncNormalInitializerConf, - random_seed: int, - var_blob_shape: Sequence[int], -): - rng = np.random.default_rng(random_seed) - norm_conf = getattr(initializer_conf, "norm_conf") - mean = getattr(norm_conf, "mean") - std = getattr(norm_conf, "std") - min = getattr(initializer_conf, "min") - max = getattr(initializer_conf, "max") - return lambda length: np.clip( - rng.normal(loc=mean, scale=std, size=length), a_min=min, a_max=max - ) - - @register_initializer("empty_conf") def EmptyInitializerImpl( initializer_conf: initializer_conf_util.EmptyInitializerConf, From d63a59b6fcd90cb689f5430956ad35d654b085ce Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Wed, 26 Oct 2022 16:01:29 +0800 Subject: [PATCH 54/66] fix(RandomSeed): fix parallel_num==1 --- oneflow/user/kernels/random_seed_util.cpp | 1 + .../oneflow/test/modules/test_global_random_op_data.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/oneflow/user/kernels/random_seed_util.cpp b/oneflow/user/kernels/random_seed_util.cpp index 69daa515673..753e69d68ca 100644 --- a/oneflow/user/kernels/random_seed_util.cpp +++ b/oneflow/user/kernels/random_seed_util.cpp @@ -55,6 +55,7 @@ Maybe GetRandomSeedForRank(const ParallelDesc& placement, const NdSbp& Maybe GetOpKernelRandomSeedInCurrentRank(const user_op::KernelInitContext* ctx, uint64_t init_seed) { + if (ctx->parallel_ctx().parallel_num() == 1) { return init_seed; } const auto& outputs = ctx->outputs(); CHECK_EQ(outputs.size(), 1); return GetRandomSeedForRank(ctx->parallel_desc(), ctx->NdSbp4ArgNameAndIndex("out", 0), init_seed, diff --git a/python/oneflow/test/modules/test_global_random_op_data.py b/python/oneflow/test/modules/test_global_random_op_data.py index 03f1e135159..073678ba006 100644 --- a/python/oneflow/test/modules/test_global_random_op_data.py +++ b/python/oneflow/test/modules/test_global_random_op_data.py @@ -52,7 +52,9 @@ def build(self): flow.manual_seed(233) eager_x = _call_fn(fn, shape, placement, sbp) - test_case.assertTrue(np.array_equal(lazy_x.to_local().numpy(), eager_x.to_local().numpy())) + test_case.assertTrue( + np.array_equal(lazy_x.to_local().numpy(), eager_x.to_local().numpy()) + ) class TestGlobalRandomOpData(flow.unittest.TestCase): @@ -74,7 +76,7 @@ def test_random_op_data_correctness(test_case): for device in ["cpu", "cuda"]: placement = flow.placement(device, [[0, 1], [2, 3]]) - + for fn in _fn_param.keys(): flow.manual_seed(233) np_x_local = _call_fn(fn, shape, placement, sbp).to_local().numpy() @@ -83,7 +85,9 @@ def test_random_op_data_correctness(test_case): # compare result in rank0 if flow.env.get_rank() == 0: - np_local = [np.load(f"/tmp/{fn}_{int(i)}_local.npy") for i in range(4)] + np_local = [ + np.load(f"/tmp/{fn}_{int(i)}_local.npy") for i in range(4) + ] # rank0 == rank1 test_case.assertTrue(np.array_equal(np_local[0], np_local[1])) # rank2 == rank3 From 5f379569be245a2be8bd95b4b9a5011b1aeb6f9d Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Wed, 26 Oct 2022 20:23:41 +0800 Subject: [PATCH 55/66] Add hierarchy value --- oneflow/core/framework/sbp_context.h | 6 ++++++ oneflow/core/framework/sbp_infer_util.cpp | 6 +++--- oneflow/core/framework/sbp_infer_util.h | 2 +- oneflow/core/operator/operator.cpp | 23 ++++++++++++++--------- oneflow/core/operator/user_op.cpp | 16 ++++++++++------ oneflow/core/operator/user_op.h | 2 +- oneflow/user/ops/reshape_op.cpp | 2 +- 7 files changed, 36 insertions(+), 21 deletions(-) diff --git a/oneflow/core/framework/sbp_context.h b/oneflow/core/framework/sbp_context.h index 680ae94d5ba..bac1db925f6 100644 --- a/oneflow/core/framework/sbp_context.h +++ b/oneflow/core/framework/sbp_context.h @@ -76,6 +76,12 @@ class SbpContext : public SbpContextBase { SbpContext() = default; ~SbpContext() override = default; + // hierarchy value is the value at the dimension corresponding to the current SBP + // For example, 2 machines, 4 gpus per machine, hierarchy = [2, 4] + // Suppose we have nd_sbp = (S0, B) + // The hierarchy value corresponding to S0 is 2 + // The hierarchy value corresponding to B is 4. + virtual int64_t hierarchy_value() const = 0; virtual UserOpSbpSignatureBuilder NewBuilder() = 0; }; diff --git a/oneflow/core/framework/sbp_infer_util.cpp b/oneflow/core/framework/sbp_infer_util.cpp index 8211f98caf8..f3a25167c97 100644 --- a/oneflow/core/framework/sbp_infer_util.cpp +++ b/oneflow/core/framework/sbp_infer_util.cpp @@ -604,15 +604,15 @@ void SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims, const Shape& hierarchy, - const HashMap& hierarchy_num2sbp_sig_list, + const HashMap& hierarchy_value2sbp_sig_list, std::vector* nd_sbp_sig_list) { if (depth == dims) { nd_sbp_sig_list->push_back(nd_sbp_sig); } else { for (const auto& sbp_signature : - hierarchy_num2sbp_sig_list.at(hierarchy.At(depth)).sbp_signature()) { + hierarchy_value2sbp_sig_list.at(hierarchy.At(depth)).sbp_signature()) { SetNdSbpSignature(&nd_sbp_sig, sbp_signature, depth); - DfsGetNdSbpSignature(nd_sbp_sig, depth + 1, dims, hierarchy, hierarchy_num2sbp_sig_list, + DfsGetNdSbpSignature(nd_sbp_sig, depth + 1, dims, hierarchy, hierarchy_value2sbp_sig_list, nd_sbp_sig_list); } } diff --git a/oneflow/core/framework/sbp_infer_util.h b/oneflow/core/framework/sbp_infer_util.h index afff052a4b5..fabb13edbfa 100644 --- a/oneflow/core/framework/sbp_infer_util.h +++ b/oneflow/core/framework/sbp_infer_util.h @@ -63,7 +63,7 @@ void SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims, const Shape& hierarchy, - const HashMap& hierarchy_num2sbp_sig_list, + const HashMap& hierarchy_value2sbp_sig_list, std::vector* nd_sbp_sig_list); // Compute storage for given NdSbp diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index d99127a1286..2e92ef4322e 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -510,27 +510,32 @@ Maybe Operator::GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const { // Get 1D sbp signature list - HashMap hierarchy_num2sbp_sig_list; - for (int32_t hierarchy_num : *parallel_desc.hierarchy()) { - if (hierarchy_num2sbp_sig_list.find(hierarchy_num) == hierarchy_num2sbp_sig_list.end()) { - auto* sbp_sig_list = &hierarchy_num2sbp_sig_list[hierarchy_num]; - JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, hierarchy_num, sbp_sig_list)); + HashMap hierarchy_value2sbp_sig_list; + // hierarchy value is the value at the dimension corresponding to the current SBP + // For example, 2 machines, 4 gpus per machine, hierarchy = [2, 4] + // Suppose we have nd_sbp = (S0, B) + // The hierarchy value corresponding to S0 is 2 + // The hierarchy value corresponding to B is 4. + for (int32_t hierarchy_value : *parallel_desc.hierarchy()) { + if (hierarchy_value2sbp_sig_list.find(hierarchy_value) == hierarchy_value2sbp_sig_list.end()) { + auto* sbp_sig_list = &hierarchy_value2sbp_sig_list[hierarchy_value]; + JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, hierarchy_value, sbp_sig_list)); CHECK_GT_OR_RETURN(sbp_sig_list->sbp_signature_size(), 0) << op_name() - << " gets no sbp signature from GetSbpSignaturesIf function for hierarchy num: " - << hierarchy_num; + << " gets no sbp signature from GetSbpSignaturesIf function for hierarchy value: " + << hierarchy_value; } } int32_t sbp_dimension = parallel_desc.hierarchy()->NumAxes(); NdSbpSignature nd_sbp_sig; - SbpSignatureToNdSbpSignature(hierarchy_num2sbp_sig_list.begin()->second.sbp_signature(0), + SbpSignatureToNdSbpSignature(hierarchy_value2sbp_sig_list.begin()->second.sbp_signature(0), &nd_sbp_sig); ResizeNdSbpSignature(nd_sbp_sig, sbp_dimension); // ND sbp signature list would be direct product of 1D sbp signatures CHECK_OR_RETURN(nd_sbp_sig_list->empty()); DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, *parallel_desc.hierarchy(), - hierarchy_num2sbp_sig_list, nd_sbp_sig_list); + hierarchy_value2sbp_sig_list, nd_sbp_sig_list); return Maybe::Ok(); } diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index b8765c22b7e..34e4fd2ecbc 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -346,8 +346,8 @@ class UserOpSbpContext : public user_op::SbpContext { UserOpSbpContext(const UserOp* op, SbpSignatureList* sbp_sig_list, std::function(const std::string&)> LogicalBlobDesc4Ibn, - int32_t parallel_num) - : op_(op), sbp_sig_list_(sbp_sig_list), parallel_num_(parallel_num) { + int32_t hierarchy_value) + : op_(op), sbp_sig_list_(sbp_sig_list), hierarchy_value_(hierarchy_value) { const auto& user_op_conf = op->op_conf().user_conf(); for (auto it = user_op_conf.input().begin(); it != user_op_conf.input().end(); ++it) { const std::string& arg_name = it->first; @@ -376,13 +376,17 @@ class UserOpSbpContext : public user_op::SbpContext { DeviceType device_type() const override { return op_->device_type(); } - int64_t parallel_num() const override { return parallel_num_; } + int64_t parallel_num() const override { + return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num(); + } + + int64_t hierarchy_value() const { return hierarchy_value_; } private: const UserOp* op_; SbpSignatureList* sbp_sig_list_; HashMap, user_op::NaiveTensorDesc> arg2tensor_desc_; - int32_t parallel_num_; + int32_t hierarchy_value_; }; class UserOpInferSbpSignatureFnContext : public user_op::InferSbpSignatureFnContext { @@ -876,10 +880,10 @@ Maybe UserOp::InferSbpSignature( Maybe UserOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - int32_t parallel_num, SbpSignatureList* sbp_sig_list) const { + int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const { CHECK_OR_RETURN(val_ != nullptr) << "cannot find op_type: " << op_conf().user_conf().op_type_name() << " in op registry!"; - UserOpSbpContext sbp_ctx(this, sbp_sig_list, LogicalBlobDesc4Ibn, parallel_num); + UserOpSbpContext sbp_ctx(this, sbp_sig_list, LogicalBlobDesc4Ibn, hierarchy_value); JUST(val_->get_sbp_fn(&sbp_ctx)); // Add Broadcast for source user op tick input if (val_->op_def.input_size() == 1 && input_bns().size() == 1 diff --git a/oneflow/core/operator/user_op.h b/oneflow/core/operator/user_op.h index 892111524cd..5399438fd4e 100644 --- a/oneflow/core/operator/user_op.h +++ b/oneflow/core/operator/user_op.h @@ -64,7 +64,7 @@ class UserOp final : public Operator { const ParallelDesc& parallel_desc) const override; Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - int32_t parallel_num, SbpSignatureList* sbp_sig_list) const override; + int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const override; Maybe GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, diff --git a/oneflow/user/ops/reshape_op.cpp b/oneflow/user/ops/reshape_op.cpp index 822c7f90bf7..916b66089f5 100644 --- a/oneflow/user/ops/reshape_op.cpp +++ b/oneflow/user/ops/reshape_op.cpp @@ -27,7 +27,7 @@ namespace oneflow { const auto& outshape = JUST(ReshapeUserOpUtil::GetLogicalOutBlobShape(in_shape, shape)); user_op::UserOpSbpSignatureBuilder builder = ctx->NewBuilder(); return ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures( - in_shape, *outshape, {{"in", 0}}, {{"out", 0}}, ctx->parallel_num(), &builder); + in_shape, *outshape, {{"in", 0}}, {{"out", 0}}, ctx->hierarchy_value(), &builder); } /*static*/ Maybe ReshapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { From f92e33072dbfc24dcb94a67dec7cefed9e0c33b7 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Wed, 26 Oct 2022 20:30:03 +0800 Subject: [PATCH 56/66] Address comments --- oneflow/core/operator/operator.cpp | 5 +++-- oneflow/core/operator/operator.h | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 2e92ef4322e..bc1cd858ad5 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -497,8 +497,8 @@ Maybe Operator::GetInputOutputFastestTimeShape() const { Maybe Operator::GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - int32_t parallel_num, SbpSignatureList* sbp_sig_list) const { - JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, parallel_num, sbp_sig_list)); + int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const { + JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, hierarchy_value, sbp_sig_list)); SbpSignatureBuilder() .Broadcast(input_bns()) .Broadcast(output_bns()) @@ -847,6 +847,7 @@ Maybe Operator::InferSbpSignature( SbpSignatureList valid_sbp_sig_list; { SbpSignatureList sbp_sig_candidates; + // For 1d sbp, hierarchy value = parallel num JUST( GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), &sbp_sig_candidates)); // filter sbp signatures by logical shape diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h index 6fa19f13068..64c82a9baf1 100644 --- a/oneflow/core/operator/operator.h +++ b/oneflow/core/operator/operator.h @@ -172,7 +172,7 @@ class Operator { Maybe GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - int32_t parallel_num, SbpSignatureList* sbp_sig_list) const; + int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const; virtual Maybe GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const; @@ -214,7 +214,7 @@ class Operator { const ParallelContext* parallel_ctx, const JobDesc* job_desc) const; virtual Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - int32_t parallel_num, SbpSignatureList* sbp_sig_list) const { + int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const { return GetSbpSignatures(LogicalBlobDesc4Ibn, sbp_sig_list); } virtual Maybe GetSbpSignatures( From 0c0954c079aca1a4e8abbadfe84d586b1f68b9d1 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Wed, 26 Oct 2022 12:37:02 +0000 Subject: [PATCH 57/66] parallel num j-> hierarchy value for reshape op --- oneflow/user/ops/reshape_user_op_util.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/oneflow/user/ops/reshape_user_op_util.cpp b/oneflow/user/ops/reshape_user_op_util.cpp index 78dab917c98..873cf209cba 100644 --- a/oneflow/user/ops/reshape_user_op_util.cpp +++ b/oneflow/user/ops/reshape_user_op_util.cpp @@ -92,7 +92,7 @@ Maybe ReshapeUserOpUtil::Squeeze(const Shape& origin, Shape* shape, } Maybe ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis( - const Shape& in_shape, const Shape& out_shape, const int64_t parallel_num, + const Shape& in_shape, const Shape& out_shape, const int64_t hierarchy_value, HashMap* group_start_in_axis2out_axis) { CHECK_GE_OR_RETURN(in_shape.NumAxes(), 0) << Error::RuntimeError() @@ -128,8 +128,8 @@ Maybe ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis( if (in_shape_count == out_shape_count) { // Record split axises if (in_shape.At(in_axis) == out_shape.At(out_axis) - || (in_shape.At(in_axis) % parallel_num == 0 - && out_shape.At(out_axis) % parallel_num == 0)) { + || (in_shape.At(in_axis) % hierarchy_value == 0 + && out_shape.At(out_axis) % hierarchy_value == 0)) { (*group_start_in_axis2out_axis)[in_axis] = out_axis; } // Move forward @@ -147,7 +147,7 @@ Maybe ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis( Maybe ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures( const Shape& in_shape, const Shape& out_shape, std::vector in_args, - std::vector out_args, const int64_t parallel_num, + std::vector out_args, const int64_t hierarchy_value, user_op::UserOpSbpSignatureBuilder* builder) { if (in_shape.NumAxes() == 0 || in_shape.elem_cnt() == 0) { return Maybe::Ok(); @@ -162,7 +162,7 @@ Maybe ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures( JUST(ReshapeUserOpUtil::Squeeze(out_shape, &squeezed_out_shape, &out_squeezed_axis2original_axis)); JUST(ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis(squeezed_in_shape, squeezed_out_shape, - parallel_num, + hierarchy_value, &squeezed_group_start_in_axis2out_axis)); } for (const auto& pair : squeezed_group_start_in_axis2out_axis) { From 16973bbdb4501f15161ef9a5eb1be57aafdfb09d Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Wed, 26 Oct 2022 12:57:07 +0000 Subject: [PATCH 58/66] Static analysis --- oneflow/user/ops/reshape_user_op_util.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oneflow/user/ops/reshape_user_op_util.cpp b/oneflow/user/ops/reshape_user_op_util.cpp index 873cf209cba..a5597e4ccab 100644 --- a/oneflow/user/ops/reshape_user_op_util.cpp +++ b/oneflow/user/ops/reshape_user_op_util.cpp @@ -146,8 +146,8 @@ Maybe ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis( } Maybe ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures( - const Shape& in_shape, const Shape& out_shape, std::vector in_args, - std::vector out_args, const int64_t hierarchy_value, + const Shape& in_shape, const Shape& out_shape, const std::vector& in_args, + const std::vector& out_args, const int64_t hierarchy_value, user_op::UserOpSbpSignatureBuilder* builder) { if (in_shape.NumAxes() == 0 || in_shape.elem_cnt() == 0) { return Maybe::Ok(); From ccd8c57a0f671643a6d0fcc3f5e766b28430f453 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Wed, 26 Oct 2022 13:00:16 +0000 Subject: [PATCH 59/66] refine --- oneflow/user/ops/reshape_user_op_util.h | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/oneflow/user/ops/reshape_user_op_util.h b/oneflow/user/ops/reshape_user_op_util.h index 8178803fbc1..b1fa968f3d0 100644 --- a/oneflow/user/ops/reshape_user_op_util.h +++ b/oneflow/user/ops/reshape_user_op_util.h @@ -26,15 +26,13 @@ struct ReshapeUserOpUtil { static Maybe Squeeze(const Shape& origin, Shape* shape, HashMap* squeezed_axis2origin_axis); static Maybe GetGroupStartInAxis2OutAxis(const Shape& in_shape, const Shape& out_shape, - const int64_t parallel_num, + const int64_t hierarchy_value, HashMap* group_start_in_axis2out_axis); static Maybe GetReshapeUserOpSbpSignatures(const Shape& in_shape, const Shape& out_shape, - std::vector in_args, - std::vector out_args, - const int64_t parallel_num, + const std::vector& in_args, + const std::vector& out_args, + const int64_t hierarchy_value, user_op::UserOpSbpSignatureBuilder* builder); - static Maybe InferNdSbp(user_op::InferNdSbpFnContext* ctx, const Shape& logical_in_shape, - const Shape& logical_out_shape); }; } // namespace oneflow From f471774c3c28b749a8f0a524619a30cc5eff619b Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Thu, 27 Oct 2022 00:53:14 +0800 Subject: [PATCH 60/66] Update user_op.cpp --- oneflow/core/operator/user_op.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index 34e4fd2ecbc..3845da8f566 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -380,7 +380,7 @@ class UserOpSbpContext : public user_op::SbpContext { return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num(); } - int64_t hierarchy_value() const { return hierarchy_value_; } + int64_t hierarchy_value() const override { return hierarchy_value_; } private: const UserOp* op_; From 64832e43196067d67f70094a8d35664a805a5891 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Thu, 27 Oct 2022 00:54:26 +0800 Subject: [PATCH 61/66] Update operator.cpp --- oneflow/core/operator/operator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index bc1cd858ad5..e5da2add5fc 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -529,7 +529,7 @@ Maybe Operator::GetNdSbpSignatureList( int32_t sbp_dimension = parallel_desc.hierarchy()->NumAxes(); NdSbpSignature nd_sbp_sig; - SbpSignatureToNdSbpSignature(hierarchy_value2sbp_sig_list.begin()->second.sbp_signature(0), + SbpSignatureToNdSbpSignature(hierarchy_value2sbp_sig_list.front().sbp_signature(0), &nd_sbp_sig); ResizeNdSbpSignature(nd_sbp_sig, sbp_dimension); // ND sbp signature list would be direct product of 1D sbp signatures From 582d20ae00f6b63309b9528aefaec1a3238ad742 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Wed, 26 Oct 2022 16:56:07 +0000 Subject: [PATCH 62/66] auto format by CI --- oneflow/core/operator/operator.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index e5da2add5fc..28efd312792 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -529,8 +529,7 @@ Maybe Operator::GetNdSbpSignatureList( int32_t sbp_dimension = parallel_desc.hierarchy()->NumAxes(); NdSbpSignature nd_sbp_sig; - SbpSignatureToNdSbpSignature(hierarchy_value2sbp_sig_list.front().sbp_signature(0), - &nd_sbp_sig); + SbpSignatureToNdSbpSignature(hierarchy_value2sbp_sig_list.front().sbp_signature(0), &nd_sbp_sig); ResizeNdSbpSignature(nd_sbp_sig, sbp_dimension); // ND sbp signature list would be direct product of 1D sbp signatures CHECK_OR_RETURN(nd_sbp_sig_list->empty()); From 778da7646847cafdf9ce8b1da7e8ff498bc5ea2c Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Thu, 27 Oct 2022 11:06:14 +0800 Subject: [PATCH 63/66] test(initializer): add initializer data test --- python/oneflow/nn/init.py | 3 +- .../oneflow/test/modules/test_initializer.py | 81 +++++++++++++++++++ .../oneflow/test/tensor/test_tensor_part_1.py | 32 ++++---- 3 files changed, 100 insertions(+), 16 deletions(-) create mode 100644 python/oneflow/test/modules/test_initializer.py diff --git a/python/oneflow/nn/init.py b/python/oneflow/nn/init.py index 907b8315695..7c907ed1ba7 100644 --- a/python/oneflow/nn/init.py +++ b/python/oneflow/nn/init.py @@ -267,7 +267,8 @@ def kaiming_normal_( Args: tensor: an n-dimensional `oneflow.Tensor` - a: the negative slope of the rectifier used after this layer. + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` preserves the magnitude of the variance of the weights in the forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the diff --git a/python/oneflow/test/modules/test_initializer.py b/python/oneflow/test/modules/test_initializer.py new file mode 100644 index 00000000000..6976cf9e8e7 --- /dev/null +++ b/python/oneflow/test/modules/test_initializer.py @@ -0,0 +1,81 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +import numpy as np + +import oneflow as flow +from oneflow.test_utils.automated_test_util import * +import oneflow.unittest + + +class DataChecker: + check_list = [ + "mean", "std", "min", "max", "value", "lambda", + ] + def __init__(self, **kwargs): + self.checkers = {} + for key in self.check_list: + if key in kwargs: + self.checkers[key] = kwargs[key] + + def __call__(self, test_case, tensor): + for func in ["mean", "std", "min", "max"]: + if func in self.checkers: + of_res = eval(f"tensor.{func}")().numpy() + checker_res = self.checkers[func] + test_case.assertTrue(np.allclose(of_res, checker_res, rtol=1e-2, atol=1e-2), f"{func} not equal, {of_res} vs {checker_res}") + + if "value" in self.checkers: + test_case.assertTrue(np.all(tensor.numpy() == self.checkers["value"])) + + if "lambda" in self.checkers: + test_case.assertTrue(np.allclose(tensor.numpy(), self.checkers["lambda"](tensor.shape), rtol=1e-4, atol=1e-4)) + + +# NOTE(wyg): register initializers to this list +check_func_list = [ + # oneflow.nn.init.normal_ + {"func": flow.nn.init.normal_, "params": {"mean": 0.0, "std": 1.0}, "checker": DataChecker(mean=0.0, std=1.0)}, + # oneflow.nn.init.xavier_normal_ + {"func": flow.nn.init.xavier_normal_, "params": {"gain": 1.0}, "checker": DataChecker(mean=0.0, std=0.0625)}, + # oneflow.nn.init.kaiming_normal_ + {"func": flow.nn.init.kaiming_normal_, "params": {"mode": "fan_in"}, "checker": DataChecker(mean=0.0, std=0.0883883476)}, + {"func": flow.nn.init.kaiming_normal_, "params": {"mode": "fan_out"}, "checker": DataChecker(mean=0.0, std=0.0883883476)}, + {"func": flow.nn.init.kaiming_normal_, "params": {"mode": "fan_in", "a": 2.0, "nonlinearity": "leaky_relu"}, "checker": DataChecker(mean=0.0, std=0.0395284708)}, + {"func": flow.nn.init.kaiming_normal_, "params": {"mode": "fan_in", "a": 2.0, "nonlinearity": "linear"}, "checker": DataChecker(mean=0.0, std=0.0625)}, + # TODO: test more initializer +] + + +@oneflow.unittest.skip_unless_1n1d() +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestInitializer(flow.unittest.TestCase): + def test_initializer(test_case): + default_shape = (256, 256) + for device in ["cpu", "cuda"]: + for check_func in check_func_list: + tensor = flow.empty(*default_shape, device=flow.device(device)) + check_func["func"](tensor, **check_func["params"]) + try: + check_func["checker"](test_case, tensor) + except AssertionError as e: + print(f"Failed: {check_func['func'].__name__} {check_func['params']}") + raise e + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/tensor/test_tensor_part_1.py b/python/oneflow/test/tensor/test_tensor_part_1.py index 296f025b2dc..d5024d845f5 100644 --- a/python/oneflow/test/tensor/test_tensor_part_1.py +++ b/python/oneflow/test/tensor/test_tensor_part_1.py @@ -251,24 +251,26 @@ def _test_non_contiguous_tensor_init_methods(test_case, tensor_creator, get_nump @flow.unittest.skip_unless_1n1d() def test_local_tensor_init_methods(test_case): - test_case._test_tensor_init_methods( - lambda *args, **kwargs: flow.Tensor(*args, **kwargs), lambda x: x.numpy() - ) - test_case._test_non_contiguous_tensor_init_methods( - lambda *args, **kwargs: flow.Tensor(*args, **kwargs), lambda x: x.numpy() - ) + for device in ["cpu", "cuda"]: + test_case._test_tensor_init_methods( + lambda *args, **kwargs: flow.Tensor(*args, **kwargs, device=device), lambda x: x.numpy() + ) + test_case._test_non_contiguous_tensor_init_methods( + lambda *args, **kwargs: flow.Tensor(*args, **kwargs, device=device), lambda x: x.numpy() + ) @flow.unittest.skip_unless_1n2d() def test_global_tensor_init_methods(test_case): - test_case._test_tensor_init_methods( - lambda *args, **kwargs: flow.Tensor( - *args, - **kwargs, - sbp=flow.sbp.broadcast, - placement=flow.placement("cuda", range(2)) - ), - lambda x: x.to_global(sbp=flow.sbp.broadcast).to_local().numpy(), - ) + for device in ["cpu", "cuda"]: + test_case._test_tensor_init_methods( + lambda *args, **kwargs: flow.Tensor( + *args, + **kwargs, + sbp=flow.sbp.broadcast, + placement=flow.placement(device, range(2)) + ), + lambda x: x.to_global(sbp=flow.sbp.broadcast).to_local().numpy(), + ) @flow.unittest.skip_unless_1n1d() def test_tensor_with_single_int(test_case): From 2d24db41a5c7a48fee721ea0021c6eb51c966745 Mon Sep 17 00:00:00 2001 From: wyg1997 Date: Thu, 27 Oct 2022 11:07:13 +0800 Subject: [PATCH 64/66] format code --- .../oneflow/test/modules/test_initializer.py | 62 ++++++++++++++++--- .../oneflow/test/tensor/test_tensor_part_1.py | 6 +- 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/python/oneflow/test/modules/test_initializer.py b/python/oneflow/test/modules/test_initializer.py index 6976cf9e8e7..82f11a702d5 100644 --- a/python/oneflow/test/modules/test_initializer.py +++ b/python/oneflow/test/modules/test_initializer.py @@ -24,8 +24,14 @@ class DataChecker: check_list = [ - "mean", "std", "min", "max", "value", "lambda", + "mean", + "std", + "min", + "max", + "value", + "lambda", ] + def __init__(self, **kwargs): self.checkers = {} for key in self.check_list: @@ -37,26 +43,60 @@ def __call__(self, test_case, tensor): if func in self.checkers: of_res = eval(f"tensor.{func}")().numpy() checker_res = self.checkers[func] - test_case.assertTrue(np.allclose(of_res, checker_res, rtol=1e-2, atol=1e-2), f"{func} not equal, {of_res} vs {checker_res}") + test_case.assertTrue( + np.allclose(of_res, checker_res, rtol=1e-2, atol=1e-2), + f"{func} not equal, {of_res} vs {checker_res}", + ) if "value" in self.checkers: test_case.assertTrue(np.all(tensor.numpy() == self.checkers["value"])) if "lambda" in self.checkers: - test_case.assertTrue(np.allclose(tensor.numpy(), self.checkers["lambda"](tensor.shape), rtol=1e-4, atol=1e-4)) + test_case.assertTrue( + np.allclose( + tensor.numpy(), + self.checkers["lambda"](tensor.shape), + rtol=1e-4, + atol=1e-4, + ) + ) # NOTE(wyg): register initializers to this list check_func_list = [ # oneflow.nn.init.normal_ - {"func": flow.nn.init.normal_, "params": {"mean": 0.0, "std": 1.0}, "checker": DataChecker(mean=0.0, std=1.0)}, + { + "func": flow.nn.init.normal_, + "params": {"mean": 0.0, "std": 1.0}, + "checker": DataChecker(mean=0.0, std=1.0), + }, # oneflow.nn.init.xavier_normal_ - {"func": flow.nn.init.xavier_normal_, "params": {"gain": 1.0}, "checker": DataChecker(mean=0.0, std=0.0625)}, + { + "func": flow.nn.init.xavier_normal_, + "params": {"gain": 1.0}, + "checker": DataChecker(mean=0.0, std=0.0625), + }, # oneflow.nn.init.kaiming_normal_ - {"func": flow.nn.init.kaiming_normal_, "params": {"mode": "fan_in"}, "checker": DataChecker(mean=0.0, std=0.0883883476)}, - {"func": flow.nn.init.kaiming_normal_, "params": {"mode": "fan_out"}, "checker": DataChecker(mean=0.0, std=0.0883883476)}, - {"func": flow.nn.init.kaiming_normal_, "params": {"mode": "fan_in", "a": 2.0, "nonlinearity": "leaky_relu"}, "checker": DataChecker(mean=0.0, std=0.0395284708)}, - {"func": flow.nn.init.kaiming_normal_, "params": {"mode": "fan_in", "a": 2.0, "nonlinearity": "linear"}, "checker": DataChecker(mean=0.0, std=0.0625)}, + { + "func": flow.nn.init.kaiming_normal_, + "params": {"mode": "fan_in"}, + "checker": DataChecker(mean=0.0, std=0.0883883476), + }, + { + "func": flow.nn.init.kaiming_normal_, + "params": {"mode": "fan_out"}, + "checker": DataChecker(mean=0.0, std=0.0883883476), + }, + { + "func": flow.nn.init.kaiming_normal_, + "params": {"mode": "fan_in", "a": 2.0, "nonlinearity": "leaky_relu"}, + "checker": DataChecker(mean=0.0, std=0.0395284708), + }, + { + "func": flow.nn.init.kaiming_normal_, + "params": {"mode": "fan_in", "a": 2.0, "nonlinearity": "linear"}, + "checker": DataChecker(mean=0.0, std=0.0625), + }, # TODO: test more initializer ] @@ -73,7 +113,9 @@ def test_initializer(test_case): try: check_func["checker"](test_case, tensor) except AssertionError as e: - print(f"Failed: {check_func['func'].__name__} {check_func['params']}") + print( + f"Failed: {check_func['func'].__name__} {check_func['params']}" + ) raise e diff --git a/python/oneflow/test/tensor/test_tensor_part_1.py b/python/oneflow/test/tensor/test_tensor_part_1.py index d5024d845f5..e66f2566fb6 100644 --- a/python/oneflow/test/tensor/test_tensor_part_1.py +++ b/python/oneflow/test/tensor/test_tensor_part_1.py @@ -253,10 +253,12 @@ def _test_non_contiguous_tensor_init_methods(test_case, tensor_creator, get_nump def test_local_tensor_init_methods(test_case): for device in ["cpu", "cuda"]: test_case._test_tensor_init_methods( - lambda *args, **kwargs: flow.Tensor(*args, **kwargs, device=device), lambda x: x.numpy() + lambda *args, **kwargs: flow.Tensor(*args, **kwargs, device=device), + lambda x: x.numpy(), ) test_case._test_non_contiguous_tensor_init_methods( - lambda *args, **kwargs: flow.Tensor(*args, **kwargs, device=device), lambda x: x.numpy() + lambda *args, **kwargs: flow.Tensor(*args, **kwargs, device=device), + lambda x: x.numpy(), ) @flow.unittest.skip_unless_1n2d() From 8f7ca2ffbc785f5be15dbf93e330dd7a30015358 Mon Sep 17 00:00:00 2001 From: Yipeng Li Date: Thu, 27 Oct 2022 14:09:10 +0800 Subject: [PATCH 65/66] Revert Update operator.cpp This commit revert 64832e43196067d67f70094a8d35664a805a5891 --- oneflow/core/operator/operator.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 28efd312792..bc1cd858ad5 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -529,7 +529,8 @@ Maybe Operator::GetNdSbpSignatureList( int32_t sbp_dimension = parallel_desc.hierarchy()->NumAxes(); NdSbpSignature nd_sbp_sig; - SbpSignatureToNdSbpSignature(hierarchy_value2sbp_sig_list.front().sbp_signature(0), &nd_sbp_sig); + SbpSignatureToNdSbpSignature(hierarchy_value2sbp_sig_list.begin()->second.sbp_signature(0), + &nd_sbp_sig); ResizeNdSbpSignature(nd_sbp_sig, sbp_dimension); // ND sbp signature list would be direct product of 1D sbp signatures CHECK_OR_RETURN(nd_sbp_sig_list->empty()); From 4a7a6b1633a2f2b2bbdbf7a5746271f29c2741f1 Mon Sep 17 00:00:00 2001 From: daquexian Date: Fri, 28 Oct 2022 11:06:53 +0800 Subject: [PATCH 66/66] boxing to cpu first in flow.save Signed-off-by: daquexian --- python/oneflow/framework/check_point_v2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/oneflow/framework/check_point_v2.py b/python/oneflow/framework/check_point_v2.py index 8fc3036d1f9..b2f760943bd 100644 --- a/python/oneflow/framework/check_point_v2.py +++ b/python/oneflow/framework/check_point_v2.py @@ -172,7 +172,8 @@ def tensor_getstate(self): rel_dir_name = f"global_tensor_{self.global_id()}" abs_dir_name = save_load_path / rel_dir_name - tensor = self.to_global( + # Boxing to cpu firstly to avoid extra gpu memory usage + tensor = self.to_global(sbp=self.sbp, placement=flow.placement("cpu", self.placement.ranks)).to_global( sbp=flow.sbp.broadcast, placement=flow.placement("cpu", [global_src_dsk_rank]), ).to_local()