diff --git a/src/ATen/native/xpu/SoftMax.cpp b/src/ATen/native/xpu/SoftMax.cpp index 95824577a..e816d48c8 100644 --- a/src/ATen/native/xpu/SoftMax.cpp +++ b/src/ATen/native/xpu/SoftMax.cpp @@ -76,4 +76,20 @@ TORCH_IMPL_FUNC(log_softmax_xpu_out) xpu::_log_softmax_kernel(input, dim, half_to_float, output); } +Tensor masked_softmax_xpu( + const Tensor& input_, + const Tensor& mask_, + const std::optional dim_, + const std::optional mask_type_) { + return xpu::masked_softmax_kernel(input_, mask_, dim_, mask_type_); +} + +Tensor masked_softmax_backward_xpu( + const Tensor& grad_, + const Tensor& output_, + const Tensor& mask_, + const std::optional dim_) { + return xpu::masked_softmax_backward_kernel(grad_, output_, mask_, dim_); +} + } // namespace at::native diff --git a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp index 45c9cb016..81be06363 100644 --- a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp +++ b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp @@ -8,7 +8,9 @@ #include #include +#include #include +#include #include @@ -125,7 +127,7 @@ static inline void softmax_group_reduce_spatial( } template -static inline void get_wgroup_size( +static inline int get_wgroup_size( uint64_t dim_size, int outer_size, int& sub_group_num, @@ -146,7 +148,7 @@ static inline void get_wgroup_size( local_size_col = 1; local_size_row = SIMD; global_size_row = (outer_size + local_size_row - 1) / local_size_row; - return; + return maxWGSize; } // if outer_size is too large and local_size_col is small, @@ -165,6 +167,8 @@ static inline void get_wgroup_size( while (sub_group_num <= (range >> 1)) { range = range >> 1; } + + return maxWGSize; } // this method help to divide the computation resource for spatial_softmax @@ -383,7 +387,7 @@ template < int outer_loop, bool is_masked = false, typename calc_t = decltype(nullptr)> -void dispatch_softmax_forward_kernel( +bool dispatch_softmax_forward_kernel( scalar_t* in_data, scalar_t* out_data, int dim_size, @@ -411,14 +415,20 @@ void dispatch_softmax_forward_kernel( vec_t>; int sub_group_num, global_size_row, local_size_row, range, local_size; - get_wgroup_size( - dim_size, - outer_size, - sub_group_num, - range, - global_size_row, - local_size_row, - local_size); + int max_group_size = + get_wgroup_size( + dim_size, + outer_size, + sub_group_num, + range, + global_size_row, + local_size_row, + local_size); + + if (max_group_size * INNER_LOOP < dim_size) { + return false; + } + int64_t local_range{local_size_row * local_size}; int64_t global_range{global_size_row * local_size_row * local_size}; @@ -453,14 +463,20 @@ void dispatch_softmax_forward_kernel( vec_t>; int sub_group_num, global_size_row, local_size_row, range, local_size; - get_wgroup_size( - dim_size, - outer_size, - sub_group_num, - range, - global_size_row, - local_size_row, - local_size); + int max_group_size = + get_wgroup_size( + dim_size, + outer_size, + sub_group_num, + range, + global_size_row, + local_size_row, + local_size); + + if (max_group_size * INNER_LOOP < dim_size) { + return false; + } + int64_t local_range{local_size_row * local_size}; int64_t global_range{global_size_row * local_size_row * local_size}; @@ -480,6 +496,7 @@ void dispatch_softmax_forward_kernel( nan); sycl_kernel_submit(global_range, local_range, queue, kfn); } + return true; } template < @@ -962,7 +979,7 @@ template < bool LogSoftMax, bool is_masked = false, typename calc_t = decltype(nullptr)> -void dispatch_softmax_backward_kernel( +bool dispatch_softmax_backward_kernel( scalar_t* gradInput, scalar_t* output, scalar_t* gradOutput, @@ -989,7 +1006,7 @@ void dispatch_softmax_backward_kernel( vec_t, NUM>; - get_wgroup_size( + int max_group_size = get_wgroup_size( dim_size, outer_size, sub_group_num, @@ -998,6 +1015,10 @@ void dispatch_softmax_backward_kernel( local_size_row, local_size); + if (max_group_size * INNER_LOOP < dim_size) { + return false; + } + auto kfn = KernelClass( gradInput, output, @@ -1031,7 +1052,7 @@ void dispatch_softmax_backward_kernel( vec_t, NUM>; - get_wgroup_size( + int max_group_size = get_wgroup_size( dim_size, outer_size, sub_group_num, @@ -1040,6 +1061,10 @@ void dispatch_softmax_backward_kernel( local_size_row, local_size); + if (max_group_size * INNER_LOOP < dim_size) { + return false; + } + auto kfn = KernelClass( gradInput, output, @@ -1059,6 +1084,8 @@ void dispatch_softmax_backward_kernel( sycl_kernel_submit(global_range, local_range, queue, kfn); } + + return true; } template < @@ -1396,7 +1423,7 @@ void spatial_softmax_forward( #define DISPATCH_SOFTMAX_FORWARD_IMPL(vec_size, SIMD, outer_loop) \ { \ - dispatch_softmax_forward_kernel< \ + use_slow_path = !dispatch_softmax_forward_kernel< \ INNER_LOOP, \ vec_size, \ SIMD, \ @@ -1444,29 +1471,8 @@ void spatial_softmax_forward( // if the element number is smaller than max_work_group_size * INNER_LOOP, // the fast path (dispatch_softmax_forward) will be selected. // otherwise, the general path (softmax_forward_kernel) will be selected. - - // Query the smallest max work group size of the kernel template. The kernel - // instance with the largest register pressure will have the smallest max - // work group size. Memory spill probably occurs more severely than - // any other instances, then compiler probably chooses less SIMD width to - // mitgate register pressure. Actual max work group size of these kernel - // template allowed by the compiler is less than device allowed max work - // group size. - using DispatchSoftmaxForwardKernel = DispatchSoftmaxForwardKernelFunctor< - INNER_LOOP, - max_vec_size, - SIMD32, - scalar_t, - accscalar_t, - uint32_t, - LogSoftMax, - INNER_LOOP / max_vec_size, - false, - DummyFunctor, - vec_t>; - int max_group_size = syclMaxWorkGroupSize(); - - if (can_use_32bit_index && max_group_size * INNER_LOOP >= dim_size) { + bool use_slow_path = true; + if (can_use_32bit_index) { // it assumes vec_size * outer_loop * work_group_size >= dim_size if (SIMD == SIMD32) { @@ -1508,7 +1514,9 @@ void spatial_softmax_forward( /*vec_size*/ 1, /*SIMD*/ SIMD16, outer_loop); } } - } else { + } + + if (use_slow_path) { if (can_use_32bit_index) { // the start psition of tensor pointer should be the same // the kernel can handle the non-aligned status @@ -1589,7 +1597,7 @@ void spatial_softmax_backward( #define DISPATCH_SOFTMAX_BACKWARD_IMPL(vec_size, SIMD) \ { \ - dispatch_softmax_backward_kernel< \ + use_slow_path = !dispatch_softmax_backward_kernel< \ INNER_LOOP, \ vec_size, \ SIMD, \ @@ -1626,34 +1634,12 @@ void spatial_softmax_backward( outer_size); if (inner_size == 1) { - // Query the smallest max work group size of the kernel template. The kernel - // instance with the largest register pressure will have the smallest max - // work group size. Memory spill probably occurs more severely than - // any other instances, then compiler probably chooses less SIMD width to - // mitgate register pressure. Actual max work group size of these kernel - // template allowed by the compiler is less than device allowed max work - // group size. - constexpr int NUM = INNER_LOOP / max_vec_size /* * (SIMD32 / SIMD32) */; - using DispatchSoftmaxBackwardKernel = DispatchSoftmaxBackwardKernelFunctor< - INNER_LOOP, - max_vec_size, - SIMD32, - scalar_t, - accscalar_t, - uint32_t, - LogSoftMax, - false, /* No instance for true */ - DummyFunctor, - vec_t, - NUM>; - - int max_group_size = syclMaxWorkGroupSize(); - // if the element number is smaller than max_work_group_size * INNER_LOOP // / 2, (2 indicates reading two tensors: output and gradOutput) the fast // path (dispatch_softmax_backward) will be selected. otherwise, the // general path (softmax_backward_kernel) will be selected. - if (can_use_32bit_index && max_group_size * INNER_LOOP >= dim_size) { + bool use_slow_path = true; + if (can_use_32bit_index) { if (SIMD == SIMD32) { if (gradin_start == 0 && output_start == 0 && gradoutput_start == 0 && dim_size % max_vec_size == 0) { @@ -1671,7 +1657,9 @@ void spatial_softmax_backward( DISPATCH_SOFTMAX_BACKWARD_IMPL(/*vec_size*/ 1, /*SIMD*/ SIMD16); } } - } else { + } + + if (use_slow_path) { if (can_use_32bit_index) { if (gradin_start == output_start && gradin_start == gradoutput_start) { SOFTMAX_BACKWARD_IMPL( @@ -1712,6 +1700,222 @@ void spatial_softmax_backward( #undef SPATIAL_SOFTMAX_BACKWARD_IMPL } +template +Tensor& masked_softmax_forward( + Tensor& output, + Tensor& input, + int dim, + const Tensor mask) { + auto inner_size = input.stride(dim); + auto dim_size = input.size(dim); + auto outer_size = input.numel() / (inner_size * dim_size); + + constexpr int float4_size = sizeof(float) * 4; + constexpr int max_vec_size = float4_size / sizeof(scalar_t); + constexpr int INNER_LOOP = max_vec_size * 2; + + // decide vec_size: max_vec_size or 1 + using vec_t = at::native::memory::aligned_vector; + constexpr int align_bytes = alignof(vec_t); + int input_start = + ((uint64_t)input.data_ptr()) % align_bytes / sizeof(scalar_t); + int output_start = + ((uint64_t)output.data_ptr()) % align_bytes / sizeof(scalar_t); + + // decide indexing range: uint32_t (4GB) or uint64_t (>4GB) + bool can_use_32bit_index = + canUse32BitIndexMath(input) && canUse32BitIndexMath(output); + + // decide SIMD: SIMD32 or SIMD16 + auto* dev_prop = + at::xpu::getDeviceProperties(at::xpu::getDeviceIndexOfCurrentQueue()); + auto sub_group_size = dev_prop->sub_group_sizes; + int SIMD = sub_group_size[1]; + if (SIMD == SIMD32) { + if (dim_size < SIMD16 * INNER_LOOP) + SIMD = SIMD16; + } + +#define DISPATCH_MASK_SOFTMAX_FORWARD_IMPL(vec_size, SIMD, outer_loop) \ + { \ + use_slow_path = !dispatch_softmax_forward_kernel< \ + INNER_LOOP, \ + vec_size, \ + SIMD, \ + scalar_t, \ + accscalar_t, \ + uint32_t, \ + LogSoftMax, \ + outer_loop, \ + true, \ + decltype(input_calc)>( \ + input.data_ptr(), \ + output.data_ptr(), \ + dim_size, \ + outer_size, \ + mask.data_ptr(), \ + input_calc); \ + } + + bool use_slow_path = true; + if (inner_size == 1 && can_use_32bit_index) { + // if the element number is smaller than max_work_group_size * INNER_LOOP, + // the fast path (dispatch_softmax_forward) will be selected. + // otherwise, the general path (softmax_forward_kernel) will be selected. + // it assumes vec_size * outer_loop * work_group_size >= dim_size + auto iter = TensorIterator::binary_op(output, input, mask); + auto input_calc = make_input_offset_calculator<2>(iter); + + if (SIMD == SIMD32) { + // Ensure input/output tensor are aligned with max_vec_size + if (input_start == 0 && output_start == 0 && + dim_size % max_vec_size == 0) { + constexpr int outer_loop = INNER_LOOP / max_vec_size; + DISPATCH_MASK_SOFTMAX_FORWARD_IMPL( + /*vec_size*/ max_vec_size, /*SIMD*/ SIMD32, outer_loop); + } else { + constexpr int outer_loop = INNER_LOOP; + DISPATCH_MASK_SOFTMAX_FORWARD_IMPL( + /*vec_size*/ 1, /*SIMD*/ SIMD32, outer_loop); + } + } else { + if (input_start == 0 && output_start == 0 && + dim_size % max_vec_size == 0) { + if (max_vec_size >= 4 && dim_size <= 4 * SIMD) { + // if vec_size >= 4 and dim_size <= 4 * SIMD, take smaller vec_size + // and 1 outer_loop + constexpr int outer_loop = 1; + DISPATCH_MASK_SOFTMAX_FORWARD_IMPL( + /*vec_size*/ 4, /*SIMD*/ SIMD16, outer_loop); + } else if (dim_size <= max_vec_size * SIMD) { + // if dim_size <= max_vec_size * SIMD , take 1 outer_loop + constexpr int outer_loop = 1; + DISPATCH_MASK_SOFTMAX_FORWARD_IMPL( + /*vec_size*/ max_vec_size, /*SIMD*/ SIMD16, outer_loop); + } else { + // SIMD16 will use less register numbers than SIMD32 + // if the SIMD = SIMD16, then outer_loop will be enlarged 2x + constexpr int outer_loop = INNER_LOOP / max_vec_size * 2; + DISPATCH_MASK_SOFTMAX_FORWARD_IMPL( + /*vec_size*/ max_vec_size, /*SIMD*/ SIMD16, outer_loop); + } + } else { + constexpr int outer_loop = INNER_LOOP * 2; + DISPATCH_MASK_SOFTMAX_FORWARD_IMPL( + /*vec_size*/ 1, /*SIMD*/ SIMD16, outer_loop); + } + } + } + + if (use_slow_path) { + auto mask_expand = mask.expand(input.sizes()); + output = at::softmax_out( + output, + input.masked_fill( + mask_expand, -std::numeric_limits::infinity()), + dim); + } + return output; +#undef DISPATCH_MASK_SOFTMAX_FORWARD_IMPL +} + +template +void masked_softmax_backward( + Tensor& gradInput, + Tensor& output, + Tensor& gradOutput, + Tensor& mask, + int dim) { + auto inner_size = output.stride(dim); + auto dim_size = output.size(dim); + auto outer_size = output.numel() / (dim_size * inner_size); + + constexpr int float4_size = sizeof(float) * 4; + constexpr int max_vec_size = float4_size / sizeof(scalar_t); + constexpr int INNER_LOOP = max_vec_size; + + // decide vec_size: max_vec_size or 1 + using vec_t = at::native::memory::aligned_vector; + constexpr int align_bytes = alignof(vec_t); + int gradin_start = + ((uint64_t)gradInput.data_ptr()) % align_bytes / sizeof(scalar_t); + int output_start = + ((uint64_t)output.data_ptr()) % align_bytes / sizeof(scalar_t); + int gradoutput_start = + ((uint64_t)gradOutput.data_ptr()) % align_bytes / sizeof(scalar_t); + + // decide indexing range: uint32_t (4GB) or uint64_t (>4GB) + bool can_use_32bit_index = canUse32BitIndexMath(gradInput) && + canUse32BitIndexMath(output) && canUse32BitIndexMath(gradOutput); + + // decide SIMD: SIMD32 or SIMD16 + auto* dev_prop = + at::xpu::getDeviceProperties(at::xpu::getDeviceIndexOfCurrentQueue()); + auto sub_group_size = dev_prop->sub_group_sizes; + int SIMD = sub_group_size[1]; + if (SIMD == SIMD32) { + if (dim_size < SIMD16 * max_vec_size) + SIMD = SIMD16; + } + +#define DISPATCH_MASK_SOFTMAX_BACKWARD_IMPL(vec_size, SIMD) \ + { \ + use_slow_path = !dispatch_softmax_backward_kernel< \ + INNER_LOOP, \ + vec_size, \ + SIMD, \ + scalar_t, \ + accscalar_t, \ + uint32_t, \ + LogSoftMax, \ + true, \ + decltype(input_calc)>( \ + gradInput.data_ptr(), \ + output.data_ptr(), \ + gradOutput.data_ptr(), \ + dim_size, \ + outer_size, \ + mask.data_ptr(), \ + input_calc); \ + } + + bool use_slow_path = true; + if (inner_size == 1 && can_use_32bit_index) { + auto iter = TensorIterator::binary_op(gradInput, gradOutput, mask); + auto input_calc = make_input_offset_calculator<2>(iter); + // if the element number is smaller than max_work_group_size * INNER_LOOP + // / 2, (2 indicates reading two tensors: output and gradOutput) the fast + // path (dispatch_softmax_backward) will be selected. otherwise, the + // general path (softmax_backward_kernel) will be selected. + if (SIMD == SIMD32) { + if (gradin_start == 0 && output_start == 0 && gradoutput_start == 0 && + dim_size % max_vec_size == 0) { + DISPATCH_MASK_SOFTMAX_BACKWARD_IMPL( + /*vec_size*/ max_vec_size, /*SIMD*/ SIMD32); + } else { + DISPATCH_MASK_SOFTMAX_BACKWARD_IMPL(/*vec_size*/ 1, /*SIMD*/ SIMD32); + } + } else { + if (gradin_start == 0 && output_start == 0 && gradoutput_start == 0 && + dim_size % max_vec_size == 0) { + DISPATCH_MASK_SOFTMAX_BACKWARD_IMPL( + /*vec_size*/ max_vec_size, /*SIMD*/ SIMD16); + } else { + DISPATCH_MASK_SOFTMAX_BACKWARD_IMPL(/*vec_size*/ 1, /*SIMD*/ SIMD16); + } + } + } + if (use_slow_path) { + gradInput = at::_softmax_backward_data_out( + gradInput, + gradOutput, + output.masked_fill(mask, 0), + dim, + gradOutput.scalar_type()); + } +#undef DISPATCH_SOFTMAX_BACKWARD_IMPL +} + #undef MIN_WG_NUM #undef SIMD16 #undef SIMD32 @@ -1840,6 +2044,108 @@ void _log_softmax_backward_kernel( grad.contiguous(), output.contiguous(), dim, half_to_float, grad_input); } +Tensor masked_softmax_kernel( + const Tensor& input_, + const Tensor& mask_, + const c10::optional dim_, + const c10::optional mask_type_) { + Tensor output = at::empty_like(input_, input_.options()); + TORCH_CHECK( + mask_.scalar_type() == ScalarType::Bool, + "Mask should be a boolean tensor"); + + TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined"); + int64_t mask_type = mask_type_.value(); + TORCH_CHECK( + (mask_type == 0) || (mask_type == 1) || (mask_type == 2), + "Mask Type should be 0 (src_mask), 1 (src_key_padding_mask), or 2 (default_mask)"); + + // If input is [B, H, T, T] and mask is [B, T] + // we have special fast kernel + // mask_type == 1 => mask_ is a src_key_padding_mask + bool is_BxT_mask = (mask_type == 1) && + (input_.dim() == 4 && mask_.dim() == 2 && + input_.size(0) == mask_.size(0) && input_.size(2) == mask_.size(1) && + input_.size(3) == mask_.size(1)); + + // If input is [B, H, T, T] and mask is [T, T] + // expand mask to [B, H, T, T] and treat it like regular mask + // TODO We should have special fast kernel for TxT mask as well + // mask_type == 0 => mask_ is a src_mask + bool is_TxT_mask = (mask_type == 0) && input_.dim() == 4 && + mask_.dim() == 2 && input_.size(3) == mask_.size(1) && + input_.size(2) == mask_.size(0) && mask_.size(0) == mask_.size(1); + // If mask_type == 2, then mask_.sizes() must equal input_.sizes() + TORCH_CHECK( + mask_.sizes() == input_.sizes() || is_BxT_mask || is_TxT_mask, + "Mask shape should match input. mask: ", + mask_.sizes(), + " input: ", + input_.sizes()); + + auto input = input_.dim() == 0 ? input_.view(1) : input_; + auto mask = mask_.dim() == 0 ? mask_.view(1) : mask_; + int64_t dim = dim_.has_value() ? dim_.value() : input.dim() - 1; + + if (is_BxT_mask) { + mask = mask.view({mask_.size(0), 1, 1, mask_.size(1)}); + } + // Here assumes that the mask is broadcastable for input + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + input.scalar_type(), + "masked_softmax", + [&] { + using accscalar_t = acc_type_device; + impl::masked_softmax_forward( + output, input, dim, mask); + }); + return output; +} + +Tensor masked_softmax_backward_kernel( + const Tensor& grad_, + const Tensor& output_, + const Tensor& mask_, + const c10::optional dim_) { + Tensor grad_input = at::empty_like(grad_, grad_.options()); + if (grad_.numel() == 0) { + return grad_input; + } + + auto grad = grad_.contiguous(); + auto output = output_.contiguous(); + auto mask = mask_.contiguous(); + int64_t dim = dim_.has_value() ? maybe_wrap_dim(dim_.value(), output.dim()) + : output.dim() - 1; + + grad = grad.dim() == 0 ? grad.view(1) : grad; + mask = mask.dim() == 0 ? mask.view(1) : mask; + output = output.dim() == 0 ? output.view(1) : output; + + TORCH_CHECK( + dim >= 0 && dim < grad.dim(), + "dim must be non-negative and less than input dimensions"); + TORCH_CHECK( + grad.sizes() == mask.sizes(), "Mask shape should match grad shape"); + TORCH_CHECK( + mask.scalar_type() == ScalarType::Bool, + "Mask should be a boolean tensor"); + + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + grad_input.scalar_type(), + "masked_softmax_backward", + [&] { + using accscalar_t = acc_type_device; + impl::masked_softmax_backward( + grad_input, output, grad, mask, dim); + }); + return grad_input; +} + } // namespace xpu } // namespace native } // namespace at diff --git a/src/ATen/native/xpu/sycl/SoftMaxKernels.h b/src/ATen/native/xpu/sycl/SoftMaxKernels.h index 87d205442..0fc08496b 100644 --- a/src/ATen/native/xpu/sycl/SoftMaxKernels.h +++ b/src/ATen/native/xpu/sycl/SoftMaxKernels.h @@ -32,6 +32,18 @@ TORCH_XPU_API void _log_softmax_backward_kernel( bool half_to_float, const Tensor& grad_input); +TORCH_XPU_API Tensor masked_softmax_kernel( + const Tensor& input_, + const Tensor& mask_, + const c10::optional dim_, + const c10::optional mask_type_); + +TORCH_XPU_API Tensor masked_softmax_backward_kernel( + const Tensor& grad_, + const Tensor& output_, + const Tensor& mask_, + const c10::optional dim_); + } // namespace xpu } // namespace native } // namespace at diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index a2c7522cd..49bbf1c1a 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -180,6 +180,7 @@ "cat", "log_softmax", "softmax", + "_softmax_backward_data", "scatter", "gather", "nn.functional.adaptive_max_pool2d", diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 1b5d5d01c..dd45febf6 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -1826,6 +1826,16 @@ dispatch: XPU: log_softmax_backward_xpu_out +- func: _masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor + dispatch: + XPU: masked_softmax_xpu + autogen: _masked_softmax.out + +- func: _masked_softmax_backward(Tensor grad_output, Tensor output, Tensor mask, int? dim=None) -> Tensor + dispatch: + XPU: masked_softmax_backward_xpu + autogen: _masked_softmax_backward.out + - func: exp(Tensor self) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: exp.out