From 484dd400226567b922b1e6afb49b9b00dbeb9f61 Mon Sep 17 00:00:00 2001 From: lezcano Date: Thu, 29 Dec 2022 16:32:00 +0000 Subject: [PATCH] Implement PReLU in a compositional way (#91238) The PReLU implementation was all over the place. This lead to a number of bugs like https://github.com/pytorch/pytorch/issues/68760. We fix it by: - Keeping the weird broadcasting logic it has as a CompositeImplicit kernel that calls into a second kernel - This second kernel is just a good-ol' pointwise kernel. - We implement the derivative for the pointwise kernel via TI as well for speed. - We implement the second derivative for the pointwise kernel and the forward AD derivatives compositionally This fixes a number of issues: - We don't perform copies any more when the inputs are not contiguous - The derivatives are now correct - We fix vmap and many other functorch-related issues. - CPU and CUDA now share the relevant broadcasting logic - The implementation is about 1/3 the length. Fixes https://github.com/pytorch/pytorch/issues/68760 Fixes https://github.com/pytorch/pytorch/issues/89895 Pull Request resolved: https://github.com/pytorch/pytorch/pull/91238 Approved by: https://github.com/kshitij12345, https://github.com/jbschlosser, https://github.com/albanD --- .../ATen/functorch/BatchRulesActivation.cpp | 169 ----------------- .../ATen/functorch/BatchRulesBinaryOps.cpp | 3 +- .../functorch/BatchRulesDecompositions.cpp | 1 + aten/src/ATen/native/Activation.cpp | 151 +++++---------- aten/src/ATen/native/Activation.h | 4 +- aten/src/ATen/native/cpu/Activation.cpp | 69 ++----- aten/src/ATen/native/cuda/Activation.cpp | 85 --------- aten/src/ATen/native/cuda/Activation.h | 11 -- .../ATen/native/cuda/ActivationPreluKernel.cu | 159 ++-------------- aten/src/ATen/native/mkldnn/Prelu.cpp | 7 - .../ATen/native/mps/operations/Activation.mm | 93 +--------- aten/src/ATen/native/native_functions.yaml | 16 +- .../cpu/kernels/QuantizedOpKernels.cpp | 18 +- aten/src/ATen/native/quantized/cpu/qrelu.cpp | 8 +- .../check_forward_backward_compatibility.py | 3 + test/functorch/test_aotdispatch.py | 1 - test/functorch/test_ops.py | 6 - test/test_cuda.py | 2 +- tools/autograd/derivatives.yaml | 22 ++- torch/_decomp/decompositions.py | 33 ++-- torch/csrc/autograd/FunctionsManual.cpp | 172 ------------------ torch/csrc/autograd/FunctionsManual.h | 23 --- .../_internal/common_methods_invocations.py | 12 +- 23 files changed, 152 insertions(+), 916 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesActivation.cpp b/aten/src/ATen/functorch/BatchRulesActivation.cpp index d96ab08a7e2fa9..2f93ae5b768f00 100644 --- a/aten/src/ATen/functorch/BatchRulesActivation.cpp +++ b/aten/src/ATen/functorch/BatchRulesActivation.cpp @@ -48,178 +48,9 @@ std::tuple> glu_backward_batch_rule( return std::make_tuple(res, 0); } -std::tuple> prelu_batch_rule( - const Tensor& input, optional input_bdim, - const Tensor& weight, optional weight_bdim) { - if (!weight_bdim && weight.dim() == 0) { - return std::make_tuple(at::prelu(input, weight), input_bdim); - } - - const auto input_ = moveBatchDimToFront(input, input_bdim); - auto weight_flatten = moveBatchDimToFront(weight, weight_bdim); - - const auto weight_logical_dim = rankWithoutBatchDim(weight, weight_bdim); - TORCH_CHECK(weight_logical_dim == 0 || weight_logical_dim == 1, - "prelu: Expected `weight` to be a scalar or 1D tensor, but got ndim = ", - weight_logical_dim); - - if (weight_flatten.dim() > 1) { - // for an input [N, C, ...] - // weight can be a non-vector but the total number of elements must be the same as C - weight_flatten = at::flatten(weight_flatten, weight_bdim.has_value() ? 1 : 0, -1); - } - - const int64_t input_logical_rank = rankWithoutBatchDim(input, input_bdim); - VmapDimVector new_shape(weight_flatten.sizes().begin(), weight_flatten.sizes().end()); - const int64_t final_size = weight_bdim ? (input_logical_rank + 1) : input_logical_rank; - new_shape.reserve(final_size); - - if (weight_flatten.dim() == 2 || !weight_bdim) { - // if weight (without batching) is not a scalar, its size must match the "channel dimension" of input. To do the - // decomposition, we pad the weight to - - // copies checks from prelu if the weight (without vmap) is not a scalar - TORCH_CHECK(input_logical_rank > 0, "Not allow zero-dim input tensor."); - - int64_t channel_size = 1; // channel_size default to 1 - if (input_logical_rank > 1) { - const auto channel_dim = input_bdim ? 2 : 1; - channel_size = input_.size(channel_dim); - } - const auto weight_num = weight_flatten.size(-1); - TORCH_CHECK(channel_size == weight_num, - "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, - " and channel size = ", channel_size, "."); - - // pads to the left so that the flattened shape matches up with the channel - if (!weight_bdim) { - new_shape.insert(new_shape.begin(), 1); - } else { - new_shape.insert(new_shape.begin() + 1, 1); - } - } - - for (int64_t i = new_shape.size(); i < final_size; i ++) { - new_shape.push_back(1); - } - TORCH_INTERNAL_ASSERT((int64_t)new_shape.size() == final_size); - const auto weight_padded = weight_flatten.view(new_shape); - auto zero_tensor = at::zeros(1, input.options()); - - // decomposes function, - auto res = at::maximum(zero_tensor, input_) + weight_padded * at::minimum(zero_tensor, input_); - return std::make_tuple(res, 0); -} - -VmapDimVector ensure_shape_with_bdim(const Tensor& input, const bool has_bdim, const int64_t batch_size) { - // helper function that get the size of input, ensuring that there's batch dim, without expanding input - if (has_bdim) { - // sad to have to copy but got garbage if tried to return an IntArrayRef and just do input.sizes() - VmapDimVector new_shape(input.sizes().begin(), input.sizes().end()); - return new_shape; - } - VmapDimVector new_shape(1, batch_size); - new_shape.reserve(input.dim() + 1); - new_shape.insert(new_shape.end(), input.sizes().begin(), input.sizes().end()); - return new_shape; -} - -VmapDimVector shape_maybe_with_bdim(const Tensor& input, const bool need_bdim, const bool has_bdim, const int64_t batch_size) { - // if need_bdim, will return the input with a guaranteed bdim. If not, will return the input logical size (no batch dim) - if (need_bdim) { - return ensure_shape_with_bdim(input, has_bdim, batch_size); - } else if (has_bdim) { // !need_bdim && has_bdim - VmapDimVector new_shape(input.sizes().begin() + 1, input.sizes().end()); - return new_shape; - } else { // !need_bdim && !has_bdim - VmapDimVector new_shape(input.sizes().begin(), input.sizes().end()); - return new_shape; - } -} - -std::tuple prelu_backward_batched( - const Tensor& grad_out, const Tensor& self, const Tensor& weight, - const VmapDimVector& self_grad_shape, const VmapDimVector& weight_grad_padded_shape, const VmapDimVector& weight_grad_shape) { - // helper function that produces a batched gradient for prelu using a decomposition inspired by the AOTAutograd ones - const auto input_grad_collector = at::where(self > 0, grad_out, weight * grad_out); - const auto input_grad = native::sum_to_size(input_grad_collector, self_grad_shape); - const auto weight_grad_collector = at::where(self > 0, at::zeros(1, self.options()), self * grad_out); - const auto weight_grad_collector_2 = native::sum_to_size(weight_grad_collector, weight_grad_padded_shape); - const auto weight_grad = weight_grad_collector_2.view(weight_grad_shape); - return std::make_tuple(input_grad, weight_grad); -} - -std::tuple,Tensor,optional> prelu_backward_batch_rule( - const Tensor& grad_out, optional grad_out_bdim, - const Tensor& self, optional self_bdim, - const Tensor& weight, optional weight_bdim) { - const auto batch_size = get_bdim_size3(grad_out, grad_out_bdim, self, self_bdim, weight, weight_bdim); - const auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim); - const auto self_ = moveBatchDimToFront(self, self_bdim); - const auto self_size_with_bdim = ensure_shape_with_bdim(self_, self_bdim.has_value(), batch_size); - if (!weight_bdim && weight.dim() == 0) { - VmapDimVector weight_grad_shape(1, batch_size); - VmapDimVector weight_grad_shape_padded(self_bdim.has_value() ? self.dim() : self.dim() + 1, 1); - weight_grad_shape_padded[0] = batch_size; - const auto grads = prelu_backward_batched(grad_out_, self_, weight, self_size_with_bdim, weight_grad_shape_padded, weight_grad_shape); - return std::make_tuple(std::get<0>(grads), 0, std::get<1>(grads), 0); - } - const auto weight_ = moveBatchDimToFront(weight, weight_bdim); - auto weight_flatten = weight_; - if (weight_flatten.dim() > 1) { - // for an input [N, C, ...] - // weight can be a non-vector but the total number of elements must be the same as C - weight_flatten = at::flatten(weight_flatten, weight_bdim.has_value() ? 1 : 0, -1); - } - - const int64_t self_logical_rank = rankWithoutBatchDim(self, self_bdim); - VmapDimVector new_shape(weight_flatten.sizes().begin(), weight_flatten.sizes().end()); - const int64_t final_size = weight_bdim ? (self_logical_rank + 1) : self_logical_rank; - new_shape.reserve(final_size); - - if (weight_flatten.dim() == 2 || !weight_bdim) { - // if weight (without batching) is not a scalar, its size must match the "channel dimension" of input. To do the - // decomposition, we pad the weight to - - // copies checks from prelu if the weight (without vmap) is not a scalar - TORCH_CHECK(self_logical_rank > 0, "Not allow zero-dim input tensor."); - - int64_t channel_size = 1; // channel_size default to 1 - if (self_logical_rank > 1) { - channel_size = self_.size(self_bdim.has_value() ? 2 : 1); - } - - const auto weight_num = weight_flatten.size(-1); - TORCH_CHECK(channel_size == weight_num, - "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, - " and channel size = ", channel_size, "."); - - // pads to the left so that the flattened shape matches up with the channel - if (!weight_bdim) { - new_shape.insert(new_shape.begin(), 1); - } else { - new_shape.insert(new_shape.begin() + 1, 1); - } - } - - for (int64_t i = new_shape.size(); i < final_size; i ++) { - new_shape.push_back(1); - } - // weight grad does not depend on weight values. It is batched iff grad_out or self are batched - const auto weight_grad_is_batched = grad_out_bdim.has_value() || self_bdim.has_value(); - - const auto weight_padded = weight_flatten.view(new_shape); - const auto weight_grad_shape = shape_maybe_with_bdim(weight_, weight_grad_is_batched, weight_bdim.has_value(), batch_size); - const auto weight_padded_grad_shape = shape_maybe_with_bdim(weight_padded, weight_grad_is_batched, weight_bdim.has_value(), batch_size); - - const auto grads = prelu_backward_batched(grad_out_, self_, weight_padded, self_size_with_bdim, weight_padded_grad_shape, weight_grad_shape); - return std::make_tuple(std::get<0>(grads), 0, std::get<1>(grads), (weight_grad_is_batched ? optional(0) : nullopt)); -} TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT(glu_backward, glu_backward_batch_rule); VMAP_SUPPORT(glu, glu_batch_rule); - VMAP_SUPPORT(prelu, prelu_batch_rule) - VMAP_SUPPORT(prelu_backward, prelu_backward_batch_rule) } }} // namespace at::functorch diff --git a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp index db601d3b0b8f11..cc478faef7c513 100644 --- a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp @@ -428,7 +428,8 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { BINARY_POINTWISE(hardtanh_backward); BINARY_POINTWISE(hardshrink_backward); BINARY_POINTWISE(hardswish_backward); - // BINARY_POINTWISE(infinitely_differentiable_gelu_backward); + BINARY_POINTWISE(_prelu_kernel); + VARIADIC_BDIMS_BOXED(_prelu_kernel_backward); BINARY_POINTWISE(leaky_relu_backward); BINARY_POINTWISE(logit_backward); VMAP_SUPPORT(log_sigmoid_backward, log_sigmoid_backward_batch_rule); diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 65d4b80a5ead16..b8e8069cfc8d0f 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -198,6 +198,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(resolve_neg); OP_DECOMPOSE(row_stack); OP_DECOMPOSE(rrelu); + OP_DECOMPOSE(prelu); OP_DECOMPOSE2(softmax, int); OP_DECOMPOSE(special_gammainc); OP_DECOMPOSE(special_gammaincc); diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index 4c1d3cc215a7c5..c845e6755f835e 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -13,6 +13,7 @@ #include #include +#include #if AT_MKLDNN_ENABLED() #include #include @@ -52,8 +53,10 @@ #include #include #include -#include #include +#include +#include +#include #include #include #include @@ -75,6 +78,7 @@ #include #include +#include #endif namespace at { @@ -262,8 +266,8 @@ DEFINE_DISPATCH(silu_stub); DEFINE_DISPATCH(silu_backward_stub); DEFINE_DISPATCH(mish_stub); DEFINE_DISPATCH(mish_backward_stub); -DEFINE_DISPATCH(prelu_cpu_stub); -DEFINE_DISPATCH(prelu_backward_cpu_stub); +DEFINE_DISPATCH(prelu_stub); +DEFINE_DISPATCH(prelu_backward_stub); TORCH_IMPL_FUNC(elu_out) ( const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, const Tensor& result @@ -676,124 +680,63 @@ TORCH_IMPL_FUNC(threshold_backward_out)(const Tensor& grad, const Tensor& self, threshold_stub(device_type(), *this, threshold, 0); } -Tensor prelu_cpu(const Tensor& self, const Tensor& weight_) { - int64_t weight_num = weight_.numel(); - Tensor result = at::empty_like(self, self.suggest_memory_format()); +Tensor prelu(const Tensor& self, const Tensor& weight_) { TORCH_INTERNAL_ASSERT(weight_.defined()); - - if (weight_num != 1) { - int64_t input_ndim = self.dim(); - TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); - - int64_t channel_size = 1; // channel_size default to 1 - if (input_ndim > 1) { - channel_size = self.size(1); // channel is the 2nd dim of input - } - TORCH_CHECK(channel_size == weight_num, - "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, + auto self_dim = self.dim(); + TORCH_CHECK(self.scalar_type() == weight_.scalar_type(), + "prelu: Type promoting not supported. Got ", + self.scalar_type(), " and ", weight_.scalar_type()); + if (weight_.sym_numel() != 1) { + TORCH_CHECK(self_dim > 0, "Not allow zero-dim input tensor."); + + auto channel_size = self_dim > 1 ? self.sym_size(1) : 1; // channel_size default to 1 + TORCH_CHECK(channel_size == weight_.sym_numel(), + "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_.numel(), " and channel size = ", channel_size, "."); } - const int64_t ndim = self.dim(); - // Helper to convert 1d tensors or scalar tensor to an nd tensor that broadcasts with input - // All elements go into the channel dimension - DimVector sizes(ndim, 1), strides(ndim, 0); - auto as_nd = [&](const Tensor& t) { - TORCH_CHECK( - t.dim() == 1 || t.dim() == 0, - "prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = ", t.dim()); - if (ndim >= 2) { - sizes[1] = t.dim() == 1 ? t.size(0) : 1; - strides[1] = t.dim() == 1 ? t.stride(0) : 0; - return t.as_strided(sizes, strides); + TORCH_CHECK( + weight_.dim() <= 1, + "prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = ", weight_.dim()); + // Adjust weight to broadcast over self and have weight.ndim == self.ndim + auto weight = weight_; + if (self_dim != weight.dim()) { + SymDimVector dim_w(self_dim, 1); + if (self_dim > 1) { + dim_w[1] = weight_.sym_numel(); } - return t.as_strided(sizes, strides); - }; - Tensor w; - if (self.scalar_type() == ScalarType::BFloat16) { - auto w_bf16 = at::empty(weight_.sizes(), weight_.options().dtype(ScalarType::BFloat16)); - w_bf16.copy_(weight_); - w = as_nd(w_bf16); - } else { - w = as_nd(weight_); + // This will always be a view in CPU/CUDA, but some backends + // like MKLDNN do not support views + weight = weight.reshape_symint(dim_w); } + return at::_prelu_kernel(self, weight); +} + +Tensor _prelu_kernel(const Tensor& self, const Tensor& weight) { + // Weight broadcasts over self and they have the same dtype + auto result = at::empty_like(self); auto iter = TensorIteratorConfig() .add_output(result) .add_input(self) - .add_input(w) + .add_input(weight) .build(); - prelu_cpu_stub(iter.device_type(), iter); + prelu_stub(iter.device_type(), iter); return result; } -std::tuple prelu_backward_cpu(const Tensor& grad_out_, const Tensor& self, const Tensor& weight_) { - int64_t weight_num = weight_.numel(); - - Tensor input_grad = at::empty_like(self, self.suggest_memory_format()); - Tensor weight_grad = at::empty_like(weight_, at::MemoryFormat::Contiguous); - Tensor weight_grad_collector = at::empty_like(self, at::MemoryFormat::Contiguous); - - if (weight_num != 1) { - int64_t input_ndim = self.dim(); - TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); - - int64_t channel_size = 1; // channel_size default to 1 - if (input_ndim > 1) { - channel_size = self.size(1); // channel is the 2nd dim of input - } - TORCH_CHECK(channel_size == weight_num, - "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, - " and channel size = ", channel_size, "."); - } - - const int64_t ndim = self.dim(); - // Helper to convert 1d tensor or scalar tensor to an nd tensor that broadcasts with input - // All elements go into the channel dimension - DimVector sizes(ndim, 1), strides(ndim, 0); - auto as_nd = [&](const Tensor& t) { - TORCH_INTERNAL_ASSERT(t.defined() && (t.dim() == 1 || t.dim() == 0)); - if (ndim >= 2) { - sizes[1] = t.dim() == 1 ? t.sizes()[0] : 1; - strides[1] = t.dim() == 1 ? t.strides()[0] : 0; - return t.as_strided(sizes, strides); - } - return t.as_strided(sizes, strides); - }; - Tensor w; - if (self.scalar_type() == ScalarType::BFloat16) { - auto w_bf16 = at::empty(weight_.sizes(), weight_.options().dtype(ScalarType::BFloat16)); - w_bf16.copy_(weight_); - w = weight_.defined() ? as_nd(w_bf16) : - at::detail::scalar_tensor_static(1, self.scalar_type(), kCPU); - } else { - w = weight_.defined() ? as_nd(weight_) : - at::detail::scalar_tensor_static(1, self.scalar_type(), kCPU); - } - +std::tuple _prelu_kernel_backward(const Tensor& grad_out, const Tensor& self, const Tensor& weight) { + Tensor grad_self = at::empty({0}, self.options()); + Tensor grad_weight = at::empty({0}, weight.options()); auto iter = TensorIteratorConfig() - .add_output(input_grad) - .add_output(weight_grad_collector) + .add_output(grad_self) + .add_output(grad_weight) .add_input(self) - .add_input(grad_out_) - .add_input(w) + .add_input(weight) + .add_input(grad_out) .build(); - - prelu_backward_cpu_stub(iter.device_type(), iter); - - if (weight_num == 1) { - weight_grad.fill_(weight_grad_collector.sum()); - } else { - // update weight_grad - std::vector reduce_dims; - int64_t input_ndim = self.dim(); - reduce_dims.push_back(0); - if (input_ndim > 2) { - for(int64_t i = 2; i < input_ndim; i++) reduce_dims.push_back(i); - } - weight_grad = weight_grad_collector.sum(reduce_dims); - } - return std::tuple{input_grad, weight_grad}; + prelu_backward_stub(iter.device_type(), iter); + return {grad_self, grad_weight}; } Tensor infinitely_differentiable_gelu_backward( diff --git a/aten/src/ATen/native/Activation.h b/aten/src/ATen/native/Activation.h index 64f6c6a6dceb7f..4e85d5bfd71f3b 100644 --- a/aten/src/ATen/native/Activation.h +++ b/aten/src/ATen/native/Activation.h @@ -84,8 +84,8 @@ DECLARE_DISPATCH(structured_activation_fn, silu_stub); DECLARE_DISPATCH(structured_activation_backward_fn, silu_backward_stub); DECLARE_DISPATCH(structured_activation_fn, mish_stub); DECLARE_DISPATCH(activation_backward_fn, mish_backward_stub); -DECLARE_DISPATCH(activation_fn, prelu_cpu_stub); -DECLARE_DISPATCH(activation_backward_fn, prelu_backward_cpu_stub); +DECLARE_DISPATCH(activation_fn, prelu_stub); +DECLARE_DISPATCH(activation_backward_fn, prelu_backward_stub); } // namespace native diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp index 728ea62f1898fb..fc7fcafdb0d7bc 100644 --- a/aten/src/ATen/native/cpu/Activation.cpp +++ b/aten/src/ATen/native/cpu/Activation.cpp @@ -5,10 +5,12 @@ #include + #include #include #include +#include #include #include #include @@ -1202,63 +1204,30 @@ void mish_backward_kernel(TensorIterator& iter) { } } -void prelu_cpu_kernel(TensorIterator& iter) { - if (iter.common_dtype() == kBFloat16) { - auto zero_vec = Vectorized((float)(0)); - auto one_vec = Vectorized((float)(1)); - cpu_kernel_vec( - iter, - [=](BFloat16 input, BFloat16 weight) -> BFloat16 { - return (float(input) > float(0)) ? float(input) : float(weight) * float(input); - }, - [=](Vectorized input, Vectorized weight) -> Vectorized { - Vectorized input0, input1; - Vectorized weight0, weight1; - std::tie(input0, input1) = convert_bfloat16_float(input); - std::tie(weight0, weight1) = convert_bfloat16_float(weight); - - auto res0 = input0 * (Vectorized::blendv(weight0, one_vec, input0 > zero_vec)); - auto res1 = input1 * (Vectorized::blendv(weight1, one_vec, input1 > zero_vec)); - return convert_float_bfloat16(res0, res1); - }); - } else { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "prelu_cpu", [&]() { +void prelu_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "prelu_cpu", [&]() { using Vec = Vectorized; - auto zero_vec = Vec((scalar_t)(0)); - auto one_vec = Vec((scalar_t)(1)); cpu_kernel_vec( iter, - [=](scalar_t input, scalar_t weight) { + [](scalar_t input, scalar_t weight) { return (input > scalar_t(0)) ? input : weight * input; }, - [=](Vec input, Vec weight) { - auto r = Vec::blendv(weight, one_vec, input > zero_vec); - return input * r; + [](Vec input, Vec weight) { + return Vec::blendv(weight * input, input, input > Vec(0)); }); - }); - } + }); } -void prelu_backward_cpu_kernel(TensorIterator& iter) { - if (iter.common_dtype() == kBFloat16) { - cpu_kernel_multiple_outputs( - iter, - [=](BFloat16 input, BFloat16 grad_out, BFloat16 weight) -> std::tuple { - float input_grad = (float(input) > float(0)) ? float(grad_out) : float(weight) * float(grad_out); - float weight_grad_collector = (float(input) > float(0)) ? float(0) : float(input) * float(grad_out); - return std::tuple(input_grad, weight_grad_collector); +void prelu_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "prelu_backward_cpu", [&]() { + cpu_kernel_multiple_outputs(iter, + [](scalar_t input, scalar_t weight, scalar_t grad) -> std::tuple { + auto mask = input > scalar_t{0}; + auto grad_input = mask ? grad : weight * grad; + auto grad_weight = mask ? scalar_t{0} : input * grad; + return {grad_input, grad_weight}; }); - } else { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "prelu_backward_cpu", [&]() { - cpu_kernel_multiple_outputs( - iter, - [=](scalar_t input, scalar_t grad_out, scalar_t weight) -> std::tuple { - scalar_t input_grad = (input > scalar_t(0)) ? grad_out : weight * grad_out; - scalar_t weight_grad_collector = (input > scalar_t(0)) ? scalar_t(0) : input * grad_out; - return std::tuple(input_grad, weight_grad_collector); - }); - }); - } + }); } } // namespace @@ -1289,8 +1258,8 @@ REGISTER_DISPATCH(silu_stub, &silu_kernel); REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel); REGISTER_DISPATCH(mish_stub, &mish_kernel); REGISTER_DISPATCH(mish_backward_stub, &mish_backward_kernel); -REGISTER_DISPATCH(prelu_cpu_stub, &prelu_cpu_kernel); -REGISTER_DISPATCH(prelu_backward_cpu_stub, &prelu_backward_cpu_kernel); +REGISTER_DISPATCH(prelu_stub, &prelu_kernel); +REGISTER_DISPATCH(prelu_backward_stub, &prelu_backward_kernel); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cuda/Activation.cpp b/aten/src/ATen/native/cuda/Activation.cpp index 31926b353b4a3e..d5e79e9643d187 100644 --- a/aten/src/ATen/native/cuda/Activation.cpp +++ b/aten/src/ATen/native/cuda/Activation.cpp @@ -18,8 +18,6 @@ #include #include #include -#include -#include #endif namespace at { namespace native { @@ -95,89 +93,6 @@ std::tuple log_sigmoid_forward_cuda(const Tensor& input) { return std::forward_as_tuple(result, buffer); } -// ----------------------------------- -// prelu forward -// ----------------------------------- - -Tensor prelu_cuda(const Tensor& self, const Tensor& weight_) { - TORCH_CHECK(self.is_cuda()); - TORCH_CHECK(weight_.is_cuda()); - - auto input = self.contiguous(); - auto weight = weight_.contiguous(); - - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(weight.is_contiguous()); - - int64_t weight_num = weight.numel(); - int64_t weight_dim = weight.dim(); - Tensor result = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - - TORCH_CHECK(weight_dim == 0 || weight_dim == 1, - "prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = ", - weight_dim); - - // case1: shared weight for all channels - if (weight_num == 1) { - auto iter = TensorIterator::unary_op(result, input); - launch_prelu_cuda_kernel_share_weights(iter, weight); - } - else { // case2: multiple weights, one for each channel - launch_prelu_cuda_kernel_multi_weights(result, input, weight); - } - return result; -} - -// ----------------------------------- -// prelu backward -// ----------------------------------- - -std::tuple prelu_backward_cuda(const Tensor& grad_out_, const Tensor& self, const Tensor& weight_) { - TORCH_CHECK(grad_out_.is_cuda()); - TORCH_CHECK(self.is_cuda()); - TORCH_CHECK(weight_.is_cuda()); - - auto input = self.contiguous(); - auto grad_out = grad_out_.contiguous(); - auto weight = weight_.contiguous(); - - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(weight.is_contiguous()); - TORCH_CHECK(grad_out.is_contiguous()); - - int64_t weight_num = weight.numel(); - auto dims = input.dim(); - Tensor input_grad = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - Tensor weight_grad = at::empty_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - Tensor weight_grad_collector = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - // case1: shared parameter for all channels - if (weight_num == 1) { - at::TensorIterator iter = TensorIteratorConfig() - .add_output(input_grad) - .add_output(weight_grad_collector) - .add_input(input) - .add_input(grad_out) - .build(); - - launch_prelu_cuda_backward_kernel_share_weights(iter, weight); - weight_grad.fill_(weight_grad_collector.sum()); - } - else { // case2: multiple parameters, one for each channel - launch_prelu_cuda_backward_kernel_multi_weights( - input, weight, grad_out, input_grad, weight_grad_collector); - // update weight_grad - std::vector reduce_dims; - reduce_dims.push_back(0); - if (dims > 2) { - for (const auto i : c10::irange(2, dims)) { - reduce_dims.push_back(i); - } - } - weight_grad = weight_grad_collector.sum(reduce_dims); - } - return std::tuple{input_grad, weight_grad}; -} - TORCH_IMPL_FUNC(gelu_out_cuda) ( const Tensor& /*self*/, c10::string_view approximate, const Tensor& /*result*/ ) { diff --git a/aten/src/ATen/native/cuda/Activation.h b/aten/src/ATen/native/cuda/Activation.h index 5fc52ff257ce19..5fbfe0d2c65569 100644 --- a/aten/src/ATen/native/cuda/Activation.h +++ b/aten/src/ATen/native/cuda/Activation.h @@ -14,17 +14,6 @@ void launch_glu_backward_kernel(const TensorIteratorBase& iter, void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter); -void launch_prelu_cuda_kernel_share_weights( - TensorIteratorBase &iter, const TensorBase &weight); -void launch_prelu_cuda_kernel_multi_weights( - const TensorBase &result, const TensorBase &input, const TensorBase &weight); - -void launch_prelu_cuda_backward_kernel_share_weights( - TensorIteratorBase &iter, const TensorBase &weight); -void launch_prelu_cuda_backward_kernel_multi_weights( - const TensorBase &input, const TensorBase &weight, const TensorBase &grad_out, - const TensorBase &input_grad, const TensorBase &weight_grad_collector); - void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate); void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate); diff --git a/aten/src/ATen/native/cuda/ActivationPreluKernel.cu b/aten/src/ATen/native/cuda/ActivationPreluKernel.cu index 0d8f09714698ee..7ae748599da083 100644 --- a/aten/src/ATen/native/cuda/ActivationPreluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationPreluKernel.cu @@ -20,156 +20,31 @@ namespace at { namespace native { // ----------------------------------- -// prelu forward +// prelu // ----------------------------------- -void launch_prelu_cuda_kernel_share_weights(TensorIteratorBase &iter, const TensorBase &weight) { - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, iter.input_dtype(), "prelu_cuda", [&] { - const auto *weight_data = weight.data_ptr(); - at::native::gpu_kernel(iter, - [weight_data] GPU_LAMBDA (scalar_t input_val) { - return (input_val > 0) ? input_val : *weight_data * input_val; - }); - }); -} - -template -__global__ void prelu_cuda_kernel_multi_weights( - scalar_t* result_data, - const scalar_t* input_data, - const scalar_t* weight_data, - int64_t input_stride0, - int64_t input_stride1, - int64_t input_numel) { - - int64_t linearId = blockIdx.x * blockDim.x + threadIdx.x; - if (linearId >= input_numel) return; - - // multiply values at each channel with weight[channel_index] - int64_t channel = (linearId % input_stride0) / input_stride1; - scalar_t input_data_val = input_data[linearId]; - result_data[linearId] = (input_data_val > 0) ? input_data_val : weight_data[channel] * input_data_val; -} - -void launch_prelu_cuda_kernel_multi_weights( - const TensorBase &result, const TensorBase &input, const TensorBase &weight) { - int64_t input_ndim = input.dim(); - TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); - - int64_t channel_size = 1; // channel_size default to 1 - int64_t input_stride0 = 1, input_stride1 = 1; - - if (input_ndim > 1) { - channel_size = input.size(1); // channel is the 2nd dim of input - auto strides = input.strides(); - input_stride0 = strides[0]; - input_stride1 = strides[1]; - } - const int64_t weight_num = weight.numel(); - TORCH_CHECK(channel_size == weight_num, - "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, - " and channel size = ", channel_size, "."); - - // config to run cuda kernel - int64_t input_numel = input.numel(); - const dim3 block = dim3(std::min(static_cast(cuda::getApplyBlock().x), input_numel)); - dim3 grid; - int curDevice = -1; - cudaGetDevice(&curDevice); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); - TORCH_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu: input too large or too many dimensions"); - - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "prelu_cuda", [&] { - prelu_cuda_kernel_multi_weights - <<>>( - result.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - input_stride0, - input_stride1, - input_numel); - C10_CUDA_KERNEL_LAUNCH_CHECK(); +void prelu_kernel(TensorIterator &iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "prelu_cuda", [&] { + gpu_kernel(iter, + [] GPU_LAMBDA (scalar_t input, scalar_t weight) -> scalar_t { + return (input > 0) ? input : weight * input; + }); }); } -// ----------------------------------- -// prelu backward -// ----------------------------------- -void launch_prelu_cuda_backward_kernel_share_weights( - TensorIteratorBase &iter, const TensorBase &weight) { - // N.B. `std::tuple` does not support `::operator=` on device code. - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, iter.input_dtype(), "prelu_backward_cuda", [&] { - const auto *weight_data = weight.data_ptr(); - gpu_kernel_multiple_outputs(iter, [=] GPU_LAMBDA (scalar_t input, scalar_t grad_out) -> thrust::tuple { - scalar_t input_grad = input > 0 ? grad_out : (*weight_data) * grad_out; - scalar_t weight_grad_collector = input > 0 ? scalar_t(0) : input * grad_out; - return {input_grad, weight_grad_collector}; +void prelu_backward_kernel(TensorIterator &iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "prelu_backward_cuda", [&] { + gpu_kernel_multiple_outputs(iter, + [] GPU_LAMBDA (scalar_t input, scalar_t weight, scalar_t grad) -> thrust::tuple { + auto mask = input > 0; + auto grad_input = mask ? grad : weight * grad; + auto grad_weight = mask ? scalar_t{0} : input * grad; + return {grad_input, grad_weight}; }); }); } -template -__global__ void prelu_cuda_backward_kernel_multi_weights( - const scalar_t* input_data, - const scalar_t* weight_data, - const scalar_t* grad_out_data, - scalar_t* input_grad_data, - scalar_t* weight_grad_collector, - int64_t input_stride0, - int64_t input_stride1, - int64_t input_numel) { - - int64_t linearId = blockIdx.x * blockDim.x + threadIdx.x; - if (linearId >= input_numel) return; - int64_t channel = (linearId % input_stride0) / input_stride1; - scalar_t input_data_val = input_data[linearId]; - scalar_t grad_out_data_val = grad_out_data[linearId]; - input_grad_data[linearId] = (input_data_val > 0) ? grad_out_data_val : weight_data[channel] * grad_out_data_val; - weight_grad_collector[linearId] = (input_data_val > 0) ? scalar_t(0) : input_data_val * grad_out_data_val; -} - -void launch_prelu_cuda_backward_kernel_multi_weights( - const TensorBase &input, const TensorBase &weight, const TensorBase &grad_out, - const TensorBase &input_grad, const TensorBase &weight_grad_collector) { - int64_t input_ndim = input.dim(); - TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); - - int64_t channel_size = 1; // channel_size default to 1 - int64_t input_stride0 = 1, input_stride1 = 1; - - if (input_ndim > 1) { - channel_size = input.size(1); // channel is the 2nd dim of input - auto strides = input.strides(); - input_stride0 = strides[0]; - input_stride1 = strides[1]; - } - const int64_t weight_num = weight.numel(); - TORCH_CHECK(channel_size == weight_num, - "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, - " and channel size = ", channel_size, "."); - - // config to run cuda kernel - int64_t input_numel = input.numel(); - const dim3 block = dim3(std::min(static_cast(cuda::getApplyBlock().x), input_numel)); - dim3 grid; - int curDevice = -1; - cudaGetDevice(&curDevice); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); - TORCH_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu_backward_cuda: input too large or too many dimensions"); - - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "prelu_backward_cuda", [&] { - prelu_cuda_backward_kernel_multi_weights - <<>>( - input.data_ptr(), - weight.data_ptr(), - grad_out.data_ptr(), - input_grad.data_ptr(), - weight_grad_collector.data_ptr(), - input_stride0, - input_stride1, - input_numel); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} +REGISTER_DISPATCH(prelu_stub, &prelu_kernel); +REGISTER_DISPATCH(prelu_backward_stub, &prelu_backward_kernel); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mkldnn/Prelu.cpp b/aten/src/ATen/native/mkldnn/Prelu.cpp index dc7d239da7b68e..0a17d0384aa83c 100644 --- a/aten/src/ATen/native/mkldnn/Prelu.cpp +++ b/aten/src/ATen/native/mkldnn/Prelu.cpp @@ -30,13 +30,6 @@ Tensor mkldnn_prelu(const Tensor& input, const Tensor& weight) { "mkldnn_relu: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"); } - int64_t weight_num = weight.numel(); - if (weight_num != 1) { - int64_t channel_size = input.dim() > 1 ? input.size(1) : 1; - TORCH_CHECK(channel_size == weight_num, - "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, - " and channel size = ", channel_size, "."); - } const ideep::tensor& x = itensor_from_mkldnn(input); const ideep::tensor& w = itensor_from_tensor(weight); diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index 9d024eb73ae57b..5f38ac72c8d223 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -1597,7 +1597,6 @@ Tensor glu_backward_mps (const Tensor& grad_output, Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { using namespace mps; - int64_t weight_num = weight_.numel(); Tensor result = at::empty_like(self, self.suggest_memory_format()); TORCH_INTERNAL_ASSERT(weight_.defined()); @@ -1605,31 +1604,6 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { return result; } - TORCH_CHECK( - weight_.dim() == 1 || weight_.dim() == 0, - "prelu: Expected `weight` to be a scalar or 1D tensor, but got ndim = ", weight_.dim() - ); - - int64_t input_ndim = self.dim(); - NSMutableArray * expand_dims = [NSMutableArray new]; - - if (weight_num != 1) { - TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); - - int64_t channel_size = 1; // channel_size default to 1 - if (input_ndim > 1) { - channel_size = self.size(1); // channel is the 2nd dim of input - } - TORCH_CHECK(channel_size == weight_num, - "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, - " and channel size = ", channel_size, "."); - - for (const auto i : c10::irange(input_ndim)) { - if (i == 1) continue; - [expand_dims addObject:[NSNumber numberWithInt:i]]; - } - } - struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} @@ -1643,8 +1617,7 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - NSString* expand_dims_key = [[expand_dims valueForKey:@"description"] componentsJoinedByString:@","]; - string key = "prelu_mps:" + getTensorsStringKey({self, weight_}) + string([expand_dims_key UTF8String]); + string key = "prelu_mps:" + getTensorsStringKey({self, weight_}); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { @@ -1671,18 +1644,9 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { truePredicateTensor: inputTensor falsePredicateTensor: zeroTensor name: nil]; - if (weight_num != 1) { - MPSGraphTensor *expandedWeightTensor = [mpsGraph expandDimsOfTensor:weightTensor - axes:expand_dims - name:nil]; - weightedTensor = [mpsGraph multiplicationWithPrimaryTensor:weightedTensor - secondaryTensor:expandedWeightTensor - name:nil]; - }else{ - weightedTensor = [mpsGraph multiplicationWithPrimaryTensor:weightedTensor - secondaryTensor:weightTensor - name:nil]; - } + weightedTensor = [mpsGraph multiplicationWithPrimaryTensor:weightedTensor + secondaryTensor:weightTensor + name:nil]; MPSGraphTensor *outputTensor = [mpsGraph additionWithPrimaryTensor:reluTensor secondaryTensor:weightedTensor name:nil]; @@ -1715,35 +1679,9 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { std::tuple prelu_backward_mps(const Tensor& grad_output, const Tensor& self, const Tensor& weight_) { using namespace mps; - int64_t weight_num = weight_.numel(); - NSMutableArray * reduce_dims = [NSMutableArray new]; Tensor grad_input = at::empty_like(self, self.suggest_memory_format()); Tensor weight_grad = at::empty_like(weight_, at::MemoryFormat::Contiguous); - TORCH_CHECK( - weight_.dim() == 1 || weight_.dim() == 0, - "prelu: Expected `weight` to be a scalar or 1D tensor, but got ndim = ", weight_.dim() - ); - - if (weight_num != 1) { - int64_t input_ndim = self.dim(); - TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); - - int64_t channel_size = 1; // channel_size default to 1 - if (input_ndim > 1) { - channel_size = self.size(1); // channel is the 2nd dim of input - } - TORCH_CHECK(channel_size == weight_num, - "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, - " and channel size = ", channel_size, "." - ); - - for (const auto i : c10::irange(input_ndim)) { - if (i == 1) continue; - [reduce_dims addObject:[NSNumber numberWithInt:i]]; - } - } - struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} @@ -1759,8 +1697,7 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - NSString* reduce_dims_key = [[reduce_dims valueForKey:@"description"] componentsJoinedByString:@","]; - string key = "prelu_backward_mps:" + getTensorsStringKey({grad_output, self, weight_}) + ":" + string([reduce_dims_key UTF8String]); + string key = "prelu_backward_mps:" + getTensorsStringKey({grad_output, self, weight_}); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { @@ -1781,19 +1718,9 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { MPSGraphTensor *zeroTensor = [mpsGraph constantWithScalar: 0.0 shape:@[@1] dataType: inputTensor.dataType]; - MPSGraphTensor* weightedGradOutputTensor = nil; - if (weight_num != 1) { - MPSGraphTensor *expandedWeightTensor = [mpsGraph expandDimsOfTensor:weightTensor - axes:reduce_dims - name:nil]; - weightedGradOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:expandedWeightTensor - secondaryTensor:gradOutputTensor - name:nil]; - } else { - weightedGradOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:weightTensor - secondaryTensor:gradOutputTensor - name:nil]; - } + MPSGraphTensor* weightedGradOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:weightTensor + secondaryTensor:gradOutputTensor + name:nil]; MPSGraphTensor* inputGradOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor secondaryTensor:gradOutputTensor name:nil]; @@ -1808,10 +1735,6 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { truePredicateTensor: zeroTensor falsePredicateTensor: inputGradOutputTensor name: nil]; - weightedGradTensor = [mpsGraph reductionSumWithTensor:weightedGradTensor - axes:reduce_dims - name:nil]; - newCachedGraph->gradOutputTensor_ = gradOutputTensor; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->weightTensor_ = weightTensor; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 71fbc6f5f2c2cc..ddecf3160f6f82 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4654,22 +4654,20 @@ - func: prelu(Tensor self, Tensor weight) -> Tensor variants: function, method + autogen: prelu.out + +- func: _prelu_kernel(Tensor self, Tensor weight) -> Tensor dispatch: + CPU, CUDA: _prelu_kernel + QuantizedCPU: _prelu_kernel_quantized_cpu MkldnnCPU: mkldnn_prelu - CPU: prelu_cpu - CUDA: prelu_cuda MPS: prelu_mps - QuantizedCPU: prelu_quantized_cpu - autogen: prelu.out -- func: prelu_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) - variants: function, method +- func: _prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) dispatch: + CPU, CUDA: _prelu_kernel_backward MkldnnCPU: mkldnn_prelu_backward - CPU: prelu_backward_cpu - CUDA: prelu_backward_cuda MPS: prelu_backward_mps - autogen: prelu_backward.out - func: gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!) structured: True diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index a1f8f0d7c24579..aabd980c9f00c2 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -680,14 +680,18 @@ static void qprelu_out_kernel(Tensor& out, int64_t input_ndim = qx.dim(); TORCH_CHECK(input_ndim > 0, "qprelu: zero-dim input tensor is not allowed."); - // Weight should be a 1d or scalar tensor - // Reshape it to an nd tensor that broadcasts with input - // All elements go into the channel dimension - DimVector sizes(input_ndim, 1); - if (input_ndim > 1) { - sizes[1] = qw.numel(); + // This logic is present in at::prelu and repeated here, as this path can be + // hit via quantized::prelu, which is registered under quantized/cpu/qprelu.cpu + auto qw_nd = qw; + if (input_ndim != qw_nd.dim()) { + DimVector dim_w(input_ndim, 1); + if (input_ndim > 1) { + dim_w[1] = qw.numel(); + } + // This will always be a view in CPU/CUDA, but some backends + // like MKLDNN do not support views + qw_nd = qw_nd.reshape(dim_w); } - auto qw_nd = qw.reshape(sizes); auto iter = TensorIteratorConfig() .add_output(out) diff --git a/aten/src/ATen/native/quantized/cpu/qrelu.cpp b/aten/src/ATen/native/quantized/cpu/qrelu.cpp index fcdfb0e9260c68..739b5c5af30228 100644 --- a/aten/src/ATen/native/quantized/cpu/qrelu.cpp +++ b/aten/src/ATen/native/quantized/cpu/qrelu.cpp @@ -150,7 +150,7 @@ Tensor& leaky_relu_quantized_cpu_(Tensor& self, const Scalar& negval) { return self; } -Tensor prelu_quantized_cpu_impl(const Tensor& self, const Tensor& weight, +Tensor _prelu_kernel_quantized_cpu_impl(const Tensor& self, const Tensor& weight, double output_scale, int64_t output_zero_point) { auto ndim = self.dim(); // for ndim < 1 or > 5, go to reference path @@ -172,8 +172,8 @@ Tensor prelu_quantized_cpu_impl(const Tensor& self, const Tensor& weight, return qy; } -Tensor prelu_quantized_cpu(const Tensor& self, const Tensor& weight) { - return prelu_quantized_cpu_impl(self, weight, self.q_scale(), self.q_zero_point()); +Tensor _prelu_kernel_quantized_cpu(const Tensor& self, const Tensor& weight) { + return _prelu_kernel_quantized_cpu_impl(self, weight, self.q_scale(), self.q_zero_point()); } namespace { @@ -220,7 +220,7 @@ class QLeakyRelu final { class QPRelu final { public: static Tensor run(Tensor self, const Tensor& weight, double output_scale, int64_t output_zero_point) { - return prelu_quantized_cpu_impl(self, weight, output_scale, output_zero_point); + return _prelu_kernel_quantized_cpu_impl(self, weight, output_scale, output_zero_point); } }; diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 4d10ea835fee08..b5116e373cb8f1 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -98,6 +98,9 @@ ("aten::col2im_backward", datetime.date(2022, 12, 1)), ("aten::im2col_backward", datetime.date(2022, 12, 1)), ("aten::diag_backward", datetime.date(2022, 12, 1)), + ("aten::prelu.out", datetime.date(2023, 3, 1)), + ("aten::prelu_backward", datetime.date(2023, 3, 1)), + ("aten::prelu_backward.out", datetime.date(2023, 3, 1)), ("aten::solve", datetime.date(9999, 1, 1)), ("aten::solve.solution", datetime.date(9999, 1, 1)), ("aten::_solve_helper", datetime.date(9999, 1, 1)), diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 7e2d001f10a321..f806f0f09f2b49 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -2103,7 +2103,6 @@ def forward(self, x): xfail('nn.functional.pdist', ''), # could not find kernel xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta fun... xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta... - xfail('nn.functional.prelu', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function... xfail('nn.functional.smooth_l1_loss', ''), # could not find kernel xfail('nn.functional.unfold', ''), # Cannot call sizes() on tensor with symbolic sizes/strides diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index f2ed272919b1ff..723b8a4354a9b2 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -785,7 +785,6 @@ def fn(inp, *args, **kwargs): xfail("to_sparse"), xfail("native_batch_norm"), xfail("_native_batch_norm_legit"), - xfail('nn.functional.prelu'), })) @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) @@ -903,7 +902,6 @@ def vjp_of_vjp(*args_and_cotangents): xfail('sparse.sampled_addmm', ''), xfail('as_strided_scatter', ''), # calls as_strided xfail('index_reduce', ''), # .item() call - xfail('nn.functional.prelu'), # --------------------------------------------------------------------- }) @@ -998,7 +996,6 @@ def test_vmapvjp(self, device, dtype, op): xfail("native_batch_norm"), xfail("_native_batch_norm_legit"), - xfail('nn.functional.prelu'), # ---------------------------------------------------------------------- } @@ -1279,7 +1276,6 @@ def test(): xfail("native_batch_norm"), xfail("_native_batch_norm_legit"), xfail('as_strided', 'partial_views'), - xfail('nn.functional.prelu'), })) def test_vjpvmap(self, device, dtype, op): # NB: there is no vjpvmap_has_batch_rule test because that is almost @@ -1546,7 +1542,6 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): xfail("native_batch_norm"), xfail("_native_batch_norm_legit"), xfail('native_dropout_backward'), - xfail('nn.functional.prelu'), })) @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) @@ -1786,7 +1781,6 @@ def fn(input, weight, bias): skip('linalg.multi_dot', '', device_type='cpu'), skip('sparse.sampled_addmm', ''), skip('native_layer_norm', '', device_type='cpu'), - xfail('nn.functional.prelu'), }) @opsToleranceOverride('TestOperators', 'test_vmap_autograd_grad', ( tol1('linalg.householder_product', diff --git a/test/test_cuda.py b/test/test_cuda.py index ffb2b860203d10..605feb8b3367f3 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -2942,7 +2942,7 @@ def test_autocast_torch_bf16(self): ('TORCH_CUDNN_V8_API_DISABLED' in os.environ and int(os.environ['TORCH_CUDNN_V8_API_DISABLED']) or torch.cuda.get_device_capability() < (8, 0)) - should_error_from_not_implemented = should_error_from_cudnn or 'prelu' in op or 'thnn' in op \ + should_error_from_not_implemented = should_error_from_cudnn or 'thnn' in op \ or 'fused' in op or 'gru' in op or op == '_thnn_fused_lstm_cell' or op == 'lstm_cell' if not skip_test: if should_error_from_not_implemented: diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 54a9c50ce6c5af..81ee3a2815853f 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1987,20 +1987,24 @@ self: _masked_softmax_backward(grad, result, mask, dim) mask: non_differentiable -- name: prelu(Tensor self, Tensor weight) -> Tensor - self, weight: "grad.defined() ? prelu_backward(grad, self, weight) : std::tuple()" - result: prelu_jvp(self_p, self_t, weight_p, weight_t) - -- name: prelu_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) - grad_output, self, weight: prelu_double_backward(grads[0], grads[1], grad_output, self, weight) - result0: prelu_backward_self_jvp(self_p, weight_p, weight_t, grad_output_p, grad_output_t) - result1: prelu_backward_weight_jvp(weight_p, self_p, self_t, grad_output_p, grad_output_t) +- name: _prelu_kernel(Tensor self, Tensor weight) -> Tensor + self, weight: "grad.defined() ? _prelu_kernel_backward(grad, self, weight) : std::tuple()" + result: at::where(self_p >= 0, self_t, weight_p * self_t + weight_t * self_p) + +- name: _prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) + grad_output: "grads[0].defined() ? + (grads[1].defined() ? at::where(self >= 0, grads[0], grads[0] * weight + grads[1] * self) + : at::where(self >= 0, grads[0], grads[0] * weight)) + : at::where(self >= 0, at::zeros({}, grad_output.options()), grads[1] * self)" + self: "grads[1].defined() ? at::where(self >= 0, at::zeros({}, self.options()), grad_output * grads[1]) : zeros_like(self)" + weight: "grads[0].defined() ? at::where(self >= 0, at::zeros({}, weight.options()), grad_output * grads[0]) : zeros_like(self)" + result0: at::where(self_p >= 0, grad_output_t, grad_output_t * weight_p + grad_output_p * weight_t) + result1: at::where(self_p >= 0, at::zeros({}, self_p.options()), grad_output_p * self_t + grad_output_t * self_p) - name: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false) result: auto_element_wise - - name: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 24a0ee7388d9bf..ae41eedc740e96 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -277,27 +277,20 @@ def softshrink_backward(grad_output: Tensor, self: Tensor, lambd: float) -> Tens return torch.where((self >= -lambd) & (self <= lambd), 0.0, grad_output) -@register_decomposition(aten.prelu_backward) -@pw_cast_for_opmath -def prelu_backward( - grad_output: Tensor, self: Tensor, weight: Tensor +@register_decomposition(aten._prelu_kernel) +def _prelu_kernel(self: Tensor, weight: Tensor) -> Tensor: + return torch.where(self > 0, self, weight * self) + + +@register_decomposition(aten._prelu_kernel_backward) +def _prelu_kernel_backward( + grad_output: Tensor, + self: Tensor, + weight: Tensor, ) -> Tuple[Tensor, Tensor]: - # Logic is more complicated than I would like. Basically, weight can either - # be a scalar or a vector of size [C], and in the forward pass it's - # broadcast against [N, C, ...]. So now, we need to do the corresponding - # reduction, which is harder than we'd like... - cur_weight = weight - for _ in range(2, grad_output.dim()): - cur_weight = cur_weight.unsqueeze(-1) - input_grad = torch.where(self > 0, grad_output, cur_weight * grad_output) - weight_grad_collector = torch.where(self > 0, 0.0, self * grad_output) - if len(self.shape) == 0: - out = weight_grad_collector.view(cur_weight.shape) - else: - out = weight_grad_collector.sum_to_size(cur_weight.shape) - while out.dim() > weight.dim(): - out = out.squeeze(-1) - return (input_grad.view_as(self), out) + input_grad = torch.where(self > 0, grad_output, weight * grad_output) + weight_grad = torch.where(self > 0, 0.0, self * grad_output) + return (input_grad, weight_grad) @register_decomposition(aten.rrelu_with_noise_backward) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index fa06d23faee8d5..741339c6f34321 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -2930,178 +2930,6 @@ std::tuple atan2_backward( output_mask[1] ? grad * -self * recip : Tensor()}; } -Tensor prelu_jvp( - const Tensor& x, - const Tensor& dx, - const Tensor& w, - const Tensor& dw) { - const auto ndim = x.dim(); - auto as_nd = [ndim](const Tensor& t) { - std::vector sizes(ndim, 1), strides(ndim, 0); - if (ndim >= 2) { - sizes[1] = t.dim() == 1 ? t.sizes()[0] : 1; - strides[1] = t.dim() == 1 ? t.strides()[0] : 0; - return t.as_strided(sizes, strides); - } - return t.as_strided(sizes, strides); - }; - auto w_ = as_nd(w); - auto dw_ = as_nd(dw); - return at::where(x >= 0, dx, w_ * dx + dw_ * x); -} - -// TODO: Seriously consider writing the derivative formulas for -// each output separately; there is not all that much sharing -// of computation going on here. -std::tuple prelu_double_backward( - const Tensor& grad_grad_input, - const Tensor& grad_grad_weight, - const Tensor& grad_out, - const Tensor& input_, - const Tensor& weight_) { - if (!(grad_grad_input.defined() || grad_grad_weight.defined() || - grad_out.defined())) { - return std::tuple(Tensor(), Tensor(), Tensor()); - } - auto input = input_.contiguous(); - auto weight = weight_.contiguous(); - - // Zero-fill undefined grads (TODO: do this more efficiently) - auto ggI = grad_grad_input.defined() - ? grad_grad_input.contiguous() - : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto ggW = grad_grad_weight.defined() - ? grad_grad_weight.contiguous() - : at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto gO = grad_out.defined() - ? grad_out.contiguous() - : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - - auto positive_mask = (input > 0).type_as(ggI); - auto nonpositive_mask = (input <= 0).type_as(ggW); - - // Explanation: Let input be i, weight be w, grad_output be gO. - // f(i, w) = i if i > 0 - // = w * i if i <= 0 - // gI = df/di * gO = gO if i > 0 gW = df/dw * gO = 0 if i > 0 - // = gO * w if i <= 0 = gO * i if i <= 0 - // The rest is taking derivatives of these wrt i, w, gO and summing/expanding - // properly. - - if (weight.numel() == 1) { - // from PReLU.forward: num_parameters == 0 is used indicate that a - // single weight is shared among all input channels. - - // this is a little tricky because PReLU currently doesn't take a shape so - // the weight may be 1-d when the input is a scalar (and there isn't a good - // Parameter API for that anyway until Variable and tensor are merged). So, - // use weight and ggW as 0-dim in this case. - bool scalar_input_1d_weight = - (positive_mask.dim() == 0 && weight.dim() == 1); - auto weight_maybe_squeeze = - scalar_input_1d_weight ? weight.squeeze() : weight; - auto ggW_maybe_squeeze = scalar_input_1d_weight ? ggW.squeeze() : ggW; - - auto mask = positive_mask + - nonpositive_mask * weight_maybe_squeeze.expand_as(input); - auto ggO = ggI * mask + - ggW_maybe_squeeze.expand_as(gO) * (nonpositive_mask * input); - return std::tuple( - ggO, - ggW_maybe_squeeze.expand_as(gO) * gO * nonpositive_mask, - (ggI * gO * nonpositive_mask).sum().expand_as(weight)); - } else { - // Expand ggW to match size of ggI; a simple expand doesn't work because - // ggW is the size of the input channel (dim==1 unless there is only 1 - // dimension). For example, let ggI be size (3,4,5,6,7) and ggW be size - // (4). Then we unsqueeze ggW to be size (4,1,1,1) so the expand succeeds. - auto dims_to_unsqueeze = std::max(input.dim() - 2, 0); - auto ggW_expanded = ggW; - for (const auto i : c10::irange(dims_to_unsqueeze)) { - (void)i; // Suppress unused variable warning - ggW_expanded = ggW_expanded.unsqueeze(1); - } - ggW_expanded = ggW_expanded.expand_as(ggI); - - auto gI = ggW_expanded * gO * nonpositive_mask; - - auto gW = ggI * gO * nonpositive_mask; - if (input.dim() > 1) { - gW = gW.sum(0); - } - while (gW.dim() > 1) { - gW = gW.sum(1); - } - - Tensor ggO; - // areAnyTensorSubclassLike check necessary for composite compiance: - // e.g. it's possible that grad_out/gO is a BatchedTensor wrapping - // some Tensor that does require grad - if (areAnyTensorSubclassLike({grad_out}) || gO.requires_grad()) { - // expand weight as input as in ggW/ggI above - auto weight_expanded = weight; - for (const auto i : c10::irange(dims_to_unsqueeze)) { - (void)i; // Suppress unused variable warning - weight_expanded = weight_expanded.unsqueeze(1); - } - weight_expanded = weight_expanded.expand_as(input); - - auto mask = positive_mask + nonpositive_mask * weight_expanded; - ggO = ggI * mask + ggW_expanded * nonpositive_mask * input; - } - return std::tuple{ggO, gI, gW}; - } -} - -Tensor prelu_backward_self_jvp( - const Tensor& x, - const Tensor& w, - const Tensor& dw, - const Tensor& g, - const Tensor& dg) { - const auto ndim = x.dim(); - auto as_nd = [ndim](const Tensor& t) { - std::vector sizes(ndim, 1), strides(ndim, 0); - if (ndim >= 2) { - sizes[1] = t.dim() == 1 ? t.sizes()[0] : 1; - strides[1] = t.dim() == 1 ? t.strides()[0] : 0; - return t.as_strided(sizes, strides); - } - return t.as_strided(sizes, strides); - }; - auto w_ = as_nd(w); - auto dw_ = as_nd(dw); - return at::where(x >= 0, dg, dg * w_ + g * dw_); -} - -Tensor prelu_backward_weight_jvp( - const Tensor& w, - const Tensor& x, - const Tensor& dx, - const Tensor& g, - const Tensor& dg) { - const auto dw_full = - at::where(x >= 0, at::zeros({}, x.options()), g * dx + dg * x); - - const auto ndim = x.dim(); - std::vector reduction_dims; - reduction_dims.reserve(ndim); - // we always reduce over the 0th dim. - reduction_dims.push_back(0); - if (ndim >= 2) { - // reduce over the 1th dim if w is a 0-dim tensor - if (!w.dim()) { - reduction_dims.push_back(1); - } - // reduce over dims which are >= 2. - for (int64_t i = 2; i < ndim; ++i) { - reduction_dims.push_back(i); - } - } - const auto dw = dw_full.sum(reduction_dims); - return dw.view_as(w); -} - Tensor gelu_double_backward( const Tensor& ggI, const Tensor& gO, diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index e1475ecb2b9684..8acc29a1b24677 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -704,29 +704,6 @@ infinitely_differentiable_native_group_norm_backward( int64_t group, double eps, std::array grad_input_mask); -Tensor prelu_jvp( - const Tensor& x, - const Tensor& dx, - const Tensor& w, - const Tensor& dw); -std::tuple prelu_double_backward( - const Tensor& grad_grad_input, - const Tensor& grad_grad_weight, - const Tensor& grad_out, - const Tensor& input_, - const Tensor& weight_); -Tensor prelu_backward_self_jvp( - const Tensor& x, - const Tensor& w, - const Tensor& dw, - const Tensor& g, - const Tensor& dg); -Tensor prelu_backward_weight_jvp( - const Tensor& w, - const Tensor& x, - const Tensor& dx, - const Tensor& g, - const Tensor& dg); Tensor gelu_double_backward( const Tensor& ggI, const Tensor& gO, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index fe71b36645cd93..6a33afd06fa7f5 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -12022,14 +12022,16 @@ def reference_flatten(input, start_dim=0, end_dim=-1): 'TestUnaryUfuncs', device_type='cuda', ), ], ), + # Marked as a Unary function because it has some rather odd broadcasting semantics in its + # second argument UnaryUfuncInfo( 'nn.functional.prelu', - aten_backward_name='prelu_backward', + aten_backward_name='_prelu_kernel_backward', ref=lambda x, weight: np.maximum(0., x) + np.minimum(0., x) * (weight if x.ndim == 1 else weight.reshape([weight.size if i == 1 else 1 for i in range(0, x.ndim)])), dtypes=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.float16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_autograd=True, @@ -12042,12 +12044,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): sample_inputs_func=sample_inputs_prelu, reference_inputs_func=reference_inputs_prelu, decorators=[ - # https://github.com/pytorch/pytorch/issues/89895 - DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), - DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), - # FIXME: second derivative is implemented but seems to be incorrect - # https://github.com/pytorch/pytorch/issues/68760 - DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'), # RuntimeError: Cannot insert a Tensor that requires grad as a constant. # Consider making it a parameter or input, or detaching the gradient # https://github.com/pytorch/pytorch/issues/68752