diff --git a/oneflow/core/autograd/gradient_funcs/layer_norm.cpp b/oneflow/core/autograd/gradient_funcs/layer_norm.cpp index 4a0835247ee..c30fe64741f 100644 --- a/oneflow/core/autograd/gradient_funcs/layer_norm.cpp +++ b/oneflow/core/autograd/gradient_funcs/layer_norm.cpp @@ -18,6 +18,9 @@ limitations under the License. #include "oneflow/core/functional/functional.h" namespace oneflow { + +DEFINE_ENV_BOOL(ONEFLOW_USE_FUSE_LAYER_NORM_GRAD, false); + namespace one { struct LayerNormCaptureState : public AutoGradCaptureState { @@ -107,22 +110,37 @@ Maybe LayerNorm::Apply(const LayerNormCaptureState* ctx, const TensorTuple std::shared_ptr mean = saved_tensors.at(ctx->mean_index); std::shared_ptr inv_variance = saved_tensors.at(ctx->inv_variance_index); - if (ctx->has_affine) { - // Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, - // Int64 begin_params_axis) - const auto& results = - JUST(functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis)); - in_grads->at(1) = results->at(0); // For gamma. - in_grads->at(2) = results->at(1); // For beta. - } - if (ctx->x_requires_grad) { - if (ctx->scale) { - std::shared_ptr gamma = saved_tensors.at(ctx->gamma_index); - in_grads->at(0) = JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma, - begin_norm_axis, ctx->epsilon)); - } else { - in_grads->at(0) = - JUST(functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, ctx->epsilon)); + if (EnvBool()) { + // just for npu + CHECK(ctx->has_affine) << "LayerNorm::Apply must has_affine for NPU GPT2 test"; + if (ctx->x_requires_grad) { + if (ctx->scale) { + std::shared_ptr gamma = saved_tensors.at(ctx->gamma_index); + *in_grads = *JUST(functional::FuseLayerNormAffineGrad( + dy, x, mean, inv_variance, gamma, begin_norm_axis, begin_params_axis, ctx->epsilon)); + } else { + *in_grads = *JUST(functional::FuseLayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, + begin_params_axis, ctx->epsilon)); + } + } + } else { + if (ctx->has_affine) { + // Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, + // Int64 begin_params_axis) + const auto& results = + JUST(functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis)); + in_grads->at(1) = results->at(0); // For gamma. + in_grads->at(2) = results->at(1); // For beta. + } + if (ctx->x_requires_grad) { + if (ctx->scale) { + std::shared_ptr gamma = saved_tensors.at(ctx->gamma_index); + in_grads->at(0) = JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma, + begin_norm_axis, ctx->epsilon)); + } else { + in_grads->at(0) = JUST( + functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, ctx->epsilon)); + } } } return Maybe::Ok(); diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 8b05bf73a44..8e067b0cbd1 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -1558,6 +1558,14 @@ signature: "Tensor (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Double epsilon) => LayerNormAffineGrad" bind_python: False +- name: "fuse_layer_norm_grad" + signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => FuseLayerNormGrad" + bind_python: False + +- name: "fuse_layer_norm_affine_grad" + signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => FuseLayerNormAffineGrad" + bind_python: False + - name: "layer_norm_param_grad" signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_params_axis) => LayerNormParamGrad" bind_python: False diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp index e0cf9e2ff34..17a56beb40b 100644 --- a/oneflow/core/functional/impl/nn_grad_functor.cpp +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -983,6 +983,64 @@ class LayerNormAffineGradFunctor { std::shared_ptr op_; }; +class FuseLayerNormGradFunctor { + public: + FuseLayerNormGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("fuse_layer_norm_grad") + .Input("dy") + .Input("x") + .Input("mean") + .Input("inv_variance") + .Output("dx") + .Output("gamma_diff") + .Output("beta_diff") + .Build()); + } + Maybe operator()(const std::shared_ptr& dy, + const std::shared_ptr& x, + const std::shared_ptr& mean, + const std::shared_ptr& inv_variance, + const int64_t& begin_norm_axis, const int64_t& begin_params_axis, + const double& epsilon) const { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "begin_params_axis", "epsilon"); + attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon); + return OpInterpUtil::Dispatch(*op_, {dy, x, mean, inv_variance}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class FuseLayerNormAffineGradFunctor { + public: + FuseLayerNormAffineGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("fuse_layer_norm_grad") + .Input("dy") + .Input("x") + .Input("mean") + .Input("inv_variance") + .Input("gamma") + .Output("dx") + .Output("gamma_diff") + .Output("beta_diff") + .Build()); + } + Maybe operator()(const std::shared_ptr& dy, + const std::shared_ptr& x, + const std::shared_ptr& mean, + const std::shared_ptr& inv_variance, + const std::shared_ptr& gamma, + const int64_t& begin_norm_axis, const int64_t& begin_params_axis, + const double& epsilon) const { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "begin_params_axis", "epsilon"); + attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon); + return OpInterpUtil::Dispatch(*op_, {dy, x, mean, inv_variance, gamma}, attrs); + } + + private: + std::shared_ptr op_; +}; + class LayerNormParamGradFunctor { public: LayerNormParamGradFunctor() { @@ -1707,6 +1765,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("LayerNormGrad"); m.add_functor("LayerNormAffineGrad"); m.add_functor("LayerNormParamGrad"); + m.add_functor("FuseLayerNormGrad"); + m.add_functor("FuseLayerNormAffineGrad"); m.add_functor("GroupNormGrad"); m.add_functor("GroupNormParamGrad"); m.add_functor("BroadcastMatmulGradB"); diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index b50a1ceceab..7532cfe441e 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -7071,6 +7071,35 @@ def OneFlow_LayerNormGradOp : OneFlow_BaseOp<"layer_norm_grad", [NoMemoryEffect, let has_data_type_infer_fn = 1; } +def OneFlow_FuseLayerNormGradOp : OneFlow_BaseOp<"fuse_layer_norm_grad", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x, + OneFlow_Tensor:$mean, + OneFlow_Tensor:$inv_variance, + Optional:$gamma, + Optional:$_add_to_output + ); + let output = (outs + OneFlow_Tensor:$dx, + OneFlow_Tensor:$gamma_diff, + OneFlow_Tensor:$beta_diff + ); + let attrs = (ins + DefaultValuedAttr:$begin_norm_axis, + DefaultValuedAttr:$begin_params_axis, + DefaultValuedAttr:$epsilon + ); + let trait_attrs = (ins + DenseI32ArrayAttr:$operand_segment_sizes, + DenseI32ArrayAttr:$result_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + def OneFlow_LayerNormParamGradOp : OneFlow_BaseOp<"layer_norm_param_grad", [NoMemoryEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, diff --git a/oneflow/user/kernels/fuse_layer_norm_cpu_kernel.cpp b/oneflow/user/kernels/fuse_layer_norm_cpu_kernel.cpp new file mode 100644 index 00000000000..af70ef77bf8 --- /dev/null +++ b/oneflow/user/kernels/fuse_layer_norm_cpu_kernel.cpp @@ -0,0 +1,40 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +template +class FuseLayerNormGradCpuKernel final : public user_op::OpKernel { + public: + FuseLayerNormGradCpuKernel() = default; + ~FuseLayerNormGradCpuKernel() = default; + + private: + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + void Compute(user_op::KernelComputeContext* ctx) const override { TODO(); }; +}; + +#define REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(dtype) \ + REGISTER_USER_KERNEL("fuse_layer_norm_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ + && (user_op::HobDataType("dy", 0) == GetDataType::value)); + +REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(float) +REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(double) + +} // namespace oneflow diff --git a/oneflow/user/ops/layer_norm_op.cpp b/oneflow/user/ops/layer_norm_op.cpp index 6a6d9f85e70..55c8156b2cf 100644 --- a/oneflow/user/ops/layer_norm_op.cpp +++ b/oneflow/user/ops/layer_norm_op.cpp @@ -268,4 +268,127 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) { return Maybe::Ok(); } +/* static */ Maybe FuseLayerNormGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); + const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); + user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0); + CHECK_EQ_OR_RETURN(dy.shape(), x.shape()); + const int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); + CHECK_GT_OR_RETURN(begin_norm_axis, 0); + const Shape& bn_param_shape = InferBnParamShape(x.shape(), begin_norm_axis); + CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape); + CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape); + dx->set_shape(dy.shape()); + dx->set_is_dynamic(dy.is_dynamic()); + if (ctx->has_input("_add_to_output", 0)) { + const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); + CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape()); + } + + auto has_tensor = [ctx](const std::string& bn) -> bool { + bool ret = false; + for (const auto& t : ctx->inputs()) { + if (bn == t.first) { return true; } + } + for (const auto& t : ctx->outputs()) { + if (bn == t.first) { return true; } + } + return ret; + }; + const int64_t begin_params_axis = ctx->Attr("begin_params_axis"); + const bool has_beta_diff = has_tensor("beta_diff"); + const bool has_gamma_diff = has_tensor("gamma_diff"); + CHECK_GE_OR_RETURN(begin_params_axis, 1); + CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes()); + DimVector param_shape_dim_vec; + param_shape_dim_vec.insert(param_shape_dim_vec.end(), + dy.shape().dim_vec().cbegin() + begin_params_axis, + dy.shape().dim_vec().cend()); + const Shape param_shape(param_shape_dim_vec); + if (has_beta_diff) { + user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0); + beta_diff->set_shape(param_shape); + } + if (has_gamma_diff) { + user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0); + gamma_diff->set_shape(param_shape); + } + return Maybe::Ok(); +} + +/*static*/ Maybe FuseLayerNormGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe FuseLayerNormGradOp::GetSbp(user_op::SbpContext* ctx) { + std::vector broadcast_args; + if (ctx->user_op_conf().has_input("gamma", 0)) { + broadcast_args.emplace_back(user_op::OpArg("gamma", 0)); + } + int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); + int64_t begin_params_axis = ctx->Attr("begin_params_axis"); + CHECK_EQ(begin_norm_axis, begin_params_axis) + << "begin_norm_axis and begin_params_axis must be equal, but got " << begin_norm_axis + << " and " << begin_params_axis; + for (int i = 0; i < begin_norm_axis; ++i) { + ctx->NewBuilder() + .Split(ctx->inputs(), i) + .Split(user_op::OpArg("dx", 0), i) + .PartialSum(user_op::OpArg("gamma_diff", 0)) + .PartialSum(user_op::OpArg("beta_diff", 0)) + .Broadcast(broadcast_args) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe FuseLayerNormGradOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type()) + << "InferDataType Failed. Expected " << DataType_Name(x.data_type()) << ", but got " + << DataType_Name(dy.data_type()); + const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); + const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); + DataType bn_param_data_type = InferBnParamDataType(x.data_type()); + CHECK_EQ_OR_RETURN(mean.data_type(), bn_param_data_type) + << "InferDataType Failed. Expected " << DataType_Name(bn_param_data_type) << ", but got " + << DataType_Name(mean.data_type()); + CHECK_EQ_OR_RETURN(inv_variance.data_type(), bn_param_data_type) + << "InferDataType Failed. Expected " << DataType_Name(bn_param_data_type) << ", but got " + << DataType_Name(inv_variance.data_type()); + user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0); + dx->set_data_type(dy.data_type()); + if (ctx->has_input("_add_to_output", 0)) { + const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); + CHECK_EQ_OR_RETURN(add_to_output.data_type(), dx->data_type()) + << "InferDataType Failed. Expected " << DataType_Name(dx->data_type()) << ", but got " + << DataType_Name(add_to_output.data_type()); + } + + auto has_tensor = [ctx](const std::string& bn) -> bool { + bool ret = false; + for (auto& t : ctx->inputs()) { + if (bn == t.first) { return true; } + } + for (auto& t : ctx->outputs()) { + if (bn == t.first) { return true; } + } + return ret; + }; + const bool has_beta_diff = has_tensor("beta_diff"); + const bool has_gamma_diff = has_tensor("gamma_diff"); + if (has_beta_diff) { + user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0); + beta_diff->set_data_type(dy.data_type()); + } + if (has_gamma_diff) { + user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0); + gamma_diff->set_data_type(dy.data_type()); + } + return Maybe::Ok(); +} + } // namespace oneflow