Skip to content

Commit

Permalink
Fix the low performance issue in GroupNorm kernels (#1116)
Browse files Browse the repository at this point in the history
Primarily adopt vectorization approaches for `GroupNormForward` kernel.
  • Loading branch information
xytintel authored Nov 24, 2024
1 parent 9b224ee commit f7ca0ae
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 17 deletions.
9 changes: 5 additions & 4 deletions src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/TensorIterator.h>

#include <ATen/native/xpu/sycl/Loops.h>

#include <ATen/native/xpu/sycl/ActivationSoftshrinkKernels.h>
#include <ATen/native/xpu/sycl/Loops.h>

namespace at::native::xpu {

template <typename scalar_t>
struct SoftshrinkFunctor {
scalar_t operator()(scalar_t a) const {
return a > lambd_ ? a - lambd_ : (a < -lambd_ ? a + lambd_ : scalar_t(0));
return at::_isnan(a)
? a
: (a > lambd_ ? a - lambd_ : (a < -lambd_ ? a + lambd_ : scalar_t(0)));
}

SoftshrinkFunctor(scalar_t lambd) : lambd_(lambd) {}
Expand Down
167 changes: 154 additions & 13 deletions src/ATen/native/xpu/sycl/GroupNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,106 @@ struct GNRowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
sycl_local_acc_t<WelfordType> shared_;
};

template <typename T, int SIMD, int VEC_SIZE>
struct GNRowwiseMomentsVectorizedFunctor
: public __SYCL_KER_CONFIG_CONVENTION__ {
using T_ACC = acc_type_device<T, kXPU>;
using WelfordType = WelfordData<T_ACC, int64_t>;
using WelfordOp =
WelfordOpsXPU<T_ACC, T_ACC, int64_t, std::pair<T_ACC, T_ACC>>;
using vec_t = memory::aligned_vector<T, VEC_SIZE>;

[[intel::reqd_sub_group_size(SIMD)]] void operator()(
sycl::nd_item<1> item) const {
WelfordType val[VEC_SIZE];
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false, item};
auto g_start = item.get_group(0) * VEC_SIZE;

#pragma unroll
for (int v = 0; v < VEC_SIZE; ++v) {
const int64_t i = g_start + v;
if (i < G_) {
for (int64_t j = item.get_local_id(0) * VEC_SIZE; j < N_;
j += item.get_local_range(0) * VEC_SIZE) {
const int64_t vec_index = i * N_ + j;
auto remaining = N_ - j;
if (remaining < VEC_SIZE) {
for (int iv = 0; iv < remaining; ++iv) {
val[v] = welford_op.reduce(
val[v],
static_cast<T_ACC>(X_[vec_index + iv]),
vec_index + iv);
}
} else {
vec_t vec_in =
*reinterpret_cast<vec_t*>(const_cast<T*>(X_) + vec_index);
#pragma unroll
for (int iv = 0; iv < VEC_SIZE; ++iv) {
val[v] = welford_op.reduce(
val[v], static_cast<T_ACC>(vec_in[iv]), vec_index + iv);
}
}
}
}
}

#pragma unroll
for (int v = 0; v < VEC_SIZE; ++v) {
val[v] = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
item, val[v], welford_op, shared_);
}

if (item.get_local_id(0) == 0) {
auto remaining = G_ - g_start;
if (remaining < VEC_SIZE) {
for (int v = 0; v < remaining; ++v) {
T_ACC m1;
T_ACC m2;
std::tie(m2, m1) = welford_op.project(val[v]);
mean_[g_start + v] = m1;
rstd_[g_start + v] =
c10::xpu::compat::rsqrt(m2 + static_cast<T_ACC>(eps_));
}
} else {
vec_t mean_vec;
vec_t rstd_vec;
#pragma unroll
for (int v = 0; v < VEC_SIZE; ++v) {
T_ACC m1;
T_ACC m2;
std::tie(m2, m1) = welford_op.project(val[v]);
mean_vec[v] = m1;
rstd_vec[v] = c10::xpu::compat::rsqrt(m2 + static_cast<T_ACC>(eps_));
}
*(reinterpret_cast<vec_t*>(mean_ + g_start)) = mean_vec;
*(reinterpret_cast<vec_t*>(rstd_ + g_start)) = rstd_vec;
}
}
}

void sycl_ker_config_convention(sycl::handler& cgh) {
shared_ = sycl_local_acc_t<WelfordType>(SIMD, cgh);
}

GNRowwiseMomentsVectorizedFunctor(
int64_t N,
int64_t G,
T eps,
const T* X,
T* mean,
T* rstd)
: N_(N), G_(G), eps_(eps), X_(X), mean_(mean), rstd_(rstd) {}

private:
int64_t N_;
int64_t G_;
T eps_;
const T* X_;
T* mean_;
T* rstd_;
sycl_local_acc_t<WelfordType> shared_;
};

