Skip to content

Commit

Permalink
ml: a cmdline tool to use tflite-micro.
Browse files Browse the repository at this point in the history
Signed-off-by: jihandong <[email protected]>
  • Loading branch information
jihandong authored and xiaoxiang781216 committed Oct 18, 2024
1 parent b91adbb commit 7d87768
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 1 deletion.
17 changes: 16 additions & 1 deletion mlearning/tflite-micro/Kconfig
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,21 @@ config TFLITEMICRO

if TFLITEMICRO
config TFLITEMICRO_DEBUG
bool "TFLITEMICRO_DEBUG"
bool "Print tflite-micro's debug message"
default n

config TFLITEMICRO_TOOL
bool "tflite-micro cmdline tool"
default n

if TFLITEMICRO_TOOL
config TFLITEMICRO_TOOL_PRIORITY
int "tflite-micro tool priority"
default 100

config TFLITEMICRO_TOOL_STACKSIZE
int "tflite-micro tool stacksize"
default 4096

endif
endif
7 changes: 7 additions & 0 deletions mlearning/tflite-micro/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ endif
# extra hardware support.
-include $(TFLM_DIR)/tensorflow/lite/micro/nuttx/Makefile

ifneq ($(CONFIG_TFLITEMICRO_TOOL),)
MAINSRC = tflm_tool.cc
PROGNAME = tflm
PRIORITY = $(CONFIG_TFLITEMICRO_TOOL_PRIORITY)
STACKSIZE = $(CONFIG_TFLITEMICRO_TOOL_STACKSIZE)
endif

CFLAGS += ${COMMON_FLAGS}
CXXFLAGS += ${COMMON_FLAGS}

Expand Down
150 changes: 150 additions & 0 deletions mlearning/tflite-micro/tflm_tool.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/****************************************************************************
* apps/mlearning/tflite-micro/tflm_tool.cc
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership. The
* ASF licenses this file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the
* License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
****************************************************************************/

/****************************************************************************
* Included Files
****************************************************************************/

#include <unistd.h>

#include <cstdint>
#include <fstream>
#include <memory>

#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_profiler.h"

/****************************************************************************
* Private Functions
****************************************************************************/

static void usage(void)
{
printf("\nUtility to use tflite micro on nuttx.\n"
"[ -C ] Compile tflite model into c++ codes.\n"
"[ -E ] Do once evaluation (for profiling).\n"
"[ -i <str> ] Readable model file path.\n"
"[ -o <str> ] Writable c++ file path.\n"
"[ -p <str> ] Prefix of compiled code.\n"
"[ -a <int> ] Arena size (mempool).\n"
"[ -h ] Print this message.\n");
}

/****************************************************************************
* Public Functions
****************************************************************************/

extern "C" int main(int argc, FAR char* argv[])
{
const char* modelFileName = nullptr;
const char* codeFileName = nullptr;
const char* prefix = "NXAI";
bool need_compile = false;
bool need_invoke = false;
int arenaSize = 1024 * 8;

int ch;
while ((ch = getopt(argc, argv, "CEhi:o:p:a:")) != EOF)
{
switch (ch)
{
case 'C':
need_compile = true;
break;
case 'E':
need_invoke = true;
break;
case 'p':
prefix = optarg;
break;
case 'i':
modelFileName = optarg;
break;
case 'o':
codeFileName = optarg;
break;
case 'a':
arenaSize = strtol(optarg, NULL, 0);
break;
case 'h':
default:
usage();
return -1;
}
}

if (!modelFileName || !codeFileName)
{
usage();
return -1;
}

std::ifstream ifs(modelFileName, std::ios::binary);
ifs.seekg(0, std::ios::end);
size_t modelSize = ifs.tellg();
std::unique_ptr<uint8_t[]> pModel(new uint8_t[modelSize]);

ifs.seekg(0, std::ios::beg);
ifs.read(reinterpret_cast<char*>(pModel.get()), modelSize);
ifs.close();

/* HACK: can change operators here. */

tflite::MicroMutableOpResolver<8> resolver;
resolver.AddConv2D(tflite::Register_CONV_2D_INT8());
resolver.AddMaxPool2D(tflite::Register_MAX_POOL_2D_INT8());
resolver.AddQuantize(tflite::Register_QUANTIZE_FLOAT32_INT8());
resolver.AddDequantize(tflite::Register_DEQUANTIZE_INT8());
resolver.AddMean(tflite::Register_MEAN_INT8());
resolver.AddReshape();
resolver.AddFullyConnected(tflite::Register_FULLY_CONNECTED_INT8());
resolver.AddSoftmax(tflite::Register_SOFTMAX_INT8());

std::unique_ptr<uint8_t[]> pArena(new uint8_t[arenaSize]);

tflite::MicroProfiler profiler;
tflite::MicroInterpreter interpreter(tflite::GetModel(pModel.get()),
resolver, pArena.get(), arenaSize, nullptr,
reinterpret_cast<tflite::MicroProfilerInterface*>(&profiler));

/* HACK: can add testcases here. */

if (need_invoke)
{
interpreter.Invoke();
profiler.LogCsv();
profiler.LogTicksPerTagCsv();
}

if (need_compile)
{
#ifdef TFLITE_MODEL_COMPILER
std::ofstream ofs(codeFileName);
interpreter.Compile(ofs, prefix);
ofs.close();
#else
printf("Not supported compiling %s.\n", prefix);
#endif
}

printf("nxai done!\n");
return 0;
}

0 comments on commit 7d87768

Please sign in to comment.