Skip to content

Commit

Permalink
Prepack skip, gamma, beta, bias in skip_layer_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
amarin16 committed Oct 9, 2024
1 parent f04aac0 commit 1eaa63f
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 19 deletions.
95 changes: 79 additions & 16 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,21 @@ Status ComputeJob(
const T* gamma_data,
const T* beta_data,
const T* bias_data,
const IAllocatorUniquePtr<float>& skip_fp32,
const IAllocatorUniquePtr<float>& gamma_fp32,
const IAllocatorUniquePtr<float>& beta_fp32,
const IAllocatorUniquePtr<float>& bias_fp32,
ptrdiff_t task_idx,
int hidden_size,
int64_t skip_size,
float epsilon,
bool simplified,
T* output_data,
T* skip_input_bias_add_output_data) {
ORT_UNUSED_PARAMETER(skip_fp32); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(gamma_fp32); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(beta_fp32); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(bias_fp32); // only used in MLFloat16 overload
auto offset = task_idx * hidden_size;
const T* p_input = input_data + offset;
const T* p_skip = skip_data + (offset % skip_size);
Expand Down Expand Up @@ -104,6 +112,10 @@ Status ComputeJob(
const MLFloat16* gamma_data,
const MLFloat16* beta_data,
const MLFloat16* bias_data,
const IAllocatorUniquePtr<float>& skip_fp32,
const IAllocatorUniquePtr<float>& gamma_fp32,
const IAllocatorUniquePtr<float>& beta_fp32,
const IAllocatorUniquePtr<float>& bias_fp32,
ptrdiff_t task_idx,
int hidden_size,
int64_t skip_size,
Expand All @@ -124,13 +136,20 @@ Status ComputeJob(
float* float_input = new float[num_elems];
MlasConvertHalfToFloatBuffer(p_input, float_input, num_elems);

float* float_skip = new float[num_elems];
MlasConvertHalfToFloatBuffer(p_skip, float_skip, num_elems);
float* float_skip = skip_fp32.get();
if (nullptr == float_skip) {
float_skip = new float[num_elems];
MlasConvertHalfToFloatBuffer(p_skip, float_skip, num_elems);
}

float* float_bias = nullptr;
if (bias_data != nullptr) {
float_bias = new float[num_elems];
MlasConvertHalfToFloatBuffer(bias_data, float_bias, num_elems);
if (bias_data) {
if (nullptr != bias_fp32) {
float_bias = bias_fp32.get();
} else {
float_bias = new float[num_elems];
MlasConvertHalfToFloatBuffer(bias_data, float_bias, num_elems);
}
}

float* float_output = new float[num_elems];
Expand All @@ -147,7 +166,7 @@ Status ComputeJob(
mean_square += val * val;
}

if (float_bias != nullptr) {
if (float_bias && (nullptr == bias_fp32)) {
delete[] float_bias;
}

Expand All @@ -162,13 +181,20 @@ Status ComputeJob(
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon);
}

float* float_gamma = float_input; // overwrite float_input with gamma values, since they have the same size
MlasConvertHalfToFloatBuffer(gamma_data, float_gamma, num_elems);
float* float_gamma = gamma_fp32.get();
if (nullptr == float_gamma) {
float_gamma = float_input; // overwrite float_input with gamma values, since they have the same size
MlasConvertHalfToFloatBuffer(gamma_data, float_gamma, num_elems);
}

float* float_beta = nullptr; // overwrite float_skip with beta values, since they have the same size
float* float_beta = nullptr;
if (beta_data) {
float_beta = float_skip;
MlasConvertHalfToFloatBuffer(beta_data, float_beta, num_elems);
if (nullptr != beta_fp32) {
float_beta = beta_fp32.get();
} else {
float_beta = new float[num_elems];
MlasConvertHalfToFloatBuffer(beta_data, float_beta, num_elems);
}
}

for (size_t h = 0; h < num_elems; h++) {
Expand All @@ -180,20 +206,37 @@ Status ComputeJob(
float_output[h] = (float_output[h] - mean) / mean_square * float_gamma[h] + float_beta[h];
}
}
delete[] float_gamma; // also deletes float_input
delete[] float_skip; // also deletes float_beta if used
delete[] float_input; // also takes care of float_gamma if reused
if (float_skip && (nullptr == skip_fp32)) {
delete[] float_skip;
}
if (beta_data && (nullptr == beta_fp32)) {
delete[] float_beta;
}

MlasConvertFloatToHalfBuffer(float_output, p_output, num_elems);
delete[] float_output;

return Status::OK();
}

void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, IAllocatorUniquePtr<float>& dest, bool& is_packed) {
if (tensor.GetElementType() == utils::ToTensorProtoElementType<MLFloat16>()) {
auto tensor_data_ptr = tensor.Data<MLFloat16>();
auto tensor_size = static_cast<size_t>(tensor.Shape().Size());
auto float_ptr = IAllocator::MakeUniquePtr<float>(alloc, tensor_size, true);

MlasConvertHalfToFloatBuffer(tensor_data_ptr, float_ptr.get(), tensor_size);
dest = std::move(float_ptr);
is_packed = true;
}
}

} // namespace

template <typename T, bool simplified>
SkipLayerNorm<T, simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
: OpKernel(op_kernel_info) {
: OpKernel(op_kernel_info), skip_fp32_(nullptr), gamma_fp32_(nullptr), beta_fp32_(nullptr), bias_fp32_(nullptr) {
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
ORT_ENFORCE(epsilon_ >= 0);
}
Expand Down Expand Up @@ -240,8 +283,9 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
auto status = ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size,
skip_size, epsilon_, simplified, output_data, skip_input_bias_add_output_data);
auto status = ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, skip_fp32_, gamma_fp32_,
beta_fp32_, bias_fp32_, task_idx, hidden_size, skip_size, epsilon_, simplified,
output_data, skip_input_bias_add_output_data);
if (status != Status::OK()) {
return_status = status;
}
Expand All @@ -251,5 +295,24 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
return return_status;
}

