Skip to content

Commit

Permalink
Add at::_safe_softmax op (#1180)
Browse files Browse the repository at this point in the history
  • Loading branch information
2 people authored and ZhiweiYan-96 committed Jan 16, 2025
1 parent a8e5162 commit 76779e8
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 25 deletions.
11 changes: 11 additions & 0 deletions src/ATen/native/xpu/SoftMax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ TORCH_IMPL_FUNC(log_softmax_xpu_out)
xpu::_log_softmax_kernel(input, dim, half_to_float, output);
}

Tensor _safe_softmax_xpu(
const Tensor& self,
int64_t dim,
std::optional<ScalarType> dtype) {
// TODO: uncomment after XPU softmax support half_to_float=true
// if (self.scalar_type() == ScalarType::Half && dtype == ScalarType::Float)
// return xpu::_safe_softmax_kernel(self, dim_, true);
Tensor converted = dtype.has_value() ? self.toType(dtype.value()) : self;
return xpu::_safe_softmax_kernel(converted, dim, false);
}

Tensor masked_softmax_xpu(
const Tensor& input_,
const Tensor& mask_,
Expand Down
115 changes: 90 additions & 25 deletions src/ATen/native/xpu/sycl/SoftMaxKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ template <
int outer_loop,
bool is_masked,
typename calc_t,
typename vec_t>
typename vec_t,
bool is_safe_softmax>
struct DispatchSoftmaxForwardKernelFunctor
: public __SYCL_KER_CONFIG_CONVENTION__ {
[[intel::reqd_sub_group_size(SIMD)]] void operator()(
Expand Down Expand Up @@ -240,7 +241,8 @@ struct DispatchSoftmaxForwardKernelFunctor
if (index >= dim_size_)
break;

reg_in[i] = *(reinterpret_cast<const vec_t*>(in_data_ + group_offset + index));
reg_in[i] =
*(reinterpret_cast<const vec_t*>(in_data_ + group_offset + index));
if constexpr (is_masked) {
auto vec_offset = group_offset + index;
#pragma unroll(vec_size)
Expand Down Expand Up @@ -309,6 +311,10 @@ struct DispatchSoftmaxForwardKernelFunctor
if constexpr (LogSoftMax) {
reg_in[i][j] =
static_cast<scalar_t>(reg_in[i][j] - max_value - sum_value);
} else if (
is_safe_softmax &&
max_value == std::numeric_limits<accscalar_t>::lowest()) {
reg_in[i][j] = static_cast<scalar_t>(0);
} else if (sum_value == 0) {
reg_in[i][j] = nan_;
} else {
Expand Down Expand Up @@ -386,7 +392,8 @@ template <
bool LogSoftMax,
int outer_loop,
bool is_masked = false,
typename calc_t = decltype(nullptr)>
typename calc_t = decltype(nullptr),
bool is_safe_softmax = false>
bool dispatch_softmax_forward_kernel(
const scalar_t* in_data,
scalar_t* out_data,
Expand All @@ -412,7 +419,8 @@ bool dispatch_softmax_forward_kernel(
outer_loop,
is_masked,
calc_t,
vec_t>;
vec_t,
/*is_safe_softmax = */ false>;

int sub_group_num, global_size_row, local_size_row, range, local_size;
int max_group_size =
Expand Down Expand Up @@ -460,8 +468,8 @@ bool dispatch_softmax_forward_kernel(
outer_loop,
is_masked,
DummyFunctor,
vec_t>;

vec_t,
is_safe_softmax>;
int sub_group_num, global_size_row, local_size_row, range, local_size;
int max_group_size =
get_wgroup_size<SIMD, vec_size, outer_loop, KernelClass>(
Expand Down Expand Up @@ -506,7 +514,8 @@ template <
typename IndexType,
bool LogSoftMax,
typename vec_t,
int align_bytes>
int align_bytes,
bool is_safe_softmax>
struct SoftmaxForwardKernelFunctor {
void operator()(sycl::nd_item<1> item) const {
IndexType local_id = item.get_local_id(0);
Expand Down Expand Up @@ -562,6 +571,10 @@ struct SoftmaxForwardKernelFunctor {
if (LogSoftMax)
out_data_[group_offset + linear_idx] = static_cast<scalar_t>(
in_data_[group_offset + linear_idx] - max_value - sum_value);
else if (
is_safe_softmax &&
max_value == std::numeric_limits<accscalar_t>::lowest())
out_data_[group_offset + linear_idx] = static_cast<scalar_t>(0);
else
out_data_[group_offset + linear_idx] = static_cast<scalar_t>(
std::exp(in_data_[group_offset + linear_idx] - max_value) *
Expand All @@ -576,6 +589,10 @@ struct SoftmaxForwardKernelFunctor {
if (LogSoftMax)
in_val[j] =
static_cast<scalar_t>(in_val[j] - max_value - sum_value);
else if (
is_safe_softmax &&
max_value == std::numeric_limits<accscalar_t>::lowest())
in_val[j] = static_cast<scalar_t>(0);
else
in_val[j] = static_cast<scalar_t>(
std::exp(in_val[j] - max_value) * sum_value);
Expand Down Expand Up @@ -610,7 +627,8 @@ template <
typename scalar_t,
typename accscalar_t,
typename IndexType,
bool LogSoftMax>
bool LogSoftMax,
bool is_safe_softmax>
void softmax_forward_kernel(
const scalar_t* in_data,
scalar_t* out_data,
Expand All @@ -625,7 +643,8 @@ void softmax_forward_kernel(
IndexType,
LogSoftMax,
vec_t,
align_bytes>;
align_bytes,
is_safe_softmax>;

int local_size = std::min(
(dim_size + vec_size - 1) / vec_size,
Expand All @@ -645,7 +664,8 @@ template <
typename accscalar_t,
typename IndexType,
bool LogSoftMax,
typename vec_t>
typename vec_t,
bool is_safe_softmax>
struct SpatialSoftmaxForwardKernelFunctor
: public __SYCL_KER_CONFIG_CONVENTION__ {
void operator()(sycl::nd_item<3> item) const {
Expand All @@ -658,14 +678,16 @@ struct SpatialSoftmaxForwardKernelFunctor
// get max value
accscalar_t max_value[vec_size];
auto offset = local_row_id * inner_size_ + global_col * vec_size;
vec_t value = *(reinterpret_cast<const vec_t*>(in_data_ + group_offset + offset));
vec_t value =
*(reinterpret_cast<const vec_t*>(in_data_ + group_offset + offset));
#pragma unroll(vec_size)
for (int j = 0; j < vec_size; ++j) {
max_value[j] = accscalar_t(value[j]);
}
for (int i = local_row_id + block_row_; i < dim_size_; i += block_row_) {
offset = i * inner_size_ + global_col * vec_size;
value = *(reinterpret_cast<const vec_t*>(in_data_ + group_offset + offset));
value =
*(reinterpret_cast<const vec_t*>(in_data_ + group_offset + offset));
#pragma unroll(vec_size)
for (int j = 0; j < vec_size; ++j) {
max_value[j] = std::max(max_value[j], accscalar_t(value[j]));
Expand Down Expand Up @@ -695,7 +717,8 @@ struct SpatialSoftmaxForwardKernelFunctor
}
for (int i = local_row_id + block_row_; i < dim_size_; i += block_row_) {
offset = i * inner_size_ + global_col * vec_size;
value = *(reinterpret_cast<const vec_t*>(in_data_ + group_offset + offset));
value =
*(reinterpret_cast<const vec_t*>(in_data_ + group_offset + offset));
#pragma unroll(vec_size)
for (int j = 0; j < vec_size; ++j) {
sum_value[j] += std::exp(value[j] - max_value[j]);
Expand Down Expand Up @@ -736,6 +759,10 @@ struct SpatialSoftmaxForwardKernelFunctor
if (LogSoftMax)
in_val[j] =
static_cast<scalar_t>(in_val[j] - max_value[j] - sum_value[j]);
else if (
is_safe_softmax &&
max_value[j] == -std::numeric_limits<scalar_t>::infinity())
in_val[j] = static_cast<scalar_t>(0);
else
in_val[j] = static_cast<scalar_t>(
std::exp(in_val[j] - max_value[j]) * sum_value[j]);
Expand Down Expand Up @@ -787,7 +814,8 @@ template <
typename scalar_t,
typename accscalar_t,
typename IndexType,
bool LogSoftMax>
bool LogSoftMax,
bool is_safe_softmax>
void spatial_softmax_forward(
const scalar_t* in_data,
scalar_t* out_data,
Expand All @@ -801,7 +829,8 @@ void spatial_softmax_forward(
accscalar_t,
IndexType,
LogSoftMax,
vec_t>;
vec_t,
is_safe_softmax>;

int local_size, block_row;
get_wgroup_size_spatial<vec_size, KernelClass>(
Expand All @@ -818,7 +847,8 @@ void spatial_softmax_forward(
accscalar_t,
IndexType,
LogSoftMax,
vec_t>(
vec_t,
is_safe_softmax>(
in_data,
out_data,
dim_size,
Expand All @@ -827,7 +857,6 @@ void spatial_softmax_forward(
local_size,
block_row,
group_num);

auto& queue = getCurrentSYCLQueue();
sycl_kernel_submit(global_range, local_range, queue, kfn);
}
Expand Down Expand Up @@ -1387,7 +1416,11 @@ void spatial_softmax_backward_kernel(
sycl_kernel_submit(global_range, local_range, queue, kfn);
}

template <typename scalar_t, typename accscalar_t, bool LogSoftMax>
template <
typename scalar_t,
typename accscalar_t,
bool LogSoftMax,
bool is_safe_softmax>
void spatial_softmax_forward(
const Tensor& output,
const Tensor& input,
Expand Down Expand Up @@ -1432,7 +1465,10 @@ void spatial_softmax_forward(
accscalar_t, \
uint32_t, \
LogSoftMax, \
outer_loop>( \
outer_loop, \
/*is_masked = */ false, \
/*calc_t = */ decltype(nullptr), \
/*is_safe_softmax = */ is_safe_softmax>( \
input.const_data_ptr<scalar_t>(), \
output.mutable_data_ptr<scalar_t>(), \
dim_size, \
Expand All @@ -1446,7 +1482,8 @@ void spatial_softmax_forward(
scalar_t, \
accscalar_t, \
IndexType, \
LogSoftMax>( \
LogSoftMax, \
is_safe_softmax>( \
input.const_data_ptr<scalar_t>(), \
output.mutable_data_ptr<scalar_t>(), \
dim_size, \
Expand All @@ -1460,7 +1497,8 @@ void spatial_softmax_forward(
scalar_t, \
accscalar_t, \
IndexType, \
LogSoftMax>( \
LogSoftMax, \
is_safe_softmax>( \
input.const_data_ptr<scalar_t>(), \
output.mutable_data_ptr<scalar_t>(), \
dim_size, \
Expand Down Expand Up @@ -1749,7 +1787,8 @@ Tensor& masked_softmax_forward(
LogSoftMax, \
outer_loop, \
true, \
decltype(input_calc)>( \
decltype(input_calc), \
/*is_safe_softmax = */ false>( \
input.const_data_ptr<scalar_t>(), \
output.mutable_data_ptr<scalar_t>(), \
dim_size, \
Expand Down Expand Up @@ -1922,7 +1961,7 @@ void masked_softmax_backward(
#undef SIMD32
} // namespace impl

template <bool LogSoftMax>
template <bool LogSoftMax, bool is_safe_softmax = false>
void host_softmax(
const Tensor& input_,
const int64_t dim_,
Expand Down Expand Up @@ -1953,8 +1992,11 @@ void host_softmax(
"host_softmax",
[&] {
using accscalar_t = acc_type_device<scalar_t, kXPU>;
impl::spatial_softmax_forward<scalar_t, accscalar_t, LogSoftMax>(
output, input, dim);
impl::spatial_softmax_forward<
scalar_t,
accscalar_t,
LogSoftMax,
is_safe_softmax>(output, input, dim);
});
}
// return output;
Expand Down Expand Up @@ -2045,6 +2087,29 @@ void _log_softmax_backward_kernel(
grad.contiguous(), output.contiguous(), dim, half_to_float, grad_input);
}

Tensor _safe_softmax_kernel(
const Tensor& self,
int64_t dim,
const bool half_to_float) {
auto output_options =
self.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT);
if (half_to_float) {
output_options = output_options.dtype(ScalarType::Float);
}
Tensor output = at::empty_like(self, output_options);
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
self.scalar_type(),
"_safe_softmax",
[&] {
host_softmax<false, true>(
self.contiguous(), dim, half_to_float, output);
});

return output;
}

Tensor masked_softmax_kernel(
const Tensor& input_,
const Tensor& mask_,
Expand Down
3 changes: 3 additions & 0 deletions src/ATen/native/xpu/sycl/SoftMaxKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ TORCH_XPU_API void _log_softmax_backward_kernel(
bool half_to_float,
const Tensor& grad_input);

TORCH_XPU_API Tensor
_safe_softmax_kernel(const Tensor& self, int64_t dim, const bool half_to_float);

TORCH_XPU_API Tensor masked_softmax_kernel(
const Tensor& input_,
const Tensor& mask_,
Expand Down
44 changes: 44 additions & 0 deletions test/regressions/test_safe_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from torch.testing._internal.common_utils import TestCase

cpu_device = torch.device("cpu")
xpu_device = torch.device("xpu")


class TestSafeSoftMax(TestCase):
def test_sm(self):
for dtype in [torch.float, torch.float16, torch.bfloat16]:
x_cpu = torch.randn(128,128,128).to(dtype)
x_xpu = x_cpu.to(xpu_device)
r_cpu = torch.ops.aten._safe_softmax(x_cpu, -1)
r_xpu = torch.ops.aten._safe_softmax(x_xpu, -1)
self.assertEqual(r_xpu.to(cpu_device), r_cpu)
x_cpu[0,0,:] = -float("inf")
x_xpu = x_cpu.to(xpu_device)
r_cpu = torch.ops.aten._safe_softmax(x_cpu, -1)
r_xpu = torch.ops.aten._safe_softmax(x_xpu, -1)
self.assertEqual(r_xpu.to(cpu_device), r_cpu)

x_cpu = torch.randn(128,128,128).to(dtype)
x_xpu = x_cpu.to(xpu_device)
r_cpu = torch.ops.aten._safe_softmax(x_cpu, 1)
r_xpu = torch.ops.aten._safe_softmax(x_xpu, 1)
self.assertEqual(r_xpu.to(cpu_device), r_cpu)
x_cpu[0,:,0] = -float("inf")
x_xpu = x_cpu.to(xpu_device)
r_cpu = torch.ops.aten._safe_softmax(x_cpu, 1)
r_xpu = torch.ops.aten._safe_softmax(x_xpu, 1)
self.assertEqual(r_xpu.to(cpu_device), r_cpu)

x_cpu = torch.randn(128,128,128).to(dtype)
x_xpu = x_cpu.to(xpu_device)
r_cpu = torch.ops.aten._safe_softmax(x_cpu, 0)
r_xpu = torch.ops.aten._safe_softmax(x_xpu, 0)
self.assertEqual(r_xpu.to(cpu_device), r_cpu)
x_cpu[:,0,0] = -float("inf")
x_xpu = x_cpu.to(xpu_device)
r_cpu = torch.ops.aten._safe_softmax(x_cpu, 0)
r_xpu = torch.ops.aten._safe_softmax(x_xpu, 0)
self.assertEqual(r_xpu.to(cpu_device), r_cpu)


4 changes: 4 additions & 0 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2017,6 +2017,10 @@
dispatch:
XPU: softmax_xpu_out

- func: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
dispatch:
XPU: _safe_softmax_xpu

- func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
structured_delegate: _softmax_backward_data.out

Expand Down

0 comments on commit 76779e8

Please sign in to comment.