Skip to content

Commit

Permalink
update skip layer norm
Browse files Browse the repository at this point in the history
  • Loading branch information
amarin16 committed Nov 4, 2024
1 parent c64459f commit af60562
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 49 deletions.
116 changes: 71 additions & 45 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ namespace {
template <typename T, typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>, void>>
void ComputeJob(
const T* input_data,
const T* skip_data,
const T* gamma_data,
const T* beta_data,
const T* bias_data,
IAllocatorUniquePtr<float>& skip_float_uptr,
IAllocatorUniquePtr<float>& gamma_float_uptr,
IAllocatorUniquePtr<float>& beta_float_uptr,
IAllocatorUniquePtr<float>& bias_float_uptr,
const T* skip_data,
const float* gamma_float_ptr,
const float* beta_float_ptr,
const float* bias_float_ptr,
float* skip_float_ptr,
bool should_convert_skip,
ptrdiff_t task_idx,
int hidden_size,
int64_t skip_size,
Expand All @@ -58,10 +59,11 @@ void ComputeJob(
T* output_data,
T* skip_input_bias_add_output_data,
AllocatorPtr alloc) {
ORT_UNUSED_PARAMETER(skip_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(gamma_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(beta_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(bias_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(gamma_float_ptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(beta_float_ptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(bias_float_ptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(skip_float_ptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(should_convert_skip); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(alloc);

auto offset = task_idx * hidden_size;
Expand Down Expand Up @@ -109,14 +111,15 @@ void ComputeJob(

void ComputeJob(
const MLFloat16* input_data,
const MLFloat16* skip_data,
const MLFloat16* gamma_data,
const MLFloat16* beta_data,
const MLFloat16* bias_data,
IAllocatorUniquePtr<float>& skip_float_uptr,
IAllocatorUniquePtr<float>& gamma_float_uptr,
IAllocatorUniquePtr<float>& beta_float_uptr,
IAllocatorUniquePtr<float>& bias_float_uptr,
const MLFloat16* skip_data,
const float* gamma_float_ptr,
const float* beta_float_ptr,
const float* bias_float_ptr,
float* skip_float_ptr,
bool should_convert_skip,
ptrdiff_t task_idx,
int hidden_size,
int64_t skip_size,
Expand All @@ -125,6 +128,11 @@ void ComputeJob(
MLFloat16* output_data,
MLFloat16* skip_input_bias_add_output_data,
AllocatorPtr alloc) {
ORT_UNUSED_PARAMETER(skip_data); // only used in double/float overload
ORT_UNUSED_PARAMETER(gamma_data); // only used in double/float overload
ORT_UNUSED_PARAMETER(beta_data); // only used in double/float overload
ORT_UNUSED_PARAMETER(bias_data); // only used in double/float overload

auto offset = task_idx * hidden_size;
const MLFloat16* p_input = input_data + offset;
const MLFloat16* p_skip = skip_data + (offset % skip_size);
Expand All @@ -138,26 +146,18 @@ void ComputeJob(
IAllocatorUniquePtr<float> input_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(p_input, input_float_uptr.get(), num_elems);

if (!skip_float_uptr) {
skip_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(p_skip, skip_float_uptr.get(), num_elems);
}

if (bias_data && !bias_float_uptr) {
bias_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_float_uptr.get(), num_elems);
if (should_convert_skip) {
MlasConvertHalfToFloatBuffer(p_skip, skip_float_ptr, num_elems);
}

IAllocatorUniquePtr<float> output_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
float* output_float_ptr = output_float_uptr.get();

const float* input_float_ptr = input_float_uptr.get();
const float* skip_float_ptr = skip_float_uptr.get();
const float* bias_float_ptr = bias_float_uptr.get();
for (size_t h = 0; h < num_elems; h++) {
float val = input_float_ptr[h] + skip_float_ptr[h];

if (bias_float_uptr) {
if (bias_float_ptr) {
val += bias_float_ptr[h];
}

Expand All @@ -177,22 +177,10 @@ void ComputeJob(
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon);
}

if (!gamma_float_uptr) {
gamma_float_uptr = std::move(input_float_uptr); // overwrite input with gamma values, since they have the same size
MlasConvertHalfToFloatBuffer(gamma_data, gamma_float_uptr.get(), num_elems);
}

if (beta_data && !beta_float_uptr) {
beta_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(beta_data, beta_float_uptr.get(), num_elems);
}

const float* gamma_float_ptr = gamma_float_uptr.get();
const float* beta_float_ptr = beta_float_uptr.get();
for (size_t h = 0; h < num_elems; h++) {
if (simplified) {
output_float_ptr[h] = output_float_ptr[h] / mean_square * gamma_float_ptr[h];
} else if (nullptr == beta_float_uptr) {
} else if (nullptr == beta_float_ptr) {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h];
} else {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h] + beta_float_ptr[h];
Expand All @@ -218,7 +206,11 @@ void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, I

template <typename T, bool simplified>
SkipLayerNorm<T, simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
: OpKernel(op_kernel_info), skip_fp32_(nullptr), gamma_fp32_(nullptr), beta_fp32_(nullptr), bias_fp32_(nullptr) {
: OpKernel(op_kernel_info),
prepacked_gamma_fp32_data_(nullptr),
prepacked_beta_fp32_data_(nullptr),
prepacked_bias_fp32_data_(nullptr),
prepacked_skip_fp32_data_(nullptr) {
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
ORT_ENFORCE(epsilon_ >= 0);
}
Expand Down Expand Up @@ -264,11 +256,45 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));

IAllocatorUniquePtr<float> gamma_fp32;
IAllocatorUniquePtr<float> beta_fp32;
IAllocatorUniquePtr<float> bias_fp32;
IAllocatorUniquePtr<float> skip_fp32;
bool should_convert_skip = false;
if constexpr (std::is_same_v<T, MLFloat16>) {
const size_t num_elems = static_cast<size_t>(hidden_size);

if (prepacked_gamma_fp32_data_ == nullptr) {
gamma_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(gamma_data, gamma_fp32.get(), num_elems);
}

if (prepacked_beta_fp32_data_ == nullptr && beta_data) {
beta_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(beta_data, beta_fp32.get(), num_elems);
}

if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
bias_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems);
}

if (prepacked_skip_fp32_data_ == nullptr) {
skip_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
should_convert_skip = true;
// skip data will be converted inside ComputeJob, because it needs to use the offset.
}
}

concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, skip_fp32_, gamma_fp32_, beta_fp32_,
bias_fp32_, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
ComputeJob(input_data, gamma_data, beta_data, bias_data, skip_data,
prepacked_gamma_fp32_data_ ? prepacked_gamma_fp32_data_.get() : gamma_fp32.get(),
prepacked_beta_fp32_data_ ? prepacked_beta_fp32_data_.get() : beta_fp32.get(),
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
prepacked_skip_fp32_data_ ? prepacked_skip_fp32_data_.get() : skip_fp32.get(),
should_convert_skip, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
skip_input_bias_add_output_data, alloc);
},
0);
Expand All @@ -284,13 +310,13 @@ Status SkipLayerNorm<T, simplified>::PrePack(const Tensor& tensor, int input_idx

is_packed = false;
if (input_idx == 1) { // skip
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, skip_fp32_, is_packed);
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_skip_fp32_data_, is_packed);
} else if (input_idx == 2) { // gamma
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, gamma_fp32_, is_packed);
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_gamma_fp32_data_, is_packed);
} else if (input_idx == 3) { // beta
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, beta_fp32_, is_packed);
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed);
} else if (input_idx == 4) { // bias
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, bias_fp32_, is_packed);
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
}

return Status::OK();
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ class SkipLayerNorm final : public OpKernel {

private:
float epsilon_;
mutable IAllocatorUniquePtr<float> skip_fp32_;
mutable IAllocatorUniquePtr<float> gamma_fp32_;
mutable IAllocatorUniquePtr<float> beta_fp32_;
mutable IAllocatorUniquePtr<float> bias_fp32_;
IAllocatorUniquePtr<float> prepacked_gamma_fp32_data_;
IAllocatorUniquePtr<float> prepacked_beta_fp32_data_;
IAllocatorUniquePtr<float> prepacked_bias_fp32_data_;
IAllocatorUniquePtr<float> prepacked_skip_fp32_data_;
};

} // namespace contrib
Expand Down

0 comments on commit af60562

Please sign in to comment.