Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor get sbp signature #9304

Merged
merged 18 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions oneflow/core/framework/sbp_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ class SbpContext : public SbpContextBase {
SbpContext() = default;
~SbpContext() override = default;

// hierarchy value is the value at the dimension corresponding to the current SBP
// For example, 2 machines, 4 gpus per machine, hierarchy = [2, 4]
// Suppose we have nd_sbp = (S0, B)
// The hierarchy value corresponding to S0 is 2
// The hierarchy value corresponding to B is 4.
virtual int64_t hierarchy_value() const = 0;
Copy link
Contributor

@leaves-zwx leaves-zwx Oct 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hierarchy_value 这个名字我觉得还是有点怪,hierarchy 本身就不是太精确,后来 pytorch 取了 device mesh 这个名字。你的专业英文比较好,感觉应该能想出更加精确的名字来?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hierarchy直译就是分级的意思,emmm,hierarchy的一个分量,就直译了hierarchy value了,从一开始我其实也不知道取什么名,第一版随便取了一个hierarchy num,然后现在改成value,阔以说只要oneflow内部还在用着这个hierarchy,基本上没什么好的选择了

我记得以前还在周会上讨论过这个问题,当时也有建议过mesh,后来决定在外部用rank,内部就无所谓了

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我记得以前还在周会上讨论过这个问题,当时也有建议过mesh,后来决定在外部用rank,内部就无所谓了

我同意在未来将 hierarchy 重命名为 device mesh

virtual UserOpSbpSignatureBuilder NewBuilder() = 0;
};

Expand Down
9 changes: 6 additions & 3 deletions oneflow/core/framework/sbp_infer_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,14 +603,17 @@ void SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp
}

void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims,
const SbpSignatureList& sbp_sig_list,
const Shape& hierarchy,
const HashMap<int32_t, SbpSignatureList>& hierarchy_value2sbp_sig_list,
std::vector<NdSbpSignature>* nd_sbp_sig_list) {
if (depth == dims) {
nd_sbp_sig_list->push_back(nd_sbp_sig);
} else {
for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) {
for (const auto& sbp_signature :
hierarchy_value2sbp_sig_list.at(hierarchy.At(depth)).sbp_signature()) {
SetNdSbpSignature(&nd_sbp_sig, sbp_signature, depth);
DfsGetNdSbpSignature(nd_sbp_sig, depth + 1, dims, sbp_sig_list, nd_sbp_sig_list);
DfsGetNdSbpSignature(nd_sbp_sig, depth + 1, dims, hierarchy, hierarchy_value2sbp_sig_list,
nd_sbp_sig_list);
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/framework/sbp_infer_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ void SetNdSbpSignature(NdSbpSignature* nd_sbp_signature, const SbpSignature& sbp
int32_t sbp_axis);

void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dims,
const SbpSignatureList& sbp_sig_list,
const Shape& hierarchy,
const HashMap<int32_t, SbpSignatureList>& hierarchy_value2sbp_sig_list,
std::vector<NdSbpSignature>* nd_sbp_sig_list);

// Compute storage for given NdSbp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ bool IsS0SignatureSupported(const OpNode* node) {
auto LogicalBlobDesc4Ibn = [&](const std::string& bn) -> Maybe<const BlobDesc&> {
return Maybe<const BlobDesc&>(node->LogicalBlobDesc4Lbi(node->op().BnInOp2Lbi(bn)));
};
CHECK_JUST(node->op().GetSbpSignaturesIf(LogicalBlobDesc4Ibn, node->parallel_desc(), &list));
CHECK_JUST(node->op().GetSbpSignaturesIf(LogicalBlobDesc4Ibn,
node->parallel_desc().parallel_num(), &list));
const auto IsInOutS0Parallel = [&](const SbpSignature& signature) {
return IsS0Parallel(signature, node->op().SoleIbn())
&& IsS0Parallel(signature, node->op().SoleObn());
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/operator/dynamic_reshape_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class DynamicReshapeOp final : public Operator {
private:
Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const override {
SbpSignatureList* sbp_sig_list) const override {
SbpSignatureBuilder()
.Split(input_bns(), 0)
.Split(output_bns(), 0)
Expand Down Expand Up @@ -144,7 +144,7 @@ class DynamicReshapeLikeOp final : public Operator {
private:
Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const override {
SbpSignatureList* sbp_sig_list) const override {
SbpSignatureBuilder()
.Split(input_bns(), 0)
.Split(output_bns(), 0)
Expand Down
34 changes: 25 additions & 9 deletions oneflow/core/operator/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,8 @@ Maybe<const Shape> Operator::GetInputOutputFastestTimeShape() const {

Maybe<void> Operator::GetSbpSignaturesIf(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const {
JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc, sbp_sig_list));
int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const {
JUST(GetSbpSignatures(LogicalBlobDesc4Ibn, hierarchy_value, sbp_sig_list));
SbpSignatureBuilder()
.Broadcast(input_bns())
.Broadcast(output_bns())
Expand All @@ -510,18 +510,32 @@ Maybe<void> Operator::GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {
// Get 1D sbp signature list
SbpSignatureList sbp_sig_list;
JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc, &sbp_sig_list));
CHECK_GT_OR_RETURN(sbp_sig_list.sbp_signature_size(), 0)
<< op_name() << " gets no sbp signature from GetSbpSignaturesIf function!";
HashMap<int32_t, SbpSignatureList> hierarchy_value2sbp_sig_list;
// hierarchy value is the value at the dimension corresponding to the current SBP
// For example, 2 machines, 4 gpus per machine, hierarchy = [2, 4]
// Suppose we have nd_sbp = (S0, B)
// The hierarchy value corresponding to S0 is 2
// The hierarchy value corresponding to B is 4.
for (int32_t hierarchy_value : *parallel_desc.hierarchy()) {
if (hierarchy_value2sbp_sig_list.find(hierarchy_value) == hierarchy_value2sbp_sig_list.end()) {
auto* sbp_sig_list = &hierarchy_value2sbp_sig_list[hierarchy_value];
JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, hierarchy_value, sbp_sig_list));
CHECK_GT_OR_RETURN(sbp_sig_list->sbp_signature_size(), 0)
<< op_name()
<< " gets no sbp signature from GetSbpSignaturesIf function for hierarchy value: "
<< hierarchy_value;
}
}

int32_t sbp_dimension = parallel_desc.hierarchy()->NumAxes();
NdSbpSignature nd_sbp_sig;
SbpSignatureToNdSbpSignature(sbp_sig_list.sbp_signature(0), &nd_sbp_sig);
SbpSignatureToNdSbpSignature(hierarchy_value2sbp_sig_list.begin()->second.sbp_signature(0),
Yipeng1994 marked this conversation as resolved.
Show resolved Hide resolved
&nd_sbp_sig);
ResizeNdSbpSignature(nd_sbp_sig, sbp_dimension);
// ND sbp signature list would be direct product of 1D sbp signatures
CHECK_OR_RETURN(nd_sbp_sig_list->empty());
DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, sbp_sig_list, nd_sbp_sig_list);
DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, *parallel_desc.hierarchy(),
hierarchy_value2sbp_sig_list, nd_sbp_sig_list);
return Maybe<void>::Ok();
}

Expand Down Expand Up @@ -833,7 +847,9 @@ Maybe<void> Operator::InferSbpSignature(
SbpSignatureList valid_sbp_sig_list;
{
SbpSignatureList sbp_sig_candidates;
JUST(GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc, &sbp_sig_candidates));
// For 1d sbp, hierarchy value = parallel num
JUST(
GetSbpSignaturesIf(LogicalBlobDesc4Ibn, parallel_desc.parallel_num(), &sbp_sig_candidates));
// filter sbp signatures by logical shape
JUST(FilterAndCheckValidSbpSignatureListByLogicalShape(sbp_sig_candidates, LogicalBlobDesc4Ibn,
parallel_desc, &valid_sbp_sig_list));
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/operator/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class Operator {

Maybe<void> GetSbpSignaturesIf(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const;
int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const;
virtual Maybe<void> GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const;
Expand Down Expand Up @@ -214,7 +214,7 @@ class Operator {
const ParallelContext* parallel_ctx, const JobDesc* job_desc) const;
virtual Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const {
int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const {
return GetSbpSignatures(LogicalBlobDesc4Ibn, sbp_sig_list);
}
virtual Maybe<void> GetSbpSignatures(
Expand Down
12 changes: 8 additions & 4 deletions oneflow/core/operator/user_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,9 @@ class UserOpSbpContext : public user_op::SbpContext {
using ArgVec = std::vector<std::pair<std::string, int32_t>>;

UserOpSbpContext(const UserOp* op, SbpSignatureList* sbp_sig_list,
std::function<Maybe<const BlobDesc&>(const std::string&)> LogicalBlobDesc4Ibn)
: op_(op), sbp_sig_list_(sbp_sig_list) {
std::function<Maybe<const BlobDesc&>(const std::string&)> LogicalBlobDesc4Ibn,
int32_t hierarchy_value)
: op_(op), sbp_sig_list_(sbp_sig_list), hierarchy_value_(hierarchy_value) {
const auto& user_op_conf = op->op_conf().user_conf();
for (auto it = user_op_conf.input().begin(); it != user_op_conf.input().end(); ++it) {
const std::string& arg_name = it->first;
Expand Down Expand Up @@ -379,10 +380,13 @@ class UserOpSbpContext : public user_op::SbpContext {
return CHECK_JUST(op_->GetOpParallelDesc())->parallel_num();
}

int64_t hierarchy_value() const override { return hierarchy_value_; }

private:
const UserOp* op_;
SbpSignatureList* sbp_sig_list_;
HashMap<std::pair<std::string, int32_t>, user_op::NaiveTensorDesc> arg2tensor_desc_;
int32_t hierarchy_value_;
};

class UserOpInferSbpSignatureFnContext : public user_op::InferSbpSignatureFnContext {
Expand Down Expand Up @@ -876,10 +880,10 @@ Maybe<void> UserOp::InferSbpSignature(

Maybe<void> UserOp::GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const {
int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const {
CHECK_OR_RETURN(val_ != nullptr)
<< "cannot find op_type: " << op_conf().user_conf().op_type_name() << " in op registry!";
UserOpSbpContext sbp_ctx(this, sbp_sig_list, LogicalBlobDesc4Ibn);
UserOpSbpContext sbp_ctx(this, sbp_sig_list, LogicalBlobDesc4Ibn, hierarchy_value);
JUST(val_->get_sbp_fn(&sbp_ctx));
// Add Broadcast for source user op tick input
if (val_->op_def.input_size() == 1 && input_bns().size() == 1
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/operator/user_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class UserOp final : public Operator {
const ParallelDesc& parallel_desc) const override;
Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, SbpSignatureList* sbp_sig_list) const override;
int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const override;
Maybe<void> GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc,
Expand Down
2 changes: 1 addition & 1 deletion oneflow/user/ops/reshape_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace oneflow {
const auto& outshape = JUST(ReshapeUserOpUtil::GetLogicalOutBlobShape(in_shape, shape));
user_op::UserOpSbpSignatureBuilder builder = ctx->NewBuilder();
return ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures(
in_shape, *outshape, {{"in", 0}}, {{"out", 0}}, ctx->parallel_num(), &builder);
in_shape, *outshape, {{"in", 0}}, {{"out", 0}}, ctx->hierarchy_value(), &builder);
}

/*static*/ Maybe<void> ReshapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
Expand Down
111 changes: 6 additions & 105 deletions oneflow/user/ops/reshape_user_op_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ Maybe<void> ReshapeUserOpUtil::Squeeze(const Shape& origin, Shape* shape,
}

Maybe<void> ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis(
const Shape& in_shape, const Shape& out_shape, const int64_t parallel_num,
const Shape& in_shape, const Shape& out_shape, const int64_t hierarchy_value,
HashMap<int, int>* group_start_in_axis2out_axis) {
CHECK_GE_OR_RETURN(in_shape.NumAxes(), 0)
<< Error::RuntimeError()
Expand Down Expand Up @@ -128,8 +128,8 @@ Maybe<void> ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis(
if (in_shape_count == out_shape_count) {
// Record split axises
if (in_shape.At(in_axis) == out_shape.At(out_axis)
|| (in_shape.At(in_axis) % parallel_num == 0
&& out_shape.At(out_axis) % parallel_num == 0)) {
|| (in_shape.At(in_axis) % hierarchy_value == 0
&& out_shape.At(out_axis) % hierarchy_value == 0)) {
(*group_start_in_axis2out_axis)[in_axis] = out_axis;
}
// Move forward
Expand All @@ -146,8 +146,8 @@ Maybe<void> ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis(
}

Maybe<void> ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures(
const Shape& in_shape, const Shape& out_shape, std::vector<user_op::OpArg> in_args,
std::vector<user_op::OpArg> out_args, const int64_t parallel_num,
const Shape& in_shape, const Shape& out_shape, const std::vector<user_op::OpArg>& in_args,
const std::vector<user_op::OpArg>& out_args, const int64_t hierarchy_value,
user_op::UserOpSbpSignatureBuilder* builder) {
if (in_shape.NumAxes() == 0 || in_shape.elem_cnt() == 0) {
return Maybe<void>::Ok();
Expand All @@ -162,7 +162,7 @@ Maybe<void> ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures(
JUST(ReshapeUserOpUtil::Squeeze(out_shape, &squeezed_out_shape,
&out_squeezed_axis2original_axis));
JUST(ReshapeUserOpUtil::GetGroupStartInAxis2OutAxis(squeezed_in_shape, squeezed_out_shape,
parallel_num,
hierarchy_value,
&squeezed_group_start_in_axis2out_axis));
}
for (const auto& pair : squeezed_group_start_in_axis2out_axis) {
Expand All @@ -174,103 +174,4 @@ Maybe<void> ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures(
return Maybe<void>::Ok();
}

namespace {

Maybe<void> GetInputNdSbp(user_op::InferNdSbpFnContext* ctx, const user_op::OpArg& in_arg,
NdSbp* distribution) {
*distribution = ctx->NdSbpHint4InputArgNameAndIndex(in_arg.name(), in_arg.index());
const auto& constraints = ctx->nd_sbp_constraints();
if (constraints.bn_in_op2nd_sbp_size() != 0) {
const auto it =
constraints.bn_in_op2nd_sbp().find(GenRepeatedBn(in_arg.name(), in_arg.index()));
if (it != constraints.bn_in_op2nd_sbp().end()) { *distribution = it->second; }
}
return Maybe<void>::Ok();
}

Maybe<void> ApplySbpParallel(const SbpParallel& sbp, const int64_t parallel_num, Shape* shape) {
if (sbp.has_split_parallel()) {
const int64_t axis = sbp.split_parallel().axis();
CHECK_EQ_OR_RETURN(shape->At(axis) % parallel_num, 0)
<< Error::RuntimeError() << "The size of tensor in the " << axis
<< " must be an integer multiple of parallel_num, "
<< "but got " << shape->At(axis) << " and " << parallel_num;
shape->Set(axis, shape->At(axis) / parallel_num);
}
return Maybe<void>::Ok();
}

} // namespace

Maybe<void> ReshapeUserOpUtil::InferNdSbp(user_op::InferNdSbpFnContext* ctx,
const Shape& logical_in_shape,
const Shape& logical_out_shape) {
const std::string& op_type_name = ctx->user_op_conf().op_type_name();
CHECK_OR_RETURN(op_type_name == "reshape" || op_type_name == "reshape_like")
<< Error::RuntimeError() << "The op_type_name must be \"reshape\" or \"reshape_like\", "
<< "but got " << op_type_name;
const bool is_reshape_like = (op_type_name == "reshape_like");
std::vector<user_op::OpArg> in_args({{"in", 0}});
if (is_reshape_like) { in_args.emplace_back(user_op::OpArg("like", 0)); }
HashMap<std::string, NdSbp> ibn2nd_sbp;
ibn2nd_sbp.reserve(in_args.size());
for (const auto& arg : in_args) {
NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex(arg.name(), arg.index());
JUST(GetInputNdSbp(ctx, arg, in_distribution));
CHECK_OR_RETURN(
ibn2nd_sbp.emplace(GenRepeatedBn(arg.name(), arg.index()), *in_distribution).second)
<< "emplace error"; // NOLINT(maybe-need-error-msg)
}
NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0);

Shape in_shape = logical_in_shape;
Shape out_shape = logical_out_shape;
const Shape& parallel_hierarchy = ctx->parallel_hierarchy();
for (int64_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) {
SbpSignatureList sbp_sig_list;
user_op::UserOpSbpSignatureBuilder builder(&sbp_sig_list);
builder.Broadcast(in_args).Broadcast(user_op::OpArg("out", 0)).Build();
if (is_reshape_like) {
builder.PartialSum(user_op::OpArg("like", 0))
.Broadcast(user_op::OpArg("in", 0))
.Broadcast(user_op::OpArg("out", 0))
.Build();
builder.Broadcast(user_op::OpArg("like", 0))
.PartialSum(user_op::OpArg("in", 0))
.PartialSum(user_op::OpArg("out", 0))
.Build();
JUST(GetReshapeUserOpSbpSignatures(in_shape, out_shape, {{"in", 0}},
{{"like", 0}, {"out", 0}}, parallel_hierarchy.At(i),
&builder));
} else {
JUST(GetReshapeUserOpSbpSignatures(in_shape, out_shape, {{"in", 0}}, {{"out", 0}},
parallel_hierarchy.At(i), &builder));
}

const SbpSignature* matched_sbp_signature = nullptr;
for (const auto& sbp_signature : sbp_sig_list.sbp_signature()) {
bool all_match = true;
for (const auto& in_arg : in_args) {
std::string ibn = GenRepeatedBn(in_arg.name(), in_arg.index());
if (sbp_signature.bn_in_op2sbp_parallel().at(ibn) != ibn2nd_sbp.at(ibn).sbp_parallel(i)) {
all_match = false;
break;
}
}
if (all_match) {
matched_sbp_signature = &sbp_signature;
break;
}
}
CHECK_OR_RETURN(matched_sbp_signature != nullptr)
<< "FusedLstmCellGrad::Pointer to the matched sbp signature is nullptr";
SbpParallel out_sbp = matched_sbp_signature->bn_in_op2sbp_parallel().at("out_0");
JUST(ApplySbpParallel(matched_sbp_signature->bn_in_op2sbp_parallel().at("in_0"),
parallel_hierarchy.At(i), &in_shape));
JUST(ApplySbpParallel(out_sbp, parallel_hierarchy.At(i), &out_shape));
*(out_distribution->add_sbp_parallel()) = out_sbp;
}
return Maybe<void>::Ok();
}

} // namespace oneflow
10 changes: 4 additions & 6 deletions oneflow/user/ops/reshape_user_op_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,13 @@ struct ReshapeUserOpUtil {
static Maybe<void> Squeeze(const Shape& origin, Shape* shape,
HashMap<int, int>* squeezed_axis2origin_axis);
static Maybe<void> GetGroupStartInAxis2OutAxis(const Shape& in_shape, const Shape& out_shape,
const int64_t parallel_num,
const int64_t hierarchy_value,
HashMap<int, int>* group_start_in_axis2out_axis);
static Maybe<void> GetReshapeUserOpSbpSignatures(const Shape& in_shape, const Shape& out_shape,
std::vector<user_op::OpArg> in_args,
std::vector<user_op::OpArg> out_args,
const int64_t parallel_num,
const std::vector<user_op::OpArg>& in_args,
const std::vector<user_op::OpArg>& out_args,
const int64_t hierarchy_value,
user_op::UserOpSbpSignatureBuilder* builder);
static Maybe<void> InferNdSbp(user_op::InferNdSbpFnContext* ctx, const Shape& logical_in_shape,
const Shape& logical_out_shape);
};
} // namespace oneflow

Expand Down