template <typename T, int SIMD>
struct GNRowwiseMomentsNHWCFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
using T_ACC = acc_type_device<T, kXPU>;
Expand Down Expand Up @@ -306,6 +406,11 @@ struct GroupNormFunctor {
}
};

template <typename T>
bool can_use_vectorization(T* p, int vec_size) {
return memory::can_vectorize_up_to<T>((char*)p) >= vec_size;
}

template <typename T>
void group_norm_kernel_impl(
const Tensor& X,
Expand Down Expand Up @@ -338,7 +443,7 @@ void group_norm_kernel_impl(

auto& queue = getCurrentSYCLQueue();
int64_t simd = syclMaxSubGroupSize();
const int64_t wg_size = D * HxW < get_group_reduce_group_size(simd)
int64_t wg_size = D * HxW < get_group_reduce_group_size(simd)
? simd
: get_group_reduce_group_size(simd);
int64_t nwg = N * G;
Expand All @@ -350,18 +455,54 @@ void group_norm_kernel_impl(

switch (x_format) {
case MemoryFormat::Contiguous: {
group_norm_kernel_simd_choice_and_launch<
GNRowwiseMomentsFunctor<T, SIMD16>,
GNRowwiseMomentsFunctor<T, SIMD32>>(
simd,
global_range,
local_range,
queue,
D * HxW,
eps,
X_data,
mean_data,
rstd_data);
constexpr int VEC_SIZE =
2; // To reduce the register pressure caused by WelfordData, we apply
// 2-way vectorization only to half-precision data.
if (sizeof(T) < sizeof(float) &&
can_use_vectorization(X_data, VEC_SIZE) &&
can_use_vectorization(mean_data, VEC_SIZE) &&
can_use_vectorization(rstd_data, VEC_SIZE)) {
using FUNC_T_SIMD16 =
GNRowwiseMomentsVectorizedFunctor<T, SIMD16, VEC_SIZE>;
using FUNC_T_SIMD32 =
GNRowwiseMomentsVectorizedFunctor<T, SIMD32, VEC_SIZE>;
if (simd == SIMD16) {
wg_size = syclMaxWorkGroupSize<FUNC_T_SIMD16>();
} else if (simd == SIMD32) {
wg_size = syclMaxWorkGroupSize<FUNC_T_SIMD32>();
} else {
TORCH_INTERNAL_ASSERT(
false,
"The GroupNorm kernel currently only supports SIMD16 or SIMD32.");
}
auto global_range_ =
sycl::range<1>((N * G + VEC_SIZE - 1) / VEC_SIZE * wg_size);
auto local_range_ = sycl::range<1>(wg_size);
group_norm_kernel_simd_choice_and_launch<FUNC_T_SIMD16, FUNC_T_SIMD32>(
simd,
global_range_,
local_range_,
queue,
D * HxW,
N * G,
eps,
X_data,
mean_data,
rstd_data);
} else {
group_norm_kernel_simd_choice_and_launch<
GNRowwiseMomentsFunctor<T, SIMD16>,
GNRowwiseMomentsFunctor<T, SIMD32>>(
simd,
global_range,
local_range,
queue,
D * HxW,
eps,
X_data,
mean_data,
rstd_data);
}
break;
}
case MemoryFormat::ChannelsLast: {
Expand Down

0 comments on commit f7ca0ae

Please sign in to comment.