From 8dafde823d08f6039e1e45c4f36459dcb75992ff Mon Sep 17 00:00:00 2001 From: xytintel Date: Thu, 11 Jul 2024 02:36:38 +0000 Subject: [PATCH 1/2] add mish ops --- src/ATen/native/xpu/Activation.cpp | 29 ++++++++++ src/ATen/native/xpu/XPUFallback.template | 1 - .../native/xpu/sycl/ActivationMishKernels.cpp | 56 +++++++++++++++++++ .../native/xpu/sycl/ActivationMishKernels.h | 11 ++++ src/comm/XPUMathCompat.h | 8 +++ test/xpu/xpu_test_utils.py | 1 + yaml/xpu_functions.yaml | 4 ++ 7 files changed, 109 insertions(+), 1 deletion(-) create mode 100644 src/ATen/native/xpu/sycl/ActivationMishKernels.cpp create mode 100644 src/ATen/native/xpu/sycl/ActivationMishKernels.h diff --git a/src/ATen/native/xpu/Activation.cpp b/src/ATen/native/xpu/Activation.cpp index ee673fef6..ebb69b7f5 100644 --- a/src/ATen/native/xpu/Activation.cpp +++ b/src/ATen/native/xpu/Activation.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -632,4 +633,32 @@ Tensor& XPUNativeFunctions::softshrink_backward_out( return grad_input; } +Tensor XPUNativeFunctions::mish(const Tensor& self) { + Tensor out; + auto iter = TensorIterator::unary_op(out, self); + native::xpu::mish_kernel(iter); + return iter.output(); +} + +Tensor& XPUNativeFunctions::mish_out(const Tensor& self, Tensor& out) { + auto iter = TensorIterator::unary_op(out, self); + native::xpu::mish_kernel(iter); + return out; +} + +Tensor& XPUNativeFunctions::mish_(Tensor& self) { + auto iter = TensorIterator::unary_op(self, self); + native::xpu::mish_kernel(iter); + return self; +} + +Tensor XPUNativeFunctions::mish_backward( + const Tensor& grad_output, + const Tensor& input) { + Tensor grad_input = at::empty({0}, input.options()); + auto iter = TensorIterator::binary_op(grad_input, grad_output, input); + native::xpu::mish_backward_kernel(iter); + return grad_input; +} + } // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 7bfdd6abd..d27687919 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -276,7 +276,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "max_unpool2d", "max_unpool3d", "median", - "mish.out", "mode", "multilabel_margin_loss_backward", "multilabel_margin_loss_forward", diff --git a/src/ATen/native/xpu/sycl/ActivationMishKernels.cpp b/src/ATen/native/xpu/sycl/ActivationMishKernels.cpp new file mode 100644 index 000000000..0a1b43347 --- /dev/null +++ b/src/ATen/native/xpu/sycl/ActivationMishKernels.cpp @@ -0,0 +1,56 @@ +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace at::native::xpu { + +template +struct MishFunctor { + scalar_t operator()(scalar_t x) const { + using opmath_t = at::opmath_type; + const opmath_t x_acc = static_cast(x); + return x_acc * + c10::xpu::compat::tanh( + c10::xpu::compat::log1p(c10::xpu::compat::exp(x_acc))); + } +}; + +void mish_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "mish_xpu", + [&]() { gpu_kernel(iter, MishFunctor()); }); +} + +template +struct MishBackwardFunctor { + scalar_t operator()(scalar_t dy, scalar_t x) const { + using opmath_t = at::opmath_type; + const opmath_t dy_acc = static_cast(dy); + const opmath_t x_acc = static_cast(x); + const opmath_t s_acc = + opmath_t(1) / (opmath_t(1) + c10::xpu::compat::exp(-x_acc)); + const opmath_t t_acc = c10::xpu::compat::tanh( + c10::xpu::compat::log1p(c10::xpu::compat::exp(x_acc))); + return dy_acc * (t_acc + x_acc * s_acc * (opmath_t(1) - t_acc * t_acc)); + } +}; + +void mish_backward_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "mish_backward_xpu", + [&]() { gpu_kernel(iter, MishBackwardFunctor()); }); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ActivationMishKernels.h b/src/ATen/native/xpu/sycl/ActivationMishKernels.h new file mode 100644 index 000000000..bd68197cd --- /dev/null +++ b/src/ATen/native/xpu/sycl/ActivationMishKernels.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace at::native::xpu { + +void mish_kernel(TensorIteratorBase& iter); + +void mish_backward_kernel(TensorIteratorBase& iter); + +} // namespace at::native::xpu diff --git a/src/comm/XPUMathCompat.h b/src/comm/XPUMathCompat.h index a685286b5..0b86dc334 100644 --- a/src/comm/XPUMathCompat.h +++ b/src/comm/XPUMathCompat.h @@ -29,6 +29,14 @@ __MATH_FUNCTIONS_DECL__ double rsqrt(double x) { return sycl::rsqrt(x); } +__MATH_FUNCTIONS_DECL__ float log1p(float x) { + return ::log1pf(x); +} + +__MATH_FUNCTIONS_DECL__ double log1p(double x) { + return ::log1p(x); +} + // To walk around SYCL compiler optimization on data type promotion. // c10::Half gets data type promotion in +-*/ operations. See // c10/util/Half-inl.h. XPU implementation gets worse precision on half div, diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 35c29d96b..acf28fb76 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -61,6 +61,7 @@ "gt", "hardtanh", "hardswish", + "nn.functional.mish", "index_add", "index_put", "index_select", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 2ecc6790b..7d212fa99 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -118,6 +118,10 @@ supported: - softshrink.out - softshrink_backward - softshrink_backward.grad_input + - mish + - mish.out + - mish_ + - mish_backward - gelu - gelu_ - gelu.out From 463cc2f3656e80b0df78fc2a93914d89f251176a Mon Sep 17 00:00:00 2001 From: xytintel Date: Fri, 12 Jul 2024 06:56:32 +0000 Subject: [PATCH 2/2] refine code --- src/ATen/native/xpu/sycl/ActivationMishKernels.cpp | 10 +++------- src/comm/XPUMathCompat.h | 8 -------- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/src/ATen/native/xpu/sycl/ActivationMishKernels.cpp b/src/ATen/native/xpu/sycl/ActivationMishKernels.cpp index 0a1b43347..fe0154326 100644 --- a/src/ATen/native/xpu/sycl/ActivationMishKernels.cpp +++ b/src/ATen/native/xpu/sycl/ActivationMishKernels.cpp @@ -15,9 +15,7 @@ struct MishFunctor { scalar_t operator()(scalar_t x) const { using opmath_t = at::opmath_type; const opmath_t x_acc = static_cast(x); - return x_acc * - c10::xpu::compat::tanh( - c10::xpu::compat::log1p(c10::xpu::compat::exp(x_acc))); + return x_acc * std::tanh(std::log1p(std::exp(x_acc))); } }; @@ -36,10 +34,8 @@ struct MishBackwardFunctor { using opmath_t = at::opmath_type; const opmath_t dy_acc = static_cast(dy); const opmath_t x_acc = static_cast(x); - const opmath_t s_acc = - opmath_t(1) / (opmath_t(1) + c10::xpu::compat::exp(-x_acc)); - const opmath_t t_acc = c10::xpu::compat::tanh( - c10::xpu::compat::log1p(c10::xpu::compat::exp(x_acc))); + const opmath_t s_acc = opmath_t(1) / (opmath_t(1) + std::exp(-x_acc)); + const opmath_t t_acc = std::tanh(std::log1p(std::exp(x_acc))); return dy_acc * (t_acc + x_acc * s_acc * (opmath_t(1) - t_acc * t_acc)); } }; diff --git a/src/comm/XPUMathCompat.h b/src/comm/XPUMathCompat.h index 0b86dc334..a685286b5 100644 --- a/src/comm/XPUMathCompat.h +++ b/src/comm/XPUMathCompat.h @@ -29,14 +29,6 @@ __MATH_FUNCTIONS_DECL__ double rsqrt(double x) { return sycl::rsqrt(x); } -__MATH_FUNCTIONS_DECL__ float log1p(float x) { - return ::log1pf(x); -} - -__MATH_FUNCTIONS_DECL__ double log1p(double x) { - return ::log1p(x); -} - // To walk around SYCL compiler optimization on data type promotion. // c10::Half gets data type promotion in +-*/ operations. See // c10/util/Half-inl.h. XPU implementation gets worse precision on half div,