Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add at::_safe_softmax op #1180

Merged
merged 8 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/ATen/native/xpu/SoftMax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ TORCH_IMPL_FUNC(log_softmax_xpu_out)
xpu::_log_softmax_kernel(input, dim, half_to_float, output);
}

Tensor _safe_softmax_xpu(
const Tensor& self,
int64_t dim,
std::optional<ScalarType> dtype) {
// TODO: uncomment after XPU softmax support half_to_float=true
// if (self.scalar_type() == ScalarType::Half && dtype == ScalarType::Float)
// return xpu::_safe_softmax_kernel(self, dim_, true);
Tensor converted = dtype.has_value() ? self.toType(dtype.value()) : self;
return xpu::_safe_softmax_kernel(converted, dim, false);
}

Tensor masked_softmax_xpu(
const Tensor& input_,
const Tensor& mask_,
Expand Down
115 changes: 90 additions & 25 deletions src/ATen/native/xpu/sycl/SoftMaxKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ template <
int outer_loop,
bool is_masked,
typename calc_t,
typename vec_t>
typename vec_t,
bool is_safe_softmax>
struct DispatchSoftmaxForwardKernelFunctor
: public __SYCL_KER_CONFIG_CONVENTION__ {
[[intel::reqd_sub_group_size(SIMD)]] void operator()(
Expand Down Expand Up @@ -240,7 +241,8 @@ struct DispatchSoftmaxForwardKernelFunctor
if (index >= dim_size_)
break;

reg_in[i] = *(reinterpret_cast<const vec_t*>(in_data_ + group_offset + index));
reg_in[i] =
*(reinterpret_cast<const vec_t*>(in_data_ + group_offset + index));
if constexpr (is_masked) {
auto vec_offset = group_offset + index;
#pragma unroll(vec_size)
Expand Down Expand Up @@ -309,6 +311,10 @@ struct DispatchSoftmaxForwardKernelFunctor
if constexpr (LogSoftMax) {
reg_in[i][j] =
static_cast<scalar_t>(reg_in[i][j] - max_value - sum_value);
} else if (
is_safe_softmax &&
max_value == std::numeric_limits<accscalar_t>::lowest()) {
reg_in[i][j] = static_cast<scalar_t>(0);
} else if (sum_value == 0) {
reg_in[i][j] = nan_;
} else {
Expand Down Expand Up @@ -386,7 +392,8 @@ template <
bool LogSoftMax,
int outer_loop,
bool is_masked = false,
typename calc_t = decltype(nullptr)>
typename calc_t = decltype(nullptr),
bool is_safe_softmax = false>
bool dispatch_softmax_forward_kernel(
const scalar_t* in_data,
scalar_t* out_data,
Expand All @@ -412,7 +419,8 @@ bool dispatch_softmax_forward_kernel(
outer_loop,
is_masked,
calc_t,
vec_t>;
vec_t,
/*is_safe_softmax = */ false>;

int sub_group_num, global_size_row, local_size_row, range, local_size;
int max_group_size =
Expand Down Expand Up @@ -460,8 +468,8 @@ bool dispatch_softmax_forward_kernel(
outer_loop,
is_masked,
DummyFunctor,
vec_t>;

vec_t,
is_safe_softmax>;
int sub_group_num, global_size_row, local_size_row, range, local_size;
int max_group_size =
get_wgroup_size<SIMD, vec_size, outer_loop, KernelClass>(
Expand Down Expand Up @@ -506,7 +514,8 @@ template <
typename IndexType,
bool LogSoftMax,
typename vec_t,
int align_bytes>
int align_bytes,
bool is_safe_softmax>
struct SoftmaxForwardKernelFunctor {
void operator()(sycl::nd_item<1> item) const {
IndexType local_id = item.get_local_id(0);
Expand Down Expand Up @@ -562,6 +571,10 @@ struct SoftmaxForwardKernelFunctor {
if (LogSoftMax)
out_data_[group_offset + linear_idx] = static_cast<scalar_t>(
in_data_[group_offset + linear_idx] - max_value - sum_value);
else if (
is_safe_softmax &&
max_value == std::numeric_limits<accscalar_t>::lowest())
out_data_[group_offset + linear_idx] = static_cast<scalar_t>(0);
else
out_data_[group_offset + linear_idx] = static_cast<scalar_t>(
std::exp(in_data_[group_offset + linear_idx] - max_value) *
Expand All @@ -576,6 +589,10 @@ struct SoftmaxForwardKernelFunctor {
if (LogSoftMax)
in_val[j] =
static_cast<scalar_t>(in_val[j] - max_value - sum_value);
else if (
is_safe_softmax &&
max_value == std::numeric_limits<accscalar_t>::lowest())
in_val[j] = static_cast<scalar_t>(0);
else
in_val[j] = static_cast<scalar_t>(
std::exp(in_val[j] - max_value) * sum_value);
Expand Down Expand Up @@ -610,7 +627,8 @@ template <
typename scalar_t,
typename accscalar_t,
typename IndexType,
bool LogSoftMax>
bool LogSoftMax,
bool is_safe_softmax>
void softmax_forward_kernel(
const scalar_t* in_data,
scalar_t* out_data,
Expand All @@ -625,7 +643,8 @@ void softmax_forward_kernel(
IndexType,
LogSoftMax,
vec_t,
align_bytes>;
align_bytes,
is_safe_softmax>;

int local_size = std::min(
(dim_size + vec_size - 1) / vec_size,
Expand All @@ -645,7 +664,8 @@ template <
typename accscalar_t,
typename IndexType,
bool LogSoftMax,
typename vec_t>
typename vec_t,
bool is_safe_softmax>
struct SpatialSoftmaxForwardKernelFunctor
: public __SYCL_KER_CONFIG_CONVENTION__ {
void operator()(sycl::nd_item<3> item) const {
Expand All @@ -658,14 +678,16 @@ struct SpatialSoftmaxForwardKernelFunctor
// get max value
accscalar_t max_value[vec_size];
auto offset = local_row_id * inner_size_ + global_col * vec_size;
vec_t value = *(reinterpret_cast<const vec_t*>(in_data_ + group_offset + offset));
vec_t value =
*(reinterpret_cast<const vec_t*>(in_data_ + group_offset + offset));
#pragma unroll(vec_size)
for (int j = 0; j < vec_size; ++j) {
max_value[j] = accscalar_t(value[j]);
}
for (int i = local_row_id + block_row_; i < dim_size_; i += block_row_) {
offset = i * inner_size_ + global_col * vec_size;
value = *(reinterpret_cast<const vec_t*>(in_data_ + group_offset + offset));
value =
*(reinterpret_cast<const vec_t*>(in_data_ + group_offset + offset));
#pragma unroll(vec_size)
for (int j = 0; j < vec_size; ++j) {
max_value[j] = std::max(max_value[j], accscalar_t(value[j]));
Expand Down Expand Up @@ -695,7 +717,8 @@ struct SpatialSoftmaxForwardKernelFunctor
}
for (int i = local_row_id + block_row_; i < dim_size_; i += block_row_) {
offset = i * inner_size_ + global_col * vec_size;
value = *(reinterpret_cast<const vec_t*>(in_data_ + group_offset + offset));
value =
*(reinterpret_cast<const vec_t*>(in_data_ + group_offset + offset));
#pragma unroll(vec_size)
for (int j = 0; j < vec_size; ++j) {
sum_value[j] += std::exp(value[j] - max_value[j]);
Expand Down Expand Up @@ -736,6 +759,10 @@ struct SpatialSoftmaxForwardKernelFunctor
if (LogSoftMax)
in_val[j] =
static_cast<scalar_t>(in_val[j] - max_value[j] - sum_value[j]);
else if (
is_safe_softmax &&
max_value[j] == -std::numeric_limits<scalar_t>::infinity())
in_val[j] = static_cast<scalar_t>(0);
else
in_val[j] = static_cast<scalar_t>(
std::exp(in_val[j] - max_value[j]) * sum_value[j]);
Expand Down Expand Up @@ -787,7 +814,8 @@ template <
typename scalar_t,
typename accscalar_t,
typename IndexType,
bool LogSoftMax>
bool LogSoftMax,
bool is_safe_softmax>
void spatial_softmax_forward(
const scalar_t* in_data,
scalar_t* out_data,
Expand All @@ -801,7 +829,8 @@ void spatial_softmax_forward(
accscalar_t,
IndexType,
LogSoftMax,
vec_t>;
vec_t,
is_safe_softmax>;

int local_size, block_row;
get_wgroup_size_spatial<vec_size, KernelClass>(
Expand All @@ -818,7 +847,8 @@ void spatial_softmax_forward(
accscalar_t,
IndexType,
LogSoftMax,
vec_t>(
vec_t,
is_safe_softmax>(
in_data,
out_data,
dim_size,
Expand All @@ -827,7 +857,6 @@ void spatial_softmax_forward(
local_size,
block_row,
group_num);

auto& queue = getCurrentSYCLQueue();
sycl_kernel_submit(global_range, local_range, queue, kfn);
}
Expand Down Expand Up @@ -1387,7 +1416,11 @@ void spatial_softmax_backward_kernel(
sycl_kernel_submit(global_range, local_range, queue, kfn);
}

template <typename scalar_t, typename accscalar_t, bool LogSoftMax>
template <
typename scalar_t,
typename accscalar_t,
bool LogSoftMax,
bool is_safe_softmax>
void spatial_softmax_forward(
const Tensor& output,
const Tensor& input,
Expand Down Expand Up @@ -1432,7 +1465,10 @@ void spatial_softmax_forward(
accscalar_t, \
uint32_t, \
LogSoftMax, \
outer_loop>( \
outer_loop, \
/*is_masked = */ false, \
/*calc_t = */ decltype(nullptr), \
/*is_safe_softmax = */ is_safe_softmax>( \
input.const_data_ptr<scalar_t>(), \
output.mutable_data_ptr<scalar_t>(), \
dim_size, \
Expand All @@ -1446,7 +1482,8 @@ void spatial_softmax_forward(
scalar_t, \
accscalar_t, \
IndexType, \
LogSoftMax>( \
LogSoftMax, \
is_safe_softmax>( \
input.const_data_ptr<scalar_t>(), \
output.mutable_data_ptr<scalar_t>(), \
dim_size, \
Expand All @@ -1460,7 +1497,8 @@ void spatial_softmax_forward(
scalar_t, \
accscalar_t, \
IndexType, \
LogSoftMax>( \
LogSoftMax, \
is_safe_softmax>( \
input.const_data_ptr<scalar_t>(), \
output.mutable_data_ptr<scalar_t>(), \
dim_size, \
Expand Down Expand Up @@ -1749,7 +1787,8 @@ Tensor& masked_softmax_forward(
LogSoftMax, \
outer_loop, \
true, \
decltype(input_calc)>( \
decltype(input_calc), \
/*is_safe_softmax = */ false>( \
input.const_data_ptr<scalar_t>(), \
output.mutable_data_ptr<scalar_t>(), \
dim_size, \
Expand Down Expand Up @@ -1922,7 +1961,7 @@ void masked_softmax_backward(
#undef SIMD32
} // namespace impl

template <bool LogSoftMax>
template <bool LogSoftMax, bool is_safe_softmax = false>
void host_softmax(
const Tensor& input_,
const int64_t dim_,
Expand Down Expand Up @@ -1953,8 +1992,11 @@ void host_softmax(
"host_softmax",
[&] {
using accscalar_t = acc_type_device<scalar_t, kXPU>;
impl::spatial_softmax_forward<scalar_t, accscalar_t, LogSoftMax>(
output, input, dim);
impl::spatial_softmax_forward<
scalar_t,
accscalar_t,
LogSoftMax,
is_safe_softmax>(output, input, dim);
});
}
// return output;
Expand Down Expand Up @@ -2045,6 +2087,29 @@ void _log_softmax_backward_kernel(
grad.contiguous(), output.contiguous(), dim, half_to_float, grad_input);
}

Tensor _safe_softmax_kernel(
const Tensor& self,
int64_t dim,
const bool half_to_float) {
auto output_options =
self.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT);
if (half_to_float) {
output_options = output_options.dtype(ScalarType::Float);
}
Tensor output = at::empty_like(self, output_options);
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
self.scalar_type(),
"_safe_softmax",
[&] {
host_softmax<false, true>(
self.contiguous(), dim, half_to_float, output);
});

return output;
}

Tensor masked_softmax_kernel(
const Tensor& input_,
const Tensor& mask_,
Expand Down
3 changes: 3 additions & 0 deletions src/ATen/native/xpu/sycl/SoftMaxKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ TORCH_XPU_API void _log_softmax_backward_kernel(
bool half_to_float,
const Tensor& grad_input);

TORCH_XPU_API Tensor
_safe_softmax_kernel(const Tensor& self, int64_t dim, const bool half_to_float);

TORCH_XPU_API Tensor masked_softmax_kernel(
const Tensor& input_,
const Tensor& mask_,
Expand Down
44 changes: 44 additions & 0 deletions test/regressions/test_safe_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from torch.testing._internal.common_utils import TestCase

cpu_device = torch.device("cpu")
xpu_device = torch.device("xpu")


class TestSafeSoftMax(TestCase):
def test_sm(self):
for dtype in [torch.float, torch.float16, torch.bfloat16]:
x_cpu = torch.randn(128,128,128).to(dtype)
x_xpu = x_cpu.to(xpu_device)
r_cpu = torch.ops.aten._safe_softmax(x_cpu, -1)
r_xpu = torch.ops.aten._safe_softmax(x_xpu, -1)
self.assertEqual(r_xpu.to(cpu_device), r_cpu)
x_cpu[0,0,:] = -float("inf")
x_xpu = x_cpu.to(xpu_device)
r_cpu = torch.ops.aten._safe_softmax(x_cpu, -1)
r_xpu = torch.ops.aten._safe_softmax(x_xpu, -1)
self.assertEqual(r_xpu.to(cpu_device), r_cpu)

x_cpu = torch.randn(128,128,128).to(dtype)
x_xpu = x_cpu.to(xpu_device)
r_cpu = torch.ops.aten._safe_softmax(x_cpu, 1)
r_xpu = torch.ops.aten._safe_softmax(x_xpu, 1)
self.assertEqual(r_xpu.to(cpu_device), r_cpu)
x_cpu[0,:,0] = -float("inf")
x_xpu = x_cpu.to(xpu_device)
r_cpu = torch.ops.aten._safe_softmax(x_cpu, 1)
r_xpu = torch.ops.aten._safe_softmax(x_xpu, 1)
self.assertEqual(r_xpu.to(cpu_device), r_cpu)

x_cpu = torch.randn(128,128,128).to(dtype)
x_xpu = x_cpu.to(xpu_device)
r_cpu = torch.ops.aten._safe_softmax(x_cpu, 0)
r_xpu = torch.ops.aten._safe_softmax(x_xpu, 0)
self.assertEqual(r_xpu.to(cpu_device), r_cpu)
x_cpu[:,0,0] = -float("inf")
x_xpu = x_cpu.to(xpu_device)
r_cpu = torch.ops.aten._safe_softmax(x_cpu, 0)
r_xpu = torch.ops.aten._safe_softmax(x_xpu, 0)
self.assertEqual(r_xpu.to(cpu_device), r_cpu)


4 changes: 4 additions & 0 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2017,6 +2017,10 @@
dispatch:
XPU: softmax_xpu_out

- func: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
dispatch:
XPU: _safe_softmax_xpu

- func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
structured_delegate: _softmax_backward_data.out

Expand Down
Loading