From 76779e82ea159d394a51918713f5c50d82d30442 Mon Sep 17 00:00:00 2001 From: jianyizh Date: Thu, 26 Dec 2024 16:01:25 +0800 Subject: [PATCH] Add at::_safe_softmax op (#1180) fuse safe softmax https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp#L656 --------- Co-authored-by: Yutao Xu --- src/ATen/native/xpu/SoftMax.cpp | 11 ++ src/ATen/native/xpu/sycl/SoftMaxKernels.cpp | 115 +++++++++++++++----- src/ATen/native/xpu/sycl/SoftMaxKernels.h | 3 + test/regressions/test_safe_softmax.py | 44 ++++++++ yaml/native/native_functions.yaml | 4 + 5 files changed, 152 insertions(+), 25 deletions(-) create mode 100644 test/regressions/test_safe_softmax.py diff --git a/src/ATen/native/xpu/SoftMax.cpp b/src/ATen/native/xpu/SoftMax.cpp index e816d48c8..f155165ce 100644 --- a/src/ATen/native/xpu/SoftMax.cpp +++ b/src/ATen/native/xpu/SoftMax.cpp @@ -76,6 +76,17 @@ TORCH_IMPL_FUNC(log_softmax_xpu_out) xpu::_log_softmax_kernel(input, dim, half_to_float, output); } +Tensor _safe_softmax_xpu( + const Tensor& self, + int64_t dim, + std::optional dtype) { + // TODO: uncomment after XPU softmax support half_to_float=true + // if (self.scalar_type() == ScalarType::Half && dtype == ScalarType::Float) + // return xpu::_safe_softmax_kernel(self, dim_, true); + Tensor converted = dtype.has_value() ? self.toType(dtype.value()) : self; + return xpu::_safe_softmax_kernel(converted, dim, false); +} + Tensor masked_softmax_xpu( const Tensor& input_, const Tensor& mask_, diff --git a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp index 28d812f2c..0a0c7e718 100644 --- a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp +++ b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp @@ -210,7 +210,8 @@ template < int outer_loop, bool is_masked, typename calc_t, - typename vec_t> + typename vec_t, + bool is_safe_softmax> struct DispatchSoftmaxForwardKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { [[intel::reqd_sub_group_size(SIMD)]] void operator()( @@ -240,7 +241,8 @@ struct DispatchSoftmaxForwardKernelFunctor if (index >= dim_size_) break; - reg_in[i] = *(reinterpret_cast(in_data_ + group_offset + index)); + reg_in[i] = + *(reinterpret_cast(in_data_ + group_offset + index)); if constexpr (is_masked) { auto vec_offset = group_offset + index; #pragma unroll(vec_size) @@ -309,6 +311,10 @@ struct DispatchSoftmaxForwardKernelFunctor if constexpr (LogSoftMax) { reg_in[i][j] = static_cast(reg_in[i][j] - max_value - sum_value); + } else if ( + is_safe_softmax && + max_value == std::numeric_limits::lowest()) { + reg_in[i][j] = static_cast(0); } else if (sum_value == 0) { reg_in[i][j] = nan_; } else { @@ -386,7 +392,8 @@ template < bool LogSoftMax, int outer_loop, bool is_masked = false, - typename calc_t = decltype(nullptr)> + typename calc_t = decltype(nullptr), + bool is_safe_softmax = false> bool dispatch_softmax_forward_kernel( const scalar_t* in_data, scalar_t* out_data, @@ -412,7 +419,8 @@ bool dispatch_softmax_forward_kernel( outer_loop, is_masked, calc_t, - vec_t>; + vec_t, + /*is_safe_softmax = */ false>; int sub_group_num, global_size_row, local_size_row, range, local_size; int max_group_size = @@ -460,8 +468,8 @@ bool dispatch_softmax_forward_kernel( outer_loop, is_masked, DummyFunctor, - vec_t>; - + vec_t, + is_safe_softmax>; int sub_group_num, global_size_row, local_size_row, range, local_size; int max_group_size = get_wgroup_size( @@ -506,7 +514,8 @@ template < typename IndexType, bool LogSoftMax, typename vec_t, - int align_bytes> + int align_bytes, + bool is_safe_softmax> struct SoftmaxForwardKernelFunctor { void operator()(sycl::nd_item<1> item) const { IndexType local_id = item.get_local_id(0); @@ -562,6 +571,10 @@ struct SoftmaxForwardKernelFunctor { if (LogSoftMax) out_data_[group_offset + linear_idx] = static_cast( in_data_[group_offset + linear_idx] - max_value - sum_value); + else if ( + is_safe_softmax && + max_value == std::numeric_limits::lowest()) + out_data_[group_offset + linear_idx] = static_cast(0); else out_data_[group_offset + linear_idx] = static_cast( std::exp(in_data_[group_offset + linear_idx] - max_value) * @@ -576,6 +589,10 @@ struct SoftmaxForwardKernelFunctor { if (LogSoftMax) in_val[j] = static_cast(in_val[j] - max_value - sum_value); + else if ( + is_safe_softmax && + max_value == std::numeric_limits::lowest()) + in_val[j] = static_cast(0); else in_val[j] = static_cast( std::exp(in_val[j] - max_value) * sum_value); @@ -610,7 +627,8 @@ template < typename scalar_t, typename accscalar_t, typename IndexType, - bool LogSoftMax> + bool LogSoftMax, + bool is_safe_softmax> void softmax_forward_kernel( const scalar_t* in_data, scalar_t* out_data, @@ -625,7 +643,8 @@ void softmax_forward_kernel( IndexType, LogSoftMax, vec_t, - align_bytes>; + align_bytes, + is_safe_softmax>; int local_size = std::min( (dim_size + vec_size - 1) / vec_size, @@ -645,7 +664,8 @@ template < typename accscalar_t, typename IndexType, bool LogSoftMax, - typename vec_t> + typename vec_t, + bool is_safe_softmax> struct SpatialSoftmaxForwardKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { void operator()(sycl::nd_item<3> item) const { @@ -658,14 +678,16 @@ struct SpatialSoftmaxForwardKernelFunctor // get max value accscalar_t max_value[vec_size]; auto offset = local_row_id * inner_size_ + global_col * vec_size; - vec_t value = *(reinterpret_cast(in_data_ + group_offset + offset)); + vec_t value = + *(reinterpret_cast(in_data_ + group_offset + offset)); #pragma unroll(vec_size) for (int j = 0; j < vec_size; ++j) { max_value[j] = accscalar_t(value[j]); } for (int i = local_row_id + block_row_; i < dim_size_; i += block_row_) { offset = i * inner_size_ + global_col * vec_size; - value = *(reinterpret_cast(in_data_ + group_offset + offset)); + value = + *(reinterpret_cast(in_data_ + group_offset + offset)); #pragma unroll(vec_size) for (int j = 0; j < vec_size; ++j) { max_value[j] = std::max(max_value[j], accscalar_t(value[j])); @@ -695,7 +717,8 @@ struct SpatialSoftmaxForwardKernelFunctor } for (int i = local_row_id + block_row_; i < dim_size_; i += block_row_) { offset = i * inner_size_ + global_col * vec_size; - value = *(reinterpret_cast(in_data_ + group_offset + offset)); + value = + *(reinterpret_cast(in_data_ + group_offset + offset)); #pragma unroll(vec_size) for (int j = 0; j < vec_size; ++j) { sum_value[j] += std::exp(value[j] - max_value[j]); @@ -736,6 +759,10 @@ struct SpatialSoftmaxForwardKernelFunctor if (LogSoftMax) in_val[j] = static_cast(in_val[j] - max_value[j] - sum_value[j]); + else if ( + is_safe_softmax && + max_value[j] == -std::numeric_limits::infinity()) + in_val[j] = static_cast(0); else in_val[j] = static_cast( std::exp(in_val[j] - max_value[j]) * sum_value[j]); @@ -787,7 +814,8 @@ template < typename scalar_t, typename accscalar_t, typename IndexType, - bool LogSoftMax> + bool LogSoftMax, + bool is_safe_softmax> void spatial_softmax_forward( const scalar_t* in_data, scalar_t* out_data, @@ -801,7 +829,8 @@ void spatial_softmax_forward( accscalar_t, IndexType, LogSoftMax, - vec_t>; + vec_t, + is_safe_softmax>; int local_size, block_row; get_wgroup_size_spatial( @@ -818,7 +847,8 @@ void spatial_softmax_forward( accscalar_t, IndexType, LogSoftMax, - vec_t>( + vec_t, + is_safe_softmax>( in_data, out_data, dim_size, @@ -827,7 +857,6 @@ void spatial_softmax_forward( local_size, block_row, group_num); - auto& queue = getCurrentSYCLQueue(); sycl_kernel_submit(global_range, local_range, queue, kfn); } @@ -1387,7 +1416,11 @@ void spatial_softmax_backward_kernel( sycl_kernel_submit(global_range, local_range, queue, kfn); } -template +template < + typename scalar_t, + typename accscalar_t, + bool LogSoftMax, + bool is_safe_softmax> void spatial_softmax_forward( const Tensor& output, const Tensor& input, @@ -1432,7 +1465,10 @@ void spatial_softmax_forward( accscalar_t, \ uint32_t, \ LogSoftMax, \ - outer_loop>( \ + outer_loop, \ + /*is_masked = */ false, \ + /*calc_t = */ decltype(nullptr), \ + /*is_safe_softmax = */ is_safe_softmax>( \ input.const_data_ptr(), \ output.mutable_data_ptr(), \ dim_size, \ @@ -1446,7 +1482,8 @@ void spatial_softmax_forward( scalar_t, \ accscalar_t, \ IndexType, \ - LogSoftMax>( \ + LogSoftMax, \ + is_safe_softmax>( \ input.const_data_ptr(), \ output.mutable_data_ptr(), \ dim_size, \ @@ -1460,7 +1497,8 @@ void spatial_softmax_forward( scalar_t, \ accscalar_t, \ IndexType, \ - LogSoftMax>( \ + LogSoftMax, \ + is_safe_softmax>( \ input.const_data_ptr(), \ output.mutable_data_ptr(), \ dim_size, \ @@ -1749,7 +1787,8 @@ Tensor& masked_softmax_forward( LogSoftMax, \ outer_loop, \ true, \ - decltype(input_calc)>( \ + decltype(input_calc), \ + /*is_safe_softmax = */ false>( \ input.const_data_ptr(), \ output.mutable_data_ptr(), \ dim_size, \ @@ -1922,7 +1961,7 @@ void masked_softmax_backward( #undef SIMD32 } // namespace impl -template +template void host_softmax( const Tensor& input_, const int64_t dim_, @@ -1953,8 +1992,11 @@ void host_softmax( "host_softmax", [&] { using accscalar_t = acc_type_device; - impl::spatial_softmax_forward( - output, input, dim); + impl::spatial_softmax_forward< + scalar_t, + accscalar_t, + LogSoftMax, + is_safe_softmax>(output, input, dim); }); } // return output; @@ -2045,6 +2087,29 @@ void _log_softmax_backward_kernel( grad.contiguous(), output.contiguous(), dim, half_to_float, grad_input); } +Tensor _safe_softmax_kernel( + const Tensor& self, + int64_t dim, + const bool half_to_float) { + auto output_options = + self.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT); + if (half_to_float) { + output_options = output_options.dtype(ScalarType::Float); + } + Tensor output = at::empty_like(self, output_options); + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + self.scalar_type(), + "_safe_softmax", + [&] { + host_softmax( + self.contiguous(), dim, half_to_float, output); + }); + + return output; +} + Tensor masked_softmax_kernel( const Tensor& input_, const Tensor& mask_, diff --git a/src/ATen/native/xpu/sycl/SoftMaxKernels.h b/src/ATen/native/xpu/sycl/SoftMaxKernels.h index 0fc08496b..fc26fec3e 100644 --- a/src/ATen/native/xpu/sycl/SoftMaxKernels.h +++ b/src/ATen/native/xpu/sycl/SoftMaxKernels.h @@ -32,6 +32,9 @@ TORCH_XPU_API void _log_softmax_backward_kernel( bool half_to_float, const Tensor& grad_input); +TORCH_XPU_API Tensor +_safe_softmax_kernel(const Tensor& self, int64_t dim, const bool half_to_float); + TORCH_XPU_API Tensor masked_softmax_kernel( const Tensor& input_, const Tensor& mask_, diff --git a/test/regressions/test_safe_softmax.py b/test/regressions/test_safe_softmax.py new file mode 100644 index 000000000..7b390080a --- /dev/null +++ b/test/regressions/test_safe_softmax.py @@ -0,0 +1,44 @@ +import torch +from torch.testing._internal.common_utils import TestCase + +cpu_device = torch.device("cpu") +xpu_device = torch.device("xpu") + + +class TestSafeSoftMax(TestCase): + def test_sm(self): + for dtype in [torch.float, torch.float16, torch.bfloat16]: + x_cpu = torch.randn(128,128,128).to(dtype) + x_xpu = x_cpu.to(xpu_device) + r_cpu = torch.ops.aten._safe_softmax(x_cpu, -1) + r_xpu = torch.ops.aten._safe_softmax(x_xpu, -1) + self.assertEqual(r_xpu.to(cpu_device), r_cpu) + x_cpu[0,0,:] = -float("inf") + x_xpu = x_cpu.to(xpu_device) + r_cpu = torch.ops.aten._safe_softmax(x_cpu, -1) + r_xpu = torch.ops.aten._safe_softmax(x_xpu, -1) + self.assertEqual(r_xpu.to(cpu_device), r_cpu) + + x_cpu = torch.randn(128,128,128).to(dtype) + x_xpu = x_cpu.to(xpu_device) + r_cpu = torch.ops.aten._safe_softmax(x_cpu, 1) + r_xpu = torch.ops.aten._safe_softmax(x_xpu, 1) + self.assertEqual(r_xpu.to(cpu_device), r_cpu) + x_cpu[0,:,0] = -float("inf") + x_xpu = x_cpu.to(xpu_device) + r_cpu = torch.ops.aten._safe_softmax(x_cpu, 1) + r_xpu = torch.ops.aten._safe_softmax(x_xpu, 1) + self.assertEqual(r_xpu.to(cpu_device), r_cpu) + + x_cpu = torch.randn(128,128,128).to(dtype) + x_xpu = x_cpu.to(xpu_device) + r_cpu = torch.ops.aten._safe_softmax(x_cpu, 0) + r_xpu = torch.ops.aten._safe_softmax(x_xpu, 0) + self.assertEqual(r_xpu.to(cpu_device), r_cpu) + x_cpu[:,0,0] = -float("inf") + x_xpu = x_cpu.to(xpu_device) + r_cpu = torch.ops.aten._safe_softmax(x_cpu, 0) + r_xpu = torch.ops.aten._safe_softmax(x_xpu, 0) + self.assertEqual(r_xpu.to(cpu_device), r_cpu) + + diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index cbd57c762..7a257f0fd 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -2017,6 +2017,10 @@ dispatch: XPU: softmax_xpu_out +- func: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + dispatch: + XPU: _safe_softmax_xpu + - func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor structured_delegate: _softmax_backward_data.out