Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor get sbp signature #9304

Merged
merged 18 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions oneflow/core/framework/sbp_infer_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,14 +603,17 @@ void SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp
}

void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims,
const SbpSignatureList& sbp_sig_list,
const Shape& hierarchy,
const HashMap<int32_t, SbpSignatureList>& hierarchy_num2sbp_sig_list,
Yipeng1994 marked this conversation as resolved.
Show resolved Hide resolved
std::vector<NdSbpSignature>* nd_sbp_sig_list) {
if (depth == dims) {
nd_sbp_sig_list->push_back(nd_sbp_sig);
} else {
for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) {
for (const auto& sbp_signature :
hierarchy_num2sbp_sig_list.at(hierarchy.At(depth)).sbp_signature()) {
SetNdSbpSignature(&nd_sbp_sig, sbp_signature, depth);
DfsGetNdSbpSignature(nd_sbp_sig, depth + 1, dims, sbp_sig_list, nd_sbp_sig_list);
DfsGetNdSbpSignature(nd_sbp_sig, depth + 1, dims, hierarchy, hierarchy_num2sbp_sig_list,
nd_sbp_sig_list);
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/framework/sbp_infer_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ void SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp
int32_t sbp_axis);

void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims,
const SbpSignatureList& sbp_sig_list,
const Shape& hierarchy,
const HashMap<int32_t, SbpSignatureList>& hierarchy_num2sbp_sig_list,
std::vector<NdSbpSignature>* nd_sbp_sig_list);

// Compute storage for given NdSbp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ bool IsS0SignatureSupported(const OpNode* node) {
auto LogicalBlobDesc4Ibn = [&](const std::string& bn) -> Maybe<const BlobDesc&> {
return Maybe<const BlobDesc&>(node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi(bn)));
};
CHECK_JUST(node->op().GetSbpSignaturesIf(LogicalBlobDesc4Ibn, node->parallel_desc(), &list));
CHECK_JUST(node->op().GetSbpSignaturesIf(LogicalBlobDesc4Ibn,
node->parallel_desc().parallel_num(), &list));
const auto IsInOutS0Parallel = [&](const SbpSignature& signature) {
return IsS0Parallel(signature, node->op().SoleIbn())
&& IsS0Parallel(signature, node->op().SoleObn());
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/operator/dynamic_reshape_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class DynamicReshapeOp final : public Operator {
private:
Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const override {
SbpSignatureList* sbp_sig_list) const override {
SbpSignatureBuilder()
.Split(input_bns(), 0)
.Split(output_bns(), 0)
Expand Down Expand Up @@ -144,7 +144,7 @@ class DynamicReshapeLikeOp final : public Operator {
private:
Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const override {
SbpSignatureList* sbp_sig_list) const override {
SbpSignatureBuilder()
.Split(input_bns(), 0)
.Split(output_bns(), 0)
Expand Down
28 changes: 19 additions & 9 deletions oneflow/core/operator/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,8 @@ Maybe<const Shape> Operator::GetInputOutputFastestTimeShape() const {

Maybe<void> Operator::GetSbpSignaturesIf(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const {
JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc, sbp_sig_list));
int32_t parallel_num, SbpSignatureList* sbp_sig_list) const {
JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, parallel_num, sbp_sig_list));
SbpSignatureBuilder()
.Broadcast(input_bns())
.Broadcast(output_bns())
Expand All @@ -510,18 +510,27 @@ Maybe<void> Operator::GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {
// Get 1D sbp signature list
SbpSignatureList sbp_sig_list;
JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc, &sbp_sig_list));
CHECK_GT_OR_RETURN(sbp_sig_list.sbp_signature_size(), 0)
<< op_name() << " gets no sbp signature from GetSbpSignaturesIf function!";
HashMap<int32_t, SbpSignatureList> hierarchy_num2sbp_sig_list;
for (int32_t hierarchy_num : *parallel_desc.hierarchy()) {
if (hierarchy_num2sbp_sig_list.find(hierarchy_num) == hierarchy_num2sbp_sig_list.end()) {
Yipeng1994 marked this conversation as resolved.
Show resolved Hide resolved
auto* sbp_sig_list = &hierarchy_num2sbp_sig_list[hierarchy_num];
JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, hierarchy_num, sbp_sig_list));
CHECK_GT_OR_RETURN(sbp_sig_list->sbp_signature_size(), 0)
<< op_name()
<< " gets no sbp signature from GetSbpSignaturesIf function for hierarchy num: "
<< hierarchy_num;
}
}

int32_t sbp_dimension = parallel_desc.hierarchy()->NumAxes();
NdSbpSignature nd_sbp_sig;
SbpSignatureToNdSbpSignature(sbp_sig_list.sbp_signature(0), &nd_sbp_sig);
SbpSignatureToNdSbpSignature(hierarchy_num2sbp_sig_list.begin()->second.sbp_signature(0),
&nd_sbp_sig);
ResizeNdSbpSignature(nd_sbp_sig, sbp_dimension);
// ND sbp signature list would be direct product of 1D sbp signatures
CHECK_OR_RETURN(nd_sbp_sig_list->empty());
DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, sbp_sig_list, nd_sbp_sig_list);
DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, *parallel_desc.hierarchy(),
hierarchy_num2sbp_sig_list, nd_sbp_sig_list);
return Maybe<void>::Ok();
}

