From eb5606e7598e832efa6285b30d5f8047c1d4af6e Mon Sep 17 00:00:00 2001 From: TFLM-bot Date: Fri, 18 Oct 2024 14:02:45 +0000 Subject: [PATCH] Sync from upstream TF. --- tensorflow/lite/array.cc | 2 ++ .../lite/kernels/internal/reference/batch_matmul.h | 9 +++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/array.cc b/tensorflow/lite/array.cc index 1b1ff2e4557..21d704a76c4 100644 --- a/tensorflow/lite/array.cc +++ b/tensorflow/lite/array.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/array.h" +#include "tensorflow/lite/c/common.h" + namespace tflite { namespace array_internal { diff --git a/tensorflow/lite/kernels/internal/reference/batch_matmul.h b/tensorflow/lite/kernels/internal/reference/batch_matmul.h index 767ad6ab0af..d83696219c2 100644 --- a/tensorflow/lite/kernels/internal/reference/batch_matmul.h +++ b/tensorflow/lite/kernels/internal/reference/batch_matmul.h @@ -111,7 +111,8 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data, const float* scaling_factors, const int32_t* input_offset, int32_t* row_sums, const RuntimeShape& output_shape, float* output_data, - bool* compute_row_sums) { + bool* compute_row_sums, + const float* per_channel_scales) { const RuntimeShape extended_lhs_shape = RuntimeShape::ExtendedShape(5, lhs_shape); const RuntimeShape extended_rhs_shape = @@ -188,7 +189,11 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data, int32_t row_sum = woff_ptr2[i]; total -= row_sum * batch_offset; int idx = lhs_rows * j + i; - out_ptr[idx] += batch_scaling_factor * total; + float scale = batch_scaling_factor; + if (per_channel_scales) { + scale *= per_channel_scales[i]; + } + out_ptr[idx] += scale * total; } } }