diff --git a/oneflow/core/framework/sbp_context.h b/oneflow/core/framework/sbp_context.h index 680ae94d5ba..bac1db925f6 100644 --- a/oneflow/core/framework/sbp_context.h +++ b/oneflow/core/framework/sbp_context.h @@ -76,6 +76,12 @@ class SbpContext : public SbpContextBase { SbpContext() = default; ~SbpContext() override = default; + // hierarchy value is the value at the dimension corresponding to the current SBP + // For example, 2 machines, 4 gpus per machine, hierarchy = [2, 4] + // Suppose we have nd_sbp = (S0, B) + // The hierarchy value corresponding to S0 is 2 + // The hierarchy value corresponding to B is 4. + virtual int64_t hierarchy_value() const = 0; virtual UserOpSbpSignatureBuilder NewBuilder() = 0; }; diff --git a/oneflow/core/framework/sbp_infer_util.cpp b/oneflow/core/framework/sbp_infer_util.cpp index 3ec7562dd51..f3a25167c97 100644 --- a/oneflow/core/framework/sbp_infer_util.cpp +++ b/oneflow/core/framework/sbp_infer_util.cpp @@ -603,14 +603,17 @@ void SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp } void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims, - const SbpSignatureList& sbp_sig_list, + const Shape& hierarchy, + const HashMap& hierarchy_value2sbp_sig_list, std::vector* nd_sbp_sig_list) { if (depth == dims) { nd_sbp_sig_list->push_back(nd_sbp_sig); } else { - for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) { + for (const auto& sbp_signature : + hierarchy_value2sbp_sig_list.at(hierarchy.At(depth)).sbp_signature()) { SetNdSbpSignature(&nd_sbp_sig, sbp_signature, depth); - DfsGetNdSbpSignature(nd_sbp_sig, depth + 1, dims, sbp_sig_list, nd_sbp_sig_list); + DfsGetNdSbpSignature(nd_sbp_sig, depth + 1, dims, hierarchy, hierarchy_value2sbp_sig_list, + nd_sbp_sig_list); } } } diff --git a/oneflow/core/framework/sbp_infer_util.h b/oneflow/core/framework/sbp_infer_util.h index 21d7da6ae90..fabb13edbfa 100644 --- a/oneflow/core/framework/sbp_infer_util.h +++ b/oneflow/core/framework/sbp_infer_util.h @@ -62,7 +62,8 @@ void SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp int32_t sbp_axis); void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims, - const SbpSignatureList& sbp_sig_list, + const Shape& hierarchy, + const HashMap& hierarchy_value2sbp_sig_list, std::vector* nd_sbp_sig_list); // Compute storage for given NdSbp diff --git a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp index 2c6e16a8bb8..e62d2e7f3c5 100644 --- a/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp +++ b/oneflow/core/job_rewriter/optimizer_placement_optimization_pass.cpp @@ -297,7 +297,8 @@ bool IsS0SignatureSupported(const OpNode* node) { auto LogicalBlobDesc4Ibn = [&](const std::string& bn) -> Maybe { return Maybe(node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi(bn))); }; - CHECK_JUST(node->op().GetSbpSignaturesIf(LogicalBlobDesc4Ibn, node->parallel_desc(), &list)); + CHECK_JUST(node->op().GetSbpSignaturesIf(LogicalBlobDesc4Ibn, + node->parallel_desc().parallel_num(), &list)); const auto IsInOutS0Parallel = [&](const SbpSignature& signature) { return IsS0Parallel(signature, node->op().SoleIbn()) && IsS0Parallel(signature, node->op().SoleObn()); diff --git a/oneflow/core/operator/dynamic_reshape_op.cpp b/oneflow/core/operator/dynamic_reshape_op.cpp index 34e90416d96..72ee5dc47c8 100644 --- a/oneflow/core/operator/dynamic_reshape_op.cpp +++ b/oneflow/core/operator/dynamic_reshape_op.cpp @@ -104,7 +104,7 @@ class DynamicReshapeOp final : public Operator { private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const override { + SbpSignatureList* sbp_sig_list) const override { SbpSignatureBuilder() .Split(input_bns(), 0) .Split(output_bns(), 0) @@ -144,7 +144,7 @@ class DynamicReshapeLikeOp final : public Operator { private: Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const override { + SbpSignatureList* sbp_sig_list) const override { SbpSignatureBuilder() .Split(input_bns(), 0) .Split(output_bns(), 0) diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index a48f166e278..bc1cd858ad5 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -497,8 +497,8 @@ Maybe Operator::GetInputOutputFastestTimeShape() const { Maybe Operator::GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const { - JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc, sbp_sig_list)); + int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const { + JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, hierarchy_value, sbp_sig_list)); SbpSignatureBuilder() .Broadcast(input_bns()) .Broadcast(output_bns()) @@ -510,18 +510,32 @@ Maybe Operator::GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const { // Get 1D sbp signature list - SbpSignatureList sbp_sig_list; - JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc, &sbp_sig_list)); - CHECK_GT_OR_RETURN(sbp_sig_list.sbp_signature_size(), 0) - << op_name() << " gets no sbp signature from GetSbpSignaturesIf function!"; + HashMap hierarchy_value2sbp_sig_list; + // hierarchy value is the value at the dimension corresponding to the current SBP + // For example, 2 machines, 4 gpus per machine, hierarchy = [2, 4] + // Suppose we have nd_sbp = (S0, B) + // The hierarchy value corresponding to S0 is 2 + // The hierarchy value corresponding to B is 4. + for (int32_t hierarchy_value : *parallel_desc.hierarchy()) { + if (hierarchy_value2sbp_sig_list.find(hierarchy_value) == hierarchy_value2sbp_sig_list.end()) { + auto* sbp_sig_list = &hierarchy_value2sbp_sig_list[hierarchy_value]; + JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, hierarchy_value, sbp_sig_list)); + CHECK_GT_OR_RETURN(sbp_sig_list->sbp_signature_size(), 0) + << op_name() + << " gets no sbp signature from GetSbpSignaturesIf function for hierarchy value: " + << hierarchy_value; + } + } int32_t sbp_dimension = parallel_desc.hierarchy()->NumAxes(); NdSbpSignature nd_sbp_sig; - SbpSignatureToNdSbpSignature(sbp_sig_list.sbp_signature(0), &nd_sbp_sig); + SbpSignatureToNdSbpSignature(hierarchy_value2sbp_sig_list.begin()->second.sbp_signature(0), + &nd_sbp_sig); ResizeNdSbpSignature(nd_sbp_sig, sbp_dimension); // ND sbp signature list would be direct product of 1D sbp signatures CHECK_OR_RETURN(nd_sbp_sig_list->empty()); - DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, sbp_sig_list, nd_sbp_sig_list); + DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, *parallel_desc.hierarchy(), + hierarchy_value2sbp_sig_list, nd_sbp_sig_list); return Maybe::Ok(); } @@ -833,7 +847,9 @@ Maybe Operator::InferSbpSignature( SbpSignatureList valid_sbp_sig_list; { SbpSignatureList sbp_sig_candidates; - JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc, &sbp_sig_candidates)); + // For 1d sbp, hierarchy value = parallel num + JUST( + GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), &sbp_sig_candidates)); // filter sbp signatures by logical shape JUST(FilterAndCheckValidSbpSignatureListByLogicalShape(sbp_sig_candidates, LogicalBlobDesc4Ibn, parallel_desc, &valid_sbp_sig_list)); diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h index afed6ba8d2f..64c82a9baf1 100644 --- a/oneflow/core/operator/operator.h +++ b/oneflow/core/operator/operator.h @@ -172,7 +172,7 @@ class Operator { Maybe GetSbpSignaturesIf( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const; + int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const; virtual Maybe GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, std::vector* nd_sbp_sig_list) const; @@ -214,7 +214,7 @@ class Operator { const ParallelContext* parallel_ctx, const JobDesc* job_desc) const; virtual Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const { + int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const { return GetSbpSignatures(LogicalBlobDesc4Ibn, sbp_sig_list); } virtual Maybe GetSbpSignatures( diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index db311f600ad..3845da8f566 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -345,8 +345,9 @@ class UserOpSbpContext : public user_op::SbpContext { using ArgVec = std::vector>; UserOpSbpContext(const UserOp* op, SbpSignatureList* sbp_sig_list, - std::function(const std::string&)> LogicalBlobDesc4Ibn) - : op_(op), sbp_sig_list_(sbp_sig_list) { + std::function(const std::string&)> LogicalBlobDesc4Ibn, + int32_t hierarchy_value) + : op_(op), sbp_sig_list_(sbp_sig_list), hierarchy_value_(hierarchy_value) { const auto& user_op_conf = op->op_conf().user_conf(); for (auto it = user_op_conf.input().begin(); it != user_op_conf.input().end(); ++it) { const std::string& arg_name = it->first; @@ -379,10 +380,13 @@ class UserOpSbpContext : public user_op::SbpContext { return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num(); } + int64_t hierarchy_value() const override { return hierarchy_value_; } + private: const UserOp* op_; SbpSignatureList* sbp_sig_list_; HashMap, user_op::NaiveTensorDesc> arg2tensor_desc_; + int32_t hierarchy_value_; }; class UserOpInferSbpSignatureFnContext : public user_op::InferSbpSignatureFnContext { @@ -876,10 +880,10 @@ Maybe UserOp::InferSbpSignature( Maybe UserOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const { + int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const { CHECK_OR_RETURN(val_ != nullptr) << "cannot find op_type: " << op_conf().user_conf().op_type_name() << " in op registry!"; - UserOpSbpContext sbp_ctx(this, sbp_sig_list, LogicalBlobDesc4Ibn); + UserOpSbpContext sbp_ctx(this, sbp_sig_list, LogicalBlobDesc4Ibn, hierarchy_value); JUST(val_->get_sbp_fn(&sbp_ctx)); // Add Broadcast for source user op tick input if (val_->op_def.input_size() == 1 && input_bns().size() == 1 diff --git a/oneflow/core/operator/user_op.h b/oneflow/core/operator/user_op.h index d0f39c8fce1..5399438fd4e 100644 --- a/oneflow/core/operator/user_op.h +++ b/oneflow/core/operator/user_op.h @@ -64,7 +64,7 @@ class UserOp final : public Operator { const ParallelDesc& parallel_desc) const override; Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, - const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const override; + int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const override; Maybe GetNdSbpSignatureList( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, const ParallelDesc& parallel_desc, diff --git a/oneflow/user/ops/reshape_op.cpp b/oneflow/user/ops/reshape_op.cpp index 822c7f90bf7..916b66089f5 100644 --- a/oneflow/user/ops/reshape_op.cpp +++ b/oneflow/user/ops/reshape_op.cpp @@ -27,7 +27,7 @@ namespace oneflow { const auto& outshape = JUST(ReshapeUserOpUtil::GetLogicalOutBlobShape(in_shape, shape)); user_op::UserOpSbpSignatureBuilder builder = ctx->NewBuilder(); return ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures( - in_shape, *outshape, {{"in", 0}}, {{"out", 0}}, ctx->parallel_num(), &builder); + in_shape, *outshape, {{"in", 0}}, {{"out", 0}}, ctx->hierarchy_value(), &builder); } /*static*/ Maybe ReshapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { diff --git a/oneflow/user/ops/reshape_user_op_util.cpp b/oneflow/user/ops/reshape_user_op_util.cpp index 32fab5354e9..a5597e4ccab 100644 --- a/oneflow/user/ops/reshape_user_op_util.cpp +++ b/oneflow/user/ops/reshape_user_op_util.cpp @@ -92,7 +92,7 @@ Maybe ReshapeUserOpUtil::Squeeze(const Shape& origin, Shape* shape, } Maybe ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis( - const Shape& in_shape, const Shape& out_shape, const int64_t parallel_num, + const Shape& in_shape, const Shape& out_shape, const int64_t hierarchy_value, HashMap* group_start_in_axis2out_axis) { CHECK_GE_OR_RETURN(in_shape.NumAxes(), 0) << Error::RuntimeError() @@ -128,8 +128,8 @@ Maybe ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis( if (in_shape_count == out_shape_count) { // Record split axises if (in_shape.At(in_axis) == out_shape.At(out_axis) - || (in_shape.At(in_axis) % parallel_num == 0 - && out_shape.At(out_axis) % parallel_num == 0)) { + || (in_shape.At(in_axis) % hierarchy_value == 0 + && out_shape.At(out_axis) % hierarchy_value == 0)) { (*group_start_in_axis2out_axis)[in_axis] = out_axis; } // Move forward @@ -146,8 +146,8 @@ Maybe ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis( } Maybe ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures( - const Shape& in_shape, const Shape& out_shape, std::vector in_args, - std::vector out_args, const int64_t parallel_num, + const Shape& in_shape, const Shape& out_shape, const std::vector& in_args, + const std::vector& out_args, const int64_t hierarchy_value, user_op::UserOpSbpSignatureBuilder* builder) { if (in_shape.NumAxes() == 0 || in_shape.elem_cnt() == 0) { return Maybe::Ok(); @@ -162,7 +162,7 @@ Maybe ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures( JUST(ReshapeUserOpUtil::Squeeze(out_shape, &squeezed_out_shape, &out_squeezed_axis2original_axis)); JUST(ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis(squeezed_in_shape, squeezed_out_shape, - parallel_num, + hierarchy_value, &squeezed_group_start_in_axis2out_axis)); } for (const auto& pair : squeezed_group_start_in_axis2out_axis) { @@ -174,103 +174,4 @@ Maybe ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures( return Maybe::Ok(); } -namespace { - -Maybe GetInputNdSbp(user_op::InferNdSbpFnContext* ctx, const user_op::OpArg& in_arg, - NdSbp* distribution) { - *distribution = ctx->NdSbpHint4InputArgNameAndIndex(in_arg.name(), in_arg.index()); - const auto& constraints = ctx->nd_sbp_constraints(); - if (constraints.bn_in_op2nd_sbp_size() != 0) { - const auto it = - constraints.bn_in_op2nd_sbp().find(GenRepeatedBn(in_arg.name(), in_arg.index())); - if (it != constraints.bn_in_op2nd_sbp().end()) { *distribution = it->second; } - } - return Maybe::Ok(); -} - -Maybe ApplySbpParallel(const SbpParallel& sbp, const int64_t parallel_num, Shape* shape) { - if (sbp.has_split_parallel()) { - const int64_t axis = sbp.split_parallel().axis(); - CHECK_EQ_OR_RETURN(shape->At(axis) % parallel_num, 0) - << Error::RuntimeError() << "The size of tensor in the " << axis - << " must be an integer multiple of parallel_num, " - << "but got " << shape->At(axis) << " and " << parallel_num; - shape->Set(axis, shape->At(axis) / parallel_num); - } - return Maybe::Ok(); -} - -} // namespace - -Maybe ReshapeUserOpUtil::InferNdSbp(user_op::InferNdSbpFnContext* ctx, - const Shape& logical_in_shape, - const Shape& logical_out_shape) { - const std::string& op_type_name = ctx->user_op_conf().op_type_name(); - CHECK_OR_RETURN(op_type_name == "reshape" || op_type_name == "reshape_like") - << Error::RuntimeError() << "The op_type_name must be \"reshape\" or \"reshape_like\", " - << "but got " << op_type_name; - const bool is_reshape_like = (op_type_name == "reshape_like"); - std::vector in_args({{"in", 0}}); - if (is_reshape_like) { in_args.emplace_back(user_op::OpArg("like", 0)); } - HashMap ibn2nd_sbp; - ibn2nd_sbp.reserve(in_args.size()); - for (const auto& arg : in_args) { - NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex(arg.name(), arg.index()); - JUST(GetInputNdSbp(ctx, arg, in_distribution)); - CHECK_OR_RETURN( - ibn2nd_sbp.emplace(GenRepeatedBn(arg.name(), arg.index()), *in_distribution).second) - << "emplace error"; // NOLINT(maybe-need-error-msg) - } - NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - - Shape in_shape = logical_in_shape; - Shape out_shape = logical_out_shape; - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - for (int64_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - SbpSignatureList sbp_sig_list; - user_op::UserOpSbpSignatureBuilder builder(&sbp_sig_list); - builder.Broadcast(in_args).Broadcast(user_op::OpArg("out", 0)).Build(); - if (is_reshape_like) { - builder.PartialSum(user_op::OpArg("like", 0)) - .Broadcast(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - builder.Broadcast(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - JUST(GetReshapeUserOpSbpSignatures(in_shape, out_shape, {{"in", 0}}, - {{"like", 0}, {"out", 0}}, parallel_hierarchy.At(i), - &builder)); - } else { - JUST(GetReshapeUserOpSbpSignatures(in_shape, out_shape, {{"in", 0}}, {{"out", 0}}, - parallel_hierarchy.At(i), &builder)); - } - - const SbpSignature* matched_sbp_signature = nullptr; - for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) { - bool all_match = true; - for (const auto& in_arg : in_args) { - std::string ibn = GenRepeatedBn(in_arg.name(), in_arg.index()); - if (sbp_signature.bn_in_op2sbp_parallel().at(ibn) != ibn2nd_sbp.at(ibn).sbp_parallel(i)) { - all_match = false; - break; - } - } - if (all_match) { - matched_sbp_signature = &sbp_signature; - break; - } - } - CHECK_OR_RETURN(matched_sbp_signature != nullptr) - << "FusedLstmCellGrad::Pointer to the matched sbp signature is nullptr"; - SbpParallel out_sbp = matched_sbp_signature->bn_in_op2sbp_parallel().at("out_0"); - JUST(ApplySbpParallel(matched_sbp_signature->bn_in_op2sbp_parallel().at("in_0"), - parallel_hierarchy.At(i), &in_shape)); - JUST(ApplySbpParallel(out_sbp, parallel_hierarchy.At(i), &out_shape)); - *(out_distribution->add_sbp_parallel()) = out_sbp; - } - return Maybe::Ok(); -} - } // namespace oneflow diff --git a/oneflow/user/ops/reshape_user_op_util.h b/oneflow/user/ops/reshape_user_op_util.h index 8178803fbc1..b1fa968f3d0 100644 --- a/oneflow/user/ops/reshape_user_op_util.h +++ b/oneflow/user/ops/reshape_user_op_util.h @@ -26,15 +26,13 @@ struct ReshapeUserOpUtil { static Maybe Squeeze(const Shape& origin, Shape* shape, HashMap* squeezed_axis2origin_axis); static Maybe GetGroupStartInAxis2OutAxis(const Shape& in_shape, const Shape& out_shape, - const int64_t parallel_num, + const int64_t hierarchy_value, HashMap* group_start_in_axis2out_axis); static Maybe GetReshapeUserOpSbpSignatures(const Shape& in_shape, const Shape& out_shape, - std::vector in_args, - std::vector out_args, - const int64_t parallel_num, + const std::vector& in_args, + const std::vector& out_args, + const int64_t hierarchy_value, user_op::UserOpSbpSignatureBuilder* builder); - static Maybe InferNdSbp(user_op::InferNdSbpFnContext* ctx, const Shape& logical_in_shape, - const Shape& logical_out_shape); }; } // namespace oneflow