Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Tflite cpp static build #7

Open
wants to merge 1 commit into
base: rsdk-8636-make-tflite-cpu-module
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ add_executable(tflite_cpu

target_link_libraries(tflite_cpu
PRIVATE Threads::Threads
PRIVATE "$<LINK_LIBRARY:WHOLE_ARCHIVE,viam-cpp-sdk::viamsdk>"
PRIVATE "$<LINK_LIBRARY:WHOLE_ARCHIVE,tensorflow::tensorflowlite>"
PRIVATE viam-cpp-sdk::viamsdk
PRIVATE tensorflow::tensorflowlite
)

install(
Expand Down
10 changes: 1 addition & 9 deletions cpp/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def configure(self):
self.options["*"].shared = False

def requirements(self):
self.requires("viam-cpp-sdk/0.0.11")
self.requires("viam-cpp-sdk/0.0.12")
self.requires("tensorflow-lite/2.15.0")
self.requires("abseil/20240116.2", override=True)

Expand All @@ -57,11 +57,3 @@ def build(self):

def layout(self):
cmake_layout(self, src_folder=".")

def test(self):
if can_run(self):
cmd = os.path.join(self.cpp.build.bindir, "tflite_module")
stderr = StringIO()
self.run(cmd, env='conanrun', stderr=stderr, ignore_errors=True)
if "main failed with exception:" not in stderr.getvalue():
raise ConanException("Unexpected error output from test")
32 changes: 14 additions & 18 deletions cpp/src/tflite_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

#include <viam/sdk/components/component.hpp>
#include <viam/sdk/config/resource.hpp>
#include <viam/sdk/common/proto_type.hpp>
#include <viam/sdk/common/proto_value.hpp>
#include <viam/sdk/module/service.hpp>
#include <viam/sdk/registry/registry.hpp>
#include <viam/sdk/resource/reconfigurable.hpp>
Expand Down Expand Up @@ -70,7 +70,7 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
// drain.
}

void stop(const vsdk::AttributeMap& extra) noexcept final {
void stop(const vsdk::ProtoStruct& extra) noexcept final {
return stop();
}

Expand Down Expand Up @@ -135,7 +135,7 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
}

std::shared_ptr<named_tensor_views> infer(const named_tensor_views& inputs,
const vsdk::AttributeMap& extra) final {
const vsdk::ProtoStruct& extra) final {
auto state = lease_state_();

// We serialize access to the interpreter. We use a
Expand Down Expand Up @@ -244,7 +244,7 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
return {std::move(inference_result), views};
}

struct metadata metadata(const vsdk::AttributeMap& extra) final {
struct metadata metadata(const vsdk::ProtoStruct& extra) final {
// Just return a copy of our metadata from leased state.
return lease_state_()->metadata;
}
Expand Down Expand Up @@ -284,14 +284,14 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
// Now we can begin parsing and validating the provided `configuration`.
// Pull the model path out of the configuration.
const auto& attributes = state->configuration.attributes();
auto model_path = attributes->find("model_path");
if (model_path == attributes->end()) {
auto model_path = attributes.find("model_path");
if (model_path == attributes.end()) {
std::ostringstream buffer;
buffer << service_name
<< ": Required parameter `model_path` not found in configuration";
throw std::invalid_argument(buffer.str());
}
const auto* const model_path_string = model_path->second->get<std::string>();
const auto* const model_path_string = model_path->second.get<std::string>();
if (!model_path_string || model_path_string->empty()) {
std::ostringstream buffer;
buffer << service_name
Expand All @@ -300,9 +300,9 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
throw std::invalid_argument(buffer.str());
}
std::string label_path_string = ""; // default value for label_path
auto label_path = attributes->find("label_path");
if (label_path != attributes->end()) {
const auto* const lp_string = label_path->second->get<std::string>();
auto label_path = attributes.find("label_path");
if (label_path != attributes.end()) {
const auto* const lp_string = label_path->second.get<std::string>();
if (!lp_string) {
std::ostringstream buffer;
buffer << service_name
Expand Down Expand Up @@ -360,9 +360,9 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
// If present, extract and validate the number of threads to
// use in the interpreter and create an interpreter options
// object to carry that information.
auto num_threads = attributes->find("num_threads");
if (num_threads != attributes->end()) {
const auto* num_threads_double = num_threads->second->get<double>();
auto num_threads = attributes.find("num_threads");
if (num_threads != attributes.end()) {
const auto* num_threads_double = num_threads->second.get<double>();
if (!num_threads_double || !std::isnormal(*num_threads_double) ||
(*num_threads_double < 0) ||
(*num_threads_double >= std::numeric_limits<std::int32_t>::max()) ||
Expand Down Expand Up @@ -459,11 +459,7 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
output_info.shape.push_back(TfLiteTensorDim(tensor, j));
}
if (state->label_path != "") {
if (!output_info.extra) {
output_info.extra = std::make_shared<std::unordered_map<std::string, std::shared_ptr<vsdk::ProtoType>>>();
}
auto protoValue = std::make_shared<vsdk::ProtoType>(std::string(state->label_path));
output_info.extra->insert({"labels", protoValue});
output_info.extra.insert({"labels", state->label_path});
}
state->output_tensor_indices_by_name[output_info.name] = i;
state->metadata.outputs.emplace_back(std::move(output_info));
Expand Down