Expand Down Expand Up @@ -833,7 +842,8 @@ Maybe<void> Operator::InferSbpSignature(
SbpSignatureList valid_sbp_sig_list;
{
SbpSignatureList sbp_sig_candidates;
JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc, &sbp_sig_candidates));
JUST(
GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), &sbp_sig_candidates));
// filter sbp signatures by logical shape
JUST(FilterAndCheckValidSbpSignatureListByLogicalShape(sbp_sig_candidates, LogicalBlobDesc4Ibn,
parallel_desc, &valid_sbp_sig_list));
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/operator/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class Operator {

Maybe<void> GetSbpSignaturesIf(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const;
int32_t parallel_num, SbpSignatureList* sbp_sig_list) const;
Yipeng1994 marked this conversation as resolved.
Show resolved Hide resolved
virtual Maybe<void> GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const;
Expand Down Expand Up @@ -214,7 +214,7 @@ class Operator {
const ParallelContext* parallel_ctx, const JobDesc* job_desc) const;
virtual Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const {
int32_t parallel_num, SbpSignatureList* sbp_sig_list) const {
return GetSbpSignatures(LogicalBlobDesc4Ibn, sbp_sig_list);
}
virtual Maybe<void> GetSbpSignatures(
Expand Down
14 changes: 7 additions & 7 deletions oneflow/core/operator/user_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,9 @@ class UserOpSbpContext : public user_op::SbpContext {
using ArgVec = std::vector<std::pair<std::string, int32_t>>;

UserOpSbpContext(const UserOp* op, SbpSignatureList* sbp_sig_list,
std::function<Maybe<const BlobDesc&>(const std::string&)> LogicalBlobDesc4Ibn)
: op_(op), sbp_sig_list_(sbp_sig_list) {
std::function<Maybe<const BlobDesc&>(const std::string&)> LogicalBlobDesc4Ibn,
int32_t parallel_num)
: op_(op), sbp_sig_list_(sbp_sig_list), parallel_num_(parallel_num) {
const auto& user_op_conf = op->op_conf().user_conf();
for (auto it = user_op_conf.input().begin(); it != user_op_conf.input().end(); ++it) {
const std::string& arg_name = it->first;
Expand Down Expand Up @@ -375,14 +376,13 @@ class UserOpSbpContext : public user_op::SbpContext {

DeviceType device_type() const override { return op_->device_type(); }

int64_t parallel_num() const override {
return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num();
}
int64_t parallel_num() const override { return parallel_num_; }
Yipeng1994 marked this conversation as resolved.
Show resolved Hide resolved

private:
const UserOp* op_;
SbpSignatureList* sbp_sig_list_;
HashMap<std::pair<std::string, int32_t>, user_op::NaiveTensorDesc> arg2tensor_desc_;
int32_t parallel_num_;
};

class UserOpInferSbpSignatureFnContext : public user_op::InferSbpSignatureFnContext {
Expand Down Expand Up @@ -876,10 +876,10 @@ Maybe<void> UserOp::InferSbpSignature(

Maybe<void> UserOp::GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const {
int32_t parallel_num, SbpSignatureList* sbp_sig_list) const {
CHECK_OR_RETURN(val_ != nullptr)
<< "cannot find op_type: " << op_conf().user_conf().op_type_name() << " in op registry!";
UserOpSbpContext sbp_ctx(this, sbp_sig_list, LogicalBlobDesc4Ibn);
UserOpSbpContext sbp_ctx(this, sbp_sig_list, LogicalBlobDesc4Ibn, parallel_num);
JUST(val_->get_sbp_fn(&sbp_ctx));
// Add Broadcast for source user op tick input
if (val_->op_def.input_size() == 1 && input_bns().size() == 1
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/operator/user_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class UserOp final : public Operator {
const ParallelDesc& parallel_desc) const override;
Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const override;
int32_t parallel_num, SbpSignatureList* sbp_sig_list) const override;
Maybe<void> GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc,
Expand Down
99 changes: 0 additions & 99 deletions oneflow/user/ops/reshape_user_op_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,103 +174,4 @@ Maybe<void> ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures(
return Maybe<void>::Ok();
}

namespace {

Maybe<void> GetInputNdSbp(user_op::InferNdSbpFnContext* ctx, const user_op::OpArg& in_arg,
NdSbp* distribution) {
*distribution = ctx->NdSbpHint4InputArgNameAndIndex(in_arg.name(), in_arg.index());
const auto& constraints = ctx->nd_sbp_constraints();
if (constraints.bn_in_op2nd_sbp_size() != 0) {
const auto it =
constraints.bn_in_op2nd_sbp().find(GenRepeatedBn(in_arg.name(), in_arg.index()));
if (it != constraints.bn_in_op2nd_sbp().end()) { *distribution = it->second; }
}
return Maybe<void>::Ok();
}

Maybe<void> ApplySbpParallel(const SbpParallel& sbp, const int64_t parallel_num, Shape* shape) {
if (sbp.has_split_parallel()) {
const int64_t axis = sbp.split_parallel().axis();
CHECK_EQ_OR_RETURN(shape->At(axis) % parallel_num, 0)
<< Error::RuntimeError() << "The size of tensor in the " << axis
<< " must be an integer multiple of parallel_num, "
<< "but got " << shape->At(axis) << " and " << parallel_num;
shape->Set(axis, shape->At(axis) / parallel_num);
}
return Maybe<void>::Ok();
}

} // namespace

Maybe<void> ReshapeUserOpUtil::InferNdSbp(user_op::InferNdSbpFnContext* ctx,
const Shape& logical_in_shape,
const Shape& logical_out_shape) {
const std::string& op_type_name = ctx->user_op_conf().op_type_name();
CHECK_OR_RETURN(op_type_name == "reshape" || op_type_name == "reshape_like")
<< Error::RuntimeError() << "The op_type_name must be \"reshape\" or \"reshape_like\", "
<< "but got " << op_type_name;
const bool is_reshape_like = (op_type_name == "reshape_like");
std::vector<user_op::OpArg> in_args({{"in", 0}});
if (is_reshape_like) { in_args.emplace_back(user_op::OpArg("like", 0)); }
HashMap<std::string, NdSbp> ibn2nd_sbp;
ibn2nd_sbp.reserve(in_args.size());
for (const auto& arg : in_args) {
NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex(arg.name(), arg.index());
JUST(GetInputNdSbp(ctx, arg, in_distribution));
CHECK_OR_RETURN(
ibn2nd_sbp.emplace(GenRepeatedBn(arg.name(), arg.index()), *in_distribution).second)
<< "emplace error"; // NOLINT(maybe-need-error-msg)
}
NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0);

Shape in_shape = logical_in_shape;
Shape out_shape = logical_out_shape;
const Shape& parallel_hierarchy = ctx->parallel_hierarchy();
for (int64_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) {
SbpSignatureList sbp_sig_list;
user_op::UserOpSbpSignatureBuilder builder(&sbp_sig_list);
builder.Broadcast(in_args).Broadcast(user_op::OpArg("out", 0)).Build();
if (is_reshape_like) {
builder.PartialSum(user_op::OpArg("like", 0))
.Broadcast(user_op::OpArg("in", 0))
.Broadcast(user_op::OpArg("out", 0))
.Build();
builder.Broadcast(user_op::OpArg("like", 0))
.PartialSum(user_op::OpArg("in", 0))
.PartialSum(user_op::OpArg("out", 0))
.Build();
JUST(GetReshapeUserOpSbpSignatures(in_shape, out_shape, {{"in", 0}},
{{"like", 0}, {"out", 0}}, parallel_hierarchy.At(i),
&builder));
} else {
JUST(GetReshapeUserOpSbpSignatures(in_shape, out_shape, {{"in", 0}}, {{"out", 0}},
parallel_hierarchy.At(i), &builder));
}

const SbpSignature* matched_sbp_signature = nullptr;
for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) {
bool all_match = true;
for (const auto& in_arg : in_args) {
std::string ibn = GenRepeatedBn(in_arg.name(), in_arg.index());
if (sbp_signature.bn_in_op2sbp_parallel().at(ibn) != ibn2nd_sbp.at(ibn).sbp_parallel(i)) {
all_match = false;
break;
}
}
if (all_match) {
matched_sbp_signature = &sbp_signature;
break;
}
}
CHECK_OR_RETURN(matched_sbp_signature != nullptr)
<< "FusedLstmCellGrad::Pointer to the matched sbp signature is nullptr";
SbpParallel out_sbp = matched_sbp_signature->bn_in_op2sbp_parallel().at("out_0");
JUST(ApplySbpParallel(matched_sbp_signature->bn_in_op2sbp_parallel().at("in_0"),
parallel_hierarchy.At(i), &in_shape));
JUST(ApplySbpParallel(out_sbp, parallel_hierarchy.At(i), &out_shape));
*(out_distribution->add_sbp_parallel()) = out_sbp;
}
return Maybe<void>::Ok();
}

} // namespace oneflow