template <typename T, bool simplified>
Status SkipLayerNorm<T, simplified>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool& is_packed, PrePackedWeights* prepacked_weights) {
ORT_UNUSED_PARAMETER(prepacked_weights);

is_packed = false;
if (input_idx == 1) { // skip
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, skip_fp32_, is_packed);
} else if (input_idx == 2) { // gamma
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, gamma_fp32_, is_packed);
} else if (input_idx == 3) { // beta
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, beta_fp32_, is_packed);
} else if (input_idx == 4) { // bias
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, bias_fp32_, is_packed);
}

return Status::OK();
}

} // namespace contrib
} // namespace onnxruntime
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@ class SkipLayerNorm final : public OpKernel {
SkipLayerNorm(const OpKernelInfo& op_kernel_info);
Status Compute(OpKernelContext* p_op_kernel_context) const override;

Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
bool& is_packed, PrePackedWeights* prepacked_weights) override;

private:
float epsilon_;
IAllocatorUniquePtr<float> skip_fp32_;
IAllocatorUniquePtr<float> gamma_fp32_;
IAllocatorUniquePtr<float> beta_fp32_;
IAllocatorUniquePtr<float> bias_fp32_;
};

} // namespace contrib
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,14 @@ Status ComputeJob(
}

float* float_scale = scale_fp32.get();
if (float_scale == nullptr) {
if (nullptr == float_scale) {
float_scale = float_input; // overwrite float_input with scale values, since they have the same size
MlasConvertHalfToFloatBuffer(scale_data, float_scale, num_elems);
}

float* float_bias = nullptr;
if (bias_data) {
if (bias_fp32 != nullptr) {
if (nullptr != bias_fp32) {
float_bias = bias_fp32.get();
} else {
float_bias = new float[num_elems];
Expand All @@ -141,7 +141,7 @@ Status ComputeJob(
}

delete[] float_input; // also takes care of float_scale if reused
if (float_bias && (bias_fp32 == nullptr)) {
if (float_bias && (nullptr == bias_fp32)) {
delete[] float_bias;
}

Expand Down

0 comments on commit 1eaa63f

Please sign in to comment.