diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 1615d774907..8562d8bb53f 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -333,6 +333,7 @@ tflm_kernel_cc_library( "logistic.h", "lstm_eval.h", "lstm_shared.h", + "maximum_minimum.h", "micro_ops.h", "mul.h", "pad.h", diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/maximum_minimum.cc b/tensorflow/lite/micro/kernels/cmsis_nn/maximum_minimum.cc new file mode 100644 index 00000000000..a6affaa11bb --- /dev/null +++ b/tensorflow/lite/micro/kernels/cmsis_nn/maximum_minimum.cc @@ -0,0 +1,247 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/maximum_minimum.h" + +#include "Include/arm_nnfunctions.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_log.h" + +namespace tflite { + +namespace { + +cmsis_nn_dims FillVariableShape(int32_t rank, int32_t* tensor_dims) { + if (rank == 4) { + return {tensor_dims[0], tensor_dims[1], tensor_dims[2], tensor_dims[3]}; + } else if (rank == 3) { + return {1, tensor_dims[0], tensor_dims[1], tensor_dims[2]}; + } else if (rank == 2) { + return {1, 1, tensor_dims[0], tensor_dims[1]}; + } else { + return {1, 1, 1, 1}; + } +} + +TfLiteStatus EvalMaximum(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + const TfLiteEvalTensor* input1 = + tflite::micro::GetEvalInput(context, node, kInputTensor1); + const TfLiteEvalTensor* input2 = + tflite::micro::GetEvalInput(context, node, kInputTensor2); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + RuntimeShape input_1_shape = tflite::micro::GetTensorShape(input1); + RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2); + RuntimeShape output_shape = tflite::micro::GetTensorShape(output); + + cmsis_nn_dims input_1_dims = FillVariableShape( + input_1_shape.DimensionsCount(), input_1_shape.DimsData()); + cmsis_nn_dims input_2_dims = FillVariableShape( + input_2_shape.DimensionsCount(), input_2_shape.DimsData()); + cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(), + output_shape.DimsData()); + + switch (op_context.output->type) { + case kTfLiteInt8: + cmsis_nn_context ctx; + ctx.buf = nullptr; + ctx.size = 0; + + arm_maximum_s8( + &ctx, tflite::micro::GetTensorData(input1), &input_1_dims, + tflite::micro::GetTensorData(input2), &input_2_dims, + tflite::micro::GetTensorData(output), &output_dims); + break; + case kTfLiteFloat32: + TFLiteOperation(context, node, op_context); + break; + case kTfLiteInt16: + TFLiteOperation(context, node, op_context); + break; + case kTfLiteInt32: + TFLiteOperation(context, node, op_context); + break; + case kTfLiteInt64: + TFLiteOperation(context, node, op_context); + break; + default: + MicroPrintf("Type %s (%d) is not supported by Maximum/Minimum.", + TfLiteTypeGetName(op_context.output->type), + op_context.output->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus EvalMaximumInt8(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + const TfLiteEvalTensor* input1 = + tflite::micro::GetEvalInput(context, node, kInputTensor1); + const TfLiteEvalTensor* input2 = + tflite::micro::GetEvalInput(context, node, kInputTensor2); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + RuntimeShape input_1_shape = tflite::micro::GetTensorShape(input1); + RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2); + RuntimeShape output_shape = tflite::micro::GetTensorShape(output); + + cmsis_nn_dims input_1_dims = FillVariableShape( + input_1_shape.DimensionsCount(), input_1_shape.DimsData()); + cmsis_nn_dims input_2_dims = FillVariableShape( + input_2_shape.DimensionsCount(), input_2_shape.DimsData()); + cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(), + output_shape.DimsData()); + + switch (op_context.output->type) { + case kTfLiteInt8: + cmsis_nn_context ctx; + ctx.buf = nullptr; + ctx.size = 0; + + arm_maximum_s8( + &ctx, tflite::micro::GetTensorData(input1), &input_1_dims, + tflite::micro::GetTensorData(input2), &input_2_dims, + tflite::micro::GetTensorData(output), &output_dims); + break; + default: + MicroPrintf("Type %s (%d) is not supported by Maximum Int8 Registration.", + TfLiteTypeGetName(op_context.output->type), + op_context.output->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus EvalMinimum(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + const TfLiteEvalTensor* input1 = + tflite::micro::GetEvalInput(context, node, kInputTensor1); + const TfLiteEvalTensor* input2 = + tflite::micro::GetEvalInput(context, node, kInputTensor2); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + RuntimeShape input_1_shape = tflite::micro::GetTensorShape(input1); + RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2); + RuntimeShape output_shape = tflite::micro::GetTensorShape(output); + + cmsis_nn_dims input_1_dims = FillVariableShape( + input_1_shape.DimensionsCount(), input_1_shape.DimsData()); + cmsis_nn_dims input_2_dims = FillVariableShape( + input_2_shape.DimensionsCount(), input_2_shape.DimsData()); + cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(), + output_shape.DimsData()); + + switch (op_context.output->type) { + case kTfLiteInt8: + cmsis_nn_context ctx; + ctx.buf = nullptr; + ctx.size = 0; + + arm_minimum_s8( + &ctx, tflite::micro::GetTensorData(input1), &input_1_dims, + tflite::micro::GetTensorData(input2), &input_2_dims, + tflite::micro::GetTensorData(output), &output_dims); + break; + case kTfLiteFloat32: + TFLiteOperation(context, node, op_context); + break; + case kTfLiteInt16: + TFLiteOperation(context, node, op_context); + break; + case kTfLiteInt32: + TFLiteOperation(context, node, op_context); + break; + case kTfLiteInt64: + TFLiteOperation(context, node, op_context); + break; + default: + MicroPrintf("Type %s (%d) is not supported by Maximum/Minimum.", + TfLiteTypeGetName(op_context.output->type), + op_context.output->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus EvalMinimumInt8(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + const TfLiteEvalTensor* input1 = + tflite::micro::GetEvalInput(context, node, kInputTensor1); + const TfLiteEvalTensor* input2 = + tflite::micro::GetEvalInput(context, node, kInputTensor2); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + RuntimeShape input_1_shape = tflite::micro::GetTensorShape(input1); + RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2); + RuntimeShape output_shape = tflite::micro::GetTensorShape(output); + + cmsis_nn_dims input_1_dims = FillVariableShape( + input_1_shape.DimensionsCount(), input_1_shape.DimsData()); + cmsis_nn_dims input_2_dims = FillVariableShape( + input_2_shape.DimensionsCount(), input_2_shape.DimsData()); + cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(), + output_shape.DimsData()); + + switch (op_context.output->type) { + case kTfLiteInt8: + cmsis_nn_context ctx; + ctx.buf = nullptr; + ctx.size = 0; + + arm_minimum_s8( + &ctx, tflite::micro::GetTensorData(input1), &input_1_dims, + tflite::micro::GetTensorData(input2), &input_2_dims, + tflite::micro::GetTensorData(output), &output_dims); + break; + default: + MicroPrintf("Type %s (%d) is not supported by Minimum Int8 registration.", + TfLiteTypeGetName(op_context.output->type), + op_context.output->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace + +TFLMRegistration Register_MAXIMUM() { + return tflite::micro::RegisterOp(nullptr, nullptr, EvalMaximum); +} + +TFLMRegistration Register_MINIMUM() { + return tflite::micro::RegisterOp(nullptr, nullptr, EvalMinimum); +} + +TFLMRegistration Register_MAXIMUM_INT8() { + return tflite::micro::RegisterOp(nullptr, nullptr, EvalMaximumInt8); +} + +TFLMRegistration Register_MINIMUM_INT8() { + return tflite::micro::RegisterOp(nullptr, nullptr, EvalMinimumInt8); +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/maximum_minimum.cc b/tensorflow/lite/micro/kernels/maximum_minimum.cc index 4dc87b40148..ef4a0a6a522 100644 --- a/tensorflow/lite/micro/kernels/maximum_minimum.cc +++ b/tensorflow/lite/micro/kernels/maximum_minimum.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,59 +23,13 @@ limitations under the License. #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/maximum_minimum.h" #include "tensorflow/lite/micro/micro_log.h" namespace tflite { namespace { -// This file has a reference implementation of TFMaximum/TFMinimum. -enum KernelType { - kReference, -}; - -constexpr int kInputTensor1 = 0; -constexpr int kInputTensor2 = 1; -constexpr int kOutputTensor = 0; - -struct OpContext { - OpContext(TfLiteContext* context, TfLiteNode* node) { - input1 = tflite::micro::GetEvalInput(context, node, kInputTensor1); - input2 = tflite::micro::GetEvalInput(context, node, kInputTensor2); - output = tflite::micro::GetEvalOutput(context, node, kOutputTensor); - } - const TfLiteEvalTensor* input1; - const TfLiteEvalTensor* input2; - TfLiteEvalTensor* output; -}; - -struct MaximumOp { - template - static data_type op(data_type el1, data_type el2) { - return el1 > el2 ? el1 : el2; - } -}; - -struct MinimumOp { - template - static data_type op(data_type el1, data_type el2) { - return el1 < el2 ? el1 : el2; - } -}; - -template -void TFLiteOperation(TfLiteContext* context, TfLiteNode* node, - const OpContext& op_context) { - reference_ops::MaximumMinimumBroadcastSlow( - tflite::micro::GetTensorShape(op_context.input1), - tflite::micro::GetTensorData(op_context.input1), - tflite::micro::GetTensorShape(op_context.input2), - tflite::micro::GetTensorData(op_context.input2), - tflite::micro::GetTensorShape(op_context.output), - tflite::micro::GetTensorData(op_context.output), - op_type::template op); -} - template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { OpContext op_context(context, node); diff --git a/tensorflow/lite/micro/kernels/maximum_minimum.h b/tensorflow/lite/micro/kernels/maximum_minimum.h new file mode 100644 index 00000000000..34d7e2399f3 --- /dev/null +++ b/tensorflow/lite/micro/kernels/maximum_minimum.h @@ -0,0 +1,105 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_MAXIMUM_MINIMUM_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_MAXIMUM_MINIMUM_H_ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_log.h" + +namespace tflite { + +// This file has a reference implementation of TFMaximum/TFMinimum. +enum KernelType { + kReference, +}; + +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +struct OpContext { + OpContext(TfLiteContext* context, TfLiteNode* node) { + input1 = tflite::micro::GetEvalInput(context, node, kInputTensor1); + input2 = tflite::micro::GetEvalInput(context, node, kInputTensor2); + output = tflite::micro::GetEvalOutput(context, node, kOutputTensor); + } + const TfLiteEvalTensor* input1; + const TfLiteEvalTensor* input2; + TfLiteEvalTensor* output; +}; + +struct MaximumOp { + template + static data_type op(data_type el1, data_type el2) { + return el1 > el2 ? el1 : el2; + } +}; + +struct MinimumOp { + template + static data_type op(data_type el1, data_type el2) { + return el1 < el2 ? el1 : el2; + } +}; + +template +void TFLiteOperation(TfLiteContext* context, TfLiteNode* node, + const OpContext& op_context) { + reference_ops::MaximumMinimumBroadcastSlow( + tflite::micro::GetTensorShape(op_context.input1), + tflite::micro::GetTensorData(op_context.input1), + tflite::micro::GetTensorShape(op_context.input2), + tflite::micro::GetTensorData(op_context.input2), + tflite::micro::GetTensorShape(op_context.output), + tflite::micro::GetTensorData(op_context.output), + op_type::template op); +} + +TFLMRegistration Register_MAXIMUM(); + +TFLMRegistration Register_MINIMUM(); + +#if defined(CMSIS_NN) +// Returns a TFLMRegistration struct for kernel variant that only supports +// int8. +TFLMRegistration Register_MAXIMUM_INT8(); + +// Returns a TFLMRegistration struct for kernel variant that only supports +// int8. +TFLMRegistration Register_MINIMUM_INT8(); + +#else +// Note that while this block gets used for both reference and optimized kernels +// that do not have any specialized implementations, the only goal here is to +// define fallback implementation that allow reference kernels to still be used +// from applications that call a more specific kernel variant. +inline TFLMRegistration Register_MAXIMUM_INT8() { return Register_MAXIMUM(); } + +inline TFLMRegistration Register_MINIMUM_INT8() { return Register_MINIMUM(); } + +#endif + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_MAXIMUM_MINIMUM_H_ diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index f5f6e38e003..ad642ddbc06 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/depthwise_conv.h" #include "tensorflow/lite/micro/kernels/ethosu.h" #include "tensorflow/lite/micro/kernels/fully_connected.h" +#include "tensorflow/lite/micro/kernels/maximum_minimum.h" #include "tensorflow/lite/micro/kernels/micro_ops.h" #include "tensorflow/lite/micro/kernels/mul.h" #include "tensorflow/lite/micro/kernels/pooling.h" @@ -414,9 +415,9 @@ class MicroMutableOpResolver : public MicroOpResolver { tflite::Register_LOG_SOFTMAX(), ParseLogSoftmax); } - TfLiteStatus AddMaximum() { - return AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM(), - ParseMaximum); + TfLiteStatus AddMaximum( + const TFLMRegistration& registration = Register_MAXIMUM()) { + return AddBuiltin(BuiltinOperator_MAXIMUM, registration, ParseMaximum); } TfLiteStatus AddMaxPool2D( @@ -433,9 +434,9 @@ class MicroMutableOpResolver : public MicroOpResolver { return AddBuiltin(BuiltinOperator_MEAN, Register_MEAN(), ParseReducer); } - TfLiteStatus AddMinimum() { - return AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM(), - ParseMinimum); + TfLiteStatus AddMinimum( + const TFLMRegistration& registration = Register_MINIMUM()) { + return AddBuiltin(BuiltinOperator_MINIMUM, registration, ParseMinimum); } TfLiteStatus AddMul(const TFLMRegistration& registration = Register_MUL()) { @@ -452,7 +453,8 @@ class MicroMutableOpResolver : public MicroOpResolver { } TfLiteStatus AddOverlapAdd() { - // TODO(b/286250473): change back name to "OverlapAdd" and remove namespace + // TODO(b/286250473): change back name to "OverlapAdd" and remove + // namespace return AddCustom("SignalOverlapAdd", tflite::tflm_signal::Register_OVERLAP_ADD()); } @@ -684,8 +686,8 @@ class MicroMutableOpResolver : public MicroOpResolver { } registrations_[registrations_len_] = registration; - // Strictly speaking, the builtin_code is not necessary for TFLM but filling - // it in regardless. + // Strictly speaking, the builtin_code is not necessary for TFLM but + // filling it in regardless. registrations_[registrations_len_].builtin_code = op; registrations_len_++;