From af605625a96c2b201e06bc67bc6c992403165a81 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Mon, 4 Nov 2024 13:21:10 -0800 Subject: [PATCH] update skip layer norm --- .../contrib_ops/cpu/skip_layer_norm.cc | 116 +++++++++++------- onnxruntime/contrib_ops/cpu/skip_layer_norm.h | 8 +- 2 files changed, 75 insertions(+), 49 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index c9ee9e2cb760d..116759c3f31f0 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -42,14 +42,15 @@ namespace { template || std::is_same_v, void>> void ComputeJob( const T* input_data, - const T* skip_data, const T* gamma_data, const T* beta_data, const T* bias_data, - IAllocatorUniquePtr& skip_float_uptr, - IAllocatorUniquePtr& gamma_float_uptr, - IAllocatorUniquePtr& beta_float_uptr, - IAllocatorUniquePtr& bias_float_uptr, + const T* skip_data, + const float* gamma_float_ptr, + const float* beta_float_ptr, + const float* bias_float_ptr, + float* skip_float_ptr, + bool should_convert_skip, ptrdiff_t task_idx, int hidden_size, int64_t skip_size, @@ -58,10 +59,11 @@ void ComputeJob( T* output_data, T* skip_input_bias_add_output_data, AllocatorPtr alloc) { - ORT_UNUSED_PARAMETER(skip_float_uptr); // only used in MLFloat16 overload - ORT_UNUSED_PARAMETER(gamma_float_uptr); // only used in MLFloat16 overload - ORT_UNUSED_PARAMETER(beta_float_uptr); // only used in MLFloat16 overload - ORT_UNUSED_PARAMETER(bias_float_uptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(gamma_float_ptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(beta_float_ptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(bias_float_ptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(skip_float_ptr); // only used in MLFloat16 overload + ORT_UNUSED_PARAMETER(should_convert_skip); // only used in MLFloat16 overload ORT_UNUSED_PARAMETER(alloc); auto offset = task_idx * hidden_size; @@ -109,14 +111,15 @@ void ComputeJob( void ComputeJob( const MLFloat16* input_data, - const MLFloat16* skip_data, const MLFloat16* gamma_data, const MLFloat16* beta_data, const MLFloat16* bias_data, - IAllocatorUniquePtr& skip_float_uptr, - IAllocatorUniquePtr& gamma_float_uptr, - IAllocatorUniquePtr& beta_float_uptr, - IAllocatorUniquePtr& bias_float_uptr, + const MLFloat16* skip_data, + const float* gamma_float_ptr, + const float* beta_float_ptr, + const float* bias_float_ptr, + float* skip_float_ptr, + bool should_convert_skip, ptrdiff_t task_idx, int hidden_size, int64_t skip_size, @@ -125,6 +128,11 @@ void ComputeJob( MLFloat16* output_data, MLFloat16* skip_input_bias_add_output_data, AllocatorPtr alloc) { + ORT_UNUSED_PARAMETER(skip_data); // only used in double/float overload + ORT_UNUSED_PARAMETER(gamma_data); // only used in double/float overload + ORT_UNUSED_PARAMETER(beta_data); // only used in double/float overload + ORT_UNUSED_PARAMETER(bias_data); // only used in double/float overload + auto offset = task_idx * hidden_size; const MLFloat16* p_input = input_data + offset; const MLFloat16* p_skip = skip_data + (offset % skip_size); @@ -138,26 +146,18 @@ void ComputeJob( IAllocatorUniquePtr input_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); MlasConvertHalfToFloatBuffer(p_input, input_float_uptr.get(), num_elems); - if (!skip_float_uptr) { - skip_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - MlasConvertHalfToFloatBuffer(p_skip, skip_float_uptr.get(), num_elems); - } - - if (bias_data && !bias_float_uptr) { - bias_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - MlasConvertHalfToFloatBuffer(bias_data, bias_float_uptr.get(), num_elems); + if (should_convert_skip) { + MlasConvertHalfToFloatBuffer(p_skip, skip_float_ptr, num_elems); } IAllocatorUniquePtr output_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); float* output_float_ptr = output_float_uptr.get(); const float* input_float_ptr = input_float_uptr.get(); - const float* skip_float_ptr = skip_float_uptr.get(); - const float* bias_float_ptr = bias_float_uptr.get(); for (size_t h = 0; h < num_elems; h++) { float val = input_float_ptr[h] + skip_float_ptr[h]; - if (bias_float_uptr) { + if (bias_float_ptr) { val += bias_float_ptr[h]; } @@ -177,22 +177,10 @@ void ComputeJob( mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon); } - if (!gamma_float_uptr) { - gamma_float_uptr = std::move(input_float_uptr); // overwrite input with gamma values, since they have the same size - MlasConvertHalfToFloatBuffer(gamma_data, gamma_float_uptr.get(), num_elems); - } - - if (beta_data && !beta_float_uptr) { - beta_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - MlasConvertHalfToFloatBuffer(beta_data, beta_float_uptr.get(), num_elems); - } - - const float* gamma_float_ptr = gamma_float_uptr.get(); - const float* beta_float_ptr = beta_float_uptr.get(); for (size_t h = 0; h < num_elems; h++) { if (simplified) { output_float_ptr[h] = output_float_ptr[h] / mean_square * gamma_float_ptr[h]; - } else if (nullptr == beta_float_uptr) { + } else if (nullptr == beta_float_ptr) { output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h]; } else { output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h] + beta_float_ptr[h]; @@ -218,7 +206,11 @@ void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, I template SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) - : OpKernel(op_kernel_info), skip_fp32_(nullptr), gamma_fp32_(nullptr), beta_fp32_(nullptr), bias_fp32_(nullptr) { + : OpKernel(op_kernel_info), + prepacked_gamma_fp32_data_(nullptr), + prepacked_beta_fp32_data_(nullptr), + prepacked_bias_fp32_data_(nullptr), + prepacked_skip_fp32_data_(nullptr) { ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); ORT_ENFORCE(epsilon_ >= 0); } @@ -264,11 +256,45 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { AllocatorPtr alloc; ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); + IAllocatorUniquePtr gamma_fp32; + IAllocatorUniquePtr beta_fp32; + IAllocatorUniquePtr bias_fp32; + IAllocatorUniquePtr skip_fp32; + bool should_convert_skip = false; + if constexpr (std::is_same_v) { + const size_t num_elems = static_cast(hidden_size); + + if (prepacked_gamma_fp32_data_ == nullptr) { + gamma_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(gamma_data, gamma_fp32.get(), num_elems); + } + + if (prepacked_beta_fp32_data_ == nullptr && beta_data) { + beta_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(beta_data, beta_fp32.get(), num_elems); + } + + if (prepacked_bias_fp32_data_ == nullptr && bias_data) { + bias_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems); + } + + if (prepacked_skip_fp32_data_ == nullptr) { + skip_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + should_convert_skip = true; + // skip data will be converted inside ComputeJob, because it needs to use the offset. + } + } + concurrency::ThreadPool::TryBatchParallelFor( p_ctx->GetOperatorThreadPool(), static_cast(task_count), [&](ptrdiff_t task_idx) { - ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, skip_fp32_, gamma_fp32_, beta_fp32_, - bias_fp32_, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data, + ComputeJob(input_data, gamma_data, beta_data, bias_data, skip_data, + prepacked_gamma_fp32_data_ ? prepacked_gamma_fp32_data_.get() : gamma_fp32.get(), + prepacked_beta_fp32_data_ ? prepacked_beta_fp32_data_.get() : beta_fp32.get(), + prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(), + prepacked_skip_fp32_data_ ? prepacked_skip_fp32_data_.get() : skip_fp32.get(), + should_convert_skip, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data, skip_input_bias_add_output_data, alloc); }, 0); @@ -284,13 +310,13 @@ Status SkipLayerNorm::PrePack(const Tensor& tensor, int input_idx is_packed = false; if (input_idx == 1) { // skip - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, skip_fp32_, is_packed); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_skip_fp32_data_, is_packed); } else if (input_idx == 2) { // gamma - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, gamma_fp32_, is_packed); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_gamma_fp32_data_, is_packed); } else if (input_idx == 3) { // beta - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, beta_fp32_, is_packed); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed); } else if (input_idx == 4) { // bias - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, bias_fp32_, is_packed); + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h index d904c14857437..83deaf0be8098 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h @@ -21,10 +21,10 @@ class SkipLayerNorm final : public OpKernel { private: float epsilon_; - mutable IAllocatorUniquePtr skip_fp32_; - mutable IAllocatorUniquePtr gamma_fp32_; - mutable IAllocatorUniquePtr beta_fp32_; - mutable IAllocatorUniquePtr bias_fp32_; + IAllocatorUniquePtr prepacked_gamma_fp32_data_; + IAllocatorUniquePtr prepacked_beta_fp32_data_; + IAllocatorUniquePtr prepacked_bias_fp32_data_; + IAllocatorUniquePtr prepacked_skip_fp32_data_; }; } // namespace contrib