Skip to content

Commit

Permalink
Fix ddp multi-output bug (#6310)
Browse files Browse the repository at this point in the history
* fix(ddp): fix ddp multi-output bug

* add JUST

* resolve warning

Co-authored-by: oneflow-ci-bot <[email protected]>
  • Loading branch information
wyg1997 and oneflow-ci-bot authored Sep 16, 2021
1 parent 2fe6db3 commit 7388d21
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,32 @@ limitations under the License.
namespace oneflow {
namespace one {

struct SelectFirstCaptureState : public AutoGradCaptureState {
struct SelectTopNCaptureState : public AutoGradCaptureState {
TensorTuple inputs;
std::vector<bool> requires_grad;
int32_t top_n = 0;
};

class SelectFirst : public OpExprGradFunction<SelectFirstCaptureState> {
class SelectTopN : public OpExprGradFunction<SelectTopNCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(SelectFirstCaptureState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(SelectTopNCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->inputs = inputs;
CHECK_OR_RETURN(ctx->inputs.at(0)->requires_grad());
ctx->top_n = JUST(attrs.GetAttr<int32_t>("top_n"));
ctx->requires_grad.resize(inputs.size());
for (int i = 0; i < ctx->requires_grad.size(); ++i) {
ctx->requires_grad.at(i) = inputs.at(i)->requires_grad();
}
return Maybe<void>::Ok();
}

Maybe<void> Apply(const SelectFirstCaptureState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const SelectTopNCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
in_grads->at(0) = out_grads.at(0);
for (int i = 1; i < ctx->inputs.size(); i++) {
CHECK_EQ_OR_RETURN(ctx->top_n, out_grads.size());
for (int i = 0; i < ctx->top_n; ++i) { in_grads->at(i) = out_grads.at(i); }
for (int i = ctx->top_n; i < ctx->inputs.size(); ++i) {
if (!ctx->requires_grad.at(i)) { continue; }
const auto& tensor = ctx->inputs.at(i);
in_grads->at(i) = JUST(StaticZerosTensor::MakeTensor(
Expand All @@ -55,7 +57,7 @@ class SelectFirst : public OpExprGradFunction<SelectFirstCaptureState> {
}
};

REGISTER_OP_EXPR_GRAD_FUNCTION("select_first", SelectFirst);
REGISTER_OP_EXPR_GRAD_FUNCTION("select_top_n", SelectTopN);

} // namespace one
} // namespace oneflow
4 changes: 2 additions & 2 deletions oneflow/core/framework/op_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,9 @@ Maybe<OpExprGradClosure> BuiltinOpExprImpl<DistributeAddOpConf>::GetOrCreateOpGr
UNIMPLEMENTED_THEN_RETURN();
}

Maybe<OpExprGradClosure> SelectFirstOpExpr::GetOrCreateOpGradClosure() const {
Maybe<OpExprGradClosure> SelectTopNOpExpr::GetOrCreateOpGradClosure() const {
if (!op_grad_func_.get()) {
op_grad_func_.reset(NewObj<std::string, OpExprGradFunctionIf>("select_first"));
op_grad_func_.reset(NewObj<std::string, OpExprGradFunctionIf>("select_top_n"));
CHECK_NOTNULL_OR_RETURN(op_grad_func_.get());
JUST(op_grad_func_->Init(*this));
}
Expand Down
15 changes: 9 additions & 6 deletions oneflow/core/framework/op_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,14 @@ using DistributeCloneOpExpr = BuiltinOpExprImpl<DistributeCloneOpConf>;
using DistributeConcatOpExpr = BuiltinOpExprImpl<DistributeConcatOpConf>;
using DistributeAddOpExpr = BuiltinOpExprImpl<DistributeAddOpConf>;

class SelectFirstOpExpr final : public OpExpr {
class SelectTopNOpExpr final : public OpExpr {
public:
static Maybe<SelectFirstOpExpr> New() {
return std::shared_ptr<SelectFirstOpExpr>(new SelectFirstOpExpr());
static Maybe<SelectTopNOpExpr> New() {
return std::shared_ptr<SelectTopNOpExpr>(new SelectTopNOpExpr());
}

const std::string& op_type_name() const override {
static const std::string kOpTypeName = "select_first";
static const std::string kOpTypeName = "select_top_n";
return kOpTypeName;
}

Expand All @@ -255,14 +255,17 @@ class SelectFirstOpExpr final : public OpExpr {
return 0;
}

int output_size() const override { return 1; }
int output_size() const override {
// output should be resized in apply function
return 0;
}

Maybe<bool> IsGradDisabled() const override { return false; }

Maybe<OpExprGradClosure> GetOrCreateOpGradClosure() const override;

private:
SelectFirstOpExpr() = default;
SelectTopNOpExpr() = default;

mutable std::shared_ptr<OpExprGradFunctionIf> op_grad_func_;
};
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/framework/op_interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class OpExprInterpreter {

#define FOR_EACH_BUILTIN_OPS(_macro) \
_macro(UserOp); \
_macro(SelectFirstOp); \
_macro(SelectTopNOp); \
_macro(VariableOp); \
_macro(CastToMirroredOp); \
_macro(CastFromMirroredOp); \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ Maybe<void> EagerConsistentInterpreter::ApplyImpl(const DistributeAddOpExpr& op_
OF_UNIMPLEMENTED();
}

Maybe<void> EagerConsistentInterpreter::ApplyImpl(const SelectFirstOpExpr& op_expr,
Maybe<void> EagerConsistentInterpreter::ApplyImpl(const SelectTopNOpExpr& op_expr,
const TensorTuple& inputs, TensorTuple* outputs,
const OpExprInterpContext& ctx) const {
OF_UNIMPLEMENTED();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,11 +429,11 @@ Maybe<void> EagerMirroredInterpreter::ApplyImpl(const DistributeAddOpExpr& op_ex
return BuildAndRunDistributeConcatAndAddInstruction(op_expr, inputs, outputs);
}

Maybe<void> EagerMirroredInterpreter::ApplyImpl(const SelectFirstOpExpr& op_expr,
Maybe<void> EagerMirroredInterpreter::ApplyImpl(const SelectTopNOpExpr& op_expr,
const TensorTuple& inputs, TensorTuple* outputs,
const OpExprInterpContext& ctx) const {
CHECK_EQ_OR_RETURN(outputs->size(), 1);
outputs->at(0) = inputs.at(0);
int top_n = JUST(ctx.attrs.GetAttr<int32_t>("top_n"));
outputs->assign(inputs.begin(), inputs.begin() + top_n);
return Maybe<void>::Ok();
}

Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/framework/op_interpreter/op_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Maybe<void> EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& in
APPLY_IF(DistributeConcatOp);
APPLY_IF(DistributeAddOp);
APPLY_IF(FunctionOp);
APPLY_IF(SelectFirstOp)
APPLY_IF(SelectTopNOp)
#undef APPLY_IF

OF_UNIMPLEMENTED() << "The type " << op_expr.op_type_name()
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1082,8 +1082,8 @@
signature: "Tensor (Tensor x, SbpList out_sbp) => ConsistentS2S"
bind_python: False

- name: "select_first"
signature: "Tensor (TensorTuple inputs) => SelectFirst"
- name: "select_top_n"
signature: "TensorTuple (TensorTuple inputs, Int32 n) => SelectTopN"
bind_python: True

- name: "cast_like"
Expand Down
12 changes: 7 additions & 5 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,12 +469,14 @@ class ClampGradFunctor {
std::shared_ptr<OpExpr> clip_max_op_;
};

class SelectFirstFunctor {
class SelectTopNFunctor {
public:
SelectFirstFunctor() { op_ = CHECK_JUST(one::SelectFirstOpExpr::New()); }
SelectTopNFunctor() { op_ = CHECK_JUST(one::SelectTopNOpExpr::New()); }

Maybe<Tensor> operator()(const TensorTuple& inputs) const {
const auto& output = JUST(OpInterpUtil::Dispatch<one::Tensor>(*op_, inputs));
Maybe<TensorTuple> operator()(const TensorTuple& inputs, int32_t n) const {
MutableAttrMap attr;
JUST(attr.SetAttr<int32_t>("top_n", n));
const auto& output = JUST(OpInterpUtil::Dispatch<one::TensorTuple>(*op_, inputs, attr));
return output;
}

Expand Down Expand Up @@ -684,7 +686,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<CastFunctor>("Cast");
m.add_functor<ClampFunctor>("Clamp");
m.add_functor<ClampGradFunctor>("ClampGrad");
m.add_functor<SelectFirstFunctor>("SelectFirst");
m.add_functor<SelectTopNFunctor>("SelectTopN");
m.add_functor<MinimumFunctor>("Minimum");
m.add_functor<MaximumFunctor>("Maximum");
m.add_functor<ScalarFModFunctor>("ScalarFMod");
Expand Down
17 changes: 14 additions & 3 deletions python/oneflow/nn/parallel/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,20 @@ def post_forward_hook(module, input, output):
ddp_state_for_reversed_params = module._ddp_state_for_reversed_params
for state in ddp_state_for_reversed_params.values():
state[0], state[1] = False, False
output = flow._C.select_first(
convert_to_tensor_tuple([output, *ddp_state_for_reversed_params.keys()])
)
if isinstance(output, tuple):
output = flow._C.select_top_n(
convert_to_tensor_tuple(
[*output, *ddp_state_for_reversed_params.keys()]
),
n=len(output),
)
else:
output = flow._C.select_top_n(
convert_to_tensor_tuple(
[output, *ddp_state_for_reversed_params.keys()]
),
n=1,
)[0]
return output

module.register_forward_hook(post_forward_hook)
Expand Down

0 comments on commit 7388d21

Please sign in to comment.