Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ddp multi-output bug #6310

Merged
merged 10 commits into from
Sep 16, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,32 @@ limitations under the License.
namespace oneflow {
namespace one {

struct SelectFirstCaptureState : public AutoGradCaptureState {
struct SelectTopNCaptureState : public AutoGradCaptureState {
TensorTuple inputs;
std::vector<bool> requires_grad;
int32_t top_n;
};

class SelectFirst : public OpExprGradFunction<SelectFirstCaptureState> {
class SelectTopN : public OpExprGradFunction<SelectTopNCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(SelectFirstCaptureState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(SelectTopNCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->inputs = inputs;
CHECK_OR_RETURN(ctx->inputs.at(0)->requires_grad());
ctx->top_n = JUST(attrs.GetAttr<int32_t>("top_n"));
ctx->requires_grad.resize(inputs.size());
for (int i = 0; i < ctx->requires_grad.size(); ++i) {
ctx->requires_grad.at(i) = inputs.at(i)->requires_grad();
}
return Maybe<void>::Ok();
}

Maybe<void> Apply(const SelectFirstCaptureState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const SelectTopNCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->at(0) = out_grads.at(0);
for (int i = 1; i < ctx->inputs.size(); i++) {
CHECK_EQ_OR_RETURN(ctx->top_n, out_grads.size());
for (int i = 0; i < ctx->top_n; ++i) { in_grads->at(i) = out_grads.at(i); }
for (int i = ctx->top_n; i < ctx->inputs.size(); ++i) {
if (!ctx->requires_grad.at(i)) { continue; }
const auto& tensor = ctx->inputs.at(i);
in_grads->at(i) = JUST(StaticZerosTensor::MakeTensor(
Expand All @@ -55,7 +57,7 @@ class SelectFirst : public OpExprGradFunction<SelectFirstCaptureState> {
}
};

REGISTER_OP_EXPR_GRAD_FUNCTION("select_first", SelectFirst);
REGISTER_OP_EXPR_GRAD_FUNCTION("select_top_n", SelectTopN);

} // namespace one
} // namespace oneflow
4 changes: 2 additions & 2 deletions oneflow/core/framework/op_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,9 @@ Maybe<OpExprGradClosure> BuiltinOpExprImpl<DistributeAddOpConf>::GetOrCreateOpGr
UNIMPLEMENTED_THEN_RETURN();
}

Maybe<OpExprGradClosure> SelectFirstOpExpr::GetOrCreateOpGradClosure() const {
Maybe<OpExprGradClosure> SelectTopNOpExpr::GetOrCreateOpGradClosure() const {
if (!op_grad_func_.get()) {
op_grad_func_.reset(NewObj<std::string, OpExprGradFunctionIf>("select_first"));
op_grad_func_.reset(NewObj<std::string, OpExprGradFunctionIf>("select_top_n"));
CHECK_NOTNULL_OR_RETURN(op_grad_func_.get());
JUST(op_grad_func_->Init(*this));
}
Expand Down
15 changes: 9 additions & 6 deletions oneflow/core/framework/op_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,14 @@ using DistributeCloneOpExpr = BuiltinOpExprImpl<DistributeCloneOpConf>;
using DistributeConcatOpExpr = BuiltinOpExprImpl<DistributeConcatOpConf>;
using DistributeAddOpExpr = BuiltinOpExprImpl<DistributeAddOpConf>;

class SelectFirstOpExpr final : public OpExpr {
class SelectTopNOpExpr final : public OpExpr {
public:
static Maybe<SelectFirstOpExpr> New() {
return std::shared_ptr<SelectFirstOpExpr>(new SelectFirstOpExpr());
static Maybe<SelectTopNOpExpr> New() {
return std::shared_ptr<SelectTopNOpExpr>(new SelectTopNOpExpr());
}

const std::string& op_type_name() const override {
static const std::string kOpTypeName = "select_first";
static const std::string kOpTypeName = "select_top_n";
return kOpTypeName;
}

Expand All @@ -255,14 +255,17 @@ class SelectFirstOpExpr final : public OpExpr {
return 0;
}

int output_size() const override { return 1; }
int output_size() const override {
// output should be resized in apply function
return 0;
}

Maybe<bool> IsGradDisabled() const override { return false; }

Maybe<OpExprGradClosure> GetOrCreateOpGradClosure() const override;

private:
SelectFirstOpExpr() = default;
SelectTopNOpExpr() = default;

mutable std::shared_ptr<OpExprGradFunctionIf> op_grad_func_;
};
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/framework/op_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class OpExprInterpreter {

#define FOR_EACH_BUILTIN_OPS(_macro) \
_macro(UserOp); \
_macro(SelectFirstOp); \
_macro(SelectTopNOp); \
_macro(VariableOp); \
_macro(CastToMirroredOp); \
_macro(CastFromMirroredOp); \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ Maybe<void> EagerConsistentInterpreter::ApplyImpl(const DistributeAddOpExpr& op_
OF_UNIMPLEMENTED();
}

Maybe<void> EagerConsistentInterpreter::ApplyImpl(const SelectFirstOpExpr& op_expr,
Maybe<void> EagerConsistentInterpreter::ApplyImpl(const SelectTopNOpExpr& op_expr,
const TensorTuple& inputs, TensorTuple* outputs,
const OpExprInterpContext& ctx) const {
OF_UNIMPLEMENTED();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,11 +429,11 @@ Maybe<void> EagerMirroredInterpreter::ApplyImpl(const DistributeAddOpExpr& op_ex
return BuildAndRunDistributeConcatAndAddInstruction(op_expr, inputs, outputs);
}

Maybe<void> EagerMirroredInterpreter::ApplyImpl(const SelectFirstOpExpr& op_expr,
Maybe<void> EagerMirroredInterpreter::ApplyImpl(const SelectTopNOpExpr& op_expr,
const TensorTuple& inputs, TensorTuple* outputs,
const OpExprInterpContext& ctx) const {
CHECK_EQ_OR_RETURN(outputs->size(), 1);
outputs->at(0) = inputs.at(0);
int top_n = JUST(ctx.attrs.GetAttr<int32_t>("top_n"));
outputs->assign(inputs.begin(), inputs.begin() + top_n);
return Maybe<void>::Ok();
}

Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/framework/op_interpreter/op_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Maybe<void> EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& in
APPLY_IF(DistributeConcatOp);
APPLY_IF(DistributeAddOp);
APPLY_IF(FunctionOp);
APPLY_IF(SelectFirstOp)
APPLY_IF(SelectTopNOp)
#undef APPLY_IF

OF_UNIMPLEMENTED() << "The type " << op_expr.op_type_name()
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1082,8 +1082,8 @@
signature: "Tensor (Tensor x, SbpList out_sbp) => ConsistentS2S"
bind_python: False

- name: "select_first"
signature: "Tensor (TensorTuple inputs) => SelectFirst"
- name: "select_top_n"
signature: "TensorTuple (TensorTuple inputs, Int32 n) => SelectTopN"
bind_python: True

- name: "cast_like"
Expand Down
12 changes: 7 additions & 5 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,12 +467,14 @@ class ClampGradFunctor {
std::shared_ptr<OpExpr> clip_max_op_;
};

class SelectFirstFunctor {
class SelectTopNFunctor {
public:
SelectFirstFunctor() { op_ = CHECK_JUST(one::SelectFirstOpExpr::New()); }
SelectTopNFunctor() { op_ = CHECK_JUST(one::SelectTopNOpExpr::New()); }

Maybe<Tensor> operator()(const TensorTuple& inputs) const {
const auto& output = JUST(OpInterpUtil::Dispatch<one::Tensor>(*op_, inputs));
Maybe<TensorTuple> operator()(const TensorTuple& inputs, int32_t n) const {
MutableAttrMap attr;
attr.SetAttr<int32_t>("top_n", n);
const auto& output = JUST(OpInterpUtil::Dispatch<one::TensorTuple>(*op_, inputs, attr));
return output;
}

Expand Down Expand Up @@ -682,7 +684,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<CastFunctor>("Cast");
m.add_functor<ClampFunctor>("Clamp");
m.add_functor<ClampGradFunctor>("ClampGrad");
m.add_functor<SelectFirstFunctor>("SelectFirst");
m.add_functor<SelectTopNFunctor>("SelectTopN");
m.add_functor<MinimumFunctor>("Minimum");
m.add_functor<MaximumFunctor>("Maximum");
m.add_functor<ScalarFModFunctor>("ScalarFMod");
Expand Down
17 changes: 14 additions & 3 deletions python/oneflow/nn/parallel/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,20 @@ def post_forward_hook(module, input, output):
ddp_state_for_reversed_params = module._ddp_state_for_reversed_params
for state in ddp_state_for_reversed_params.values():
state[0], state[1] = False, False
output = flow._C.select_first(
convert_to_tensor_tuple([output, *ddp_state_for_reversed_params.keys()])
)
if isinstance(output, tuple):
output = flow._C.select_top_n(
convert_to_tensor_tuple(
[*output, *ddp_state_for_reversed_params.keys()]
),
n=len(output),
)
else:
output = flow._C.select_top_n(
convert_to_tensor_tuple(
[output, *ddp_state_for_reversed_params.keys()]
),
n=1,
)[0]
return output

module.register_forward_hook(post_forward_hook)
Expand Down