diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a4725c7..ba282d7 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -33,8 +33,8 @@ add_executable(tflite_cpu target_link_libraries(tflite_cpu PRIVATE Threads::Threads - PRIVATE "$" - PRIVATE "$" + PRIVATE viam-cpp-sdk::viamsdk + PRIVATE tensorflow::tensorflowlite ) install( diff --git a/cpp/conanfile.py b/cpp/conanfile.py index c82de3d..6773f60 100644 --- a/cpp/conanfile.py +++ b/cpp/conanfile.py @@ -41,7 +41,7 @@ def configure(self): self.options["*"].shared = False def requirements(self): - self.requires("viam-cpp-sdk/0.0.11") + self.requires("viam-cpp-sdk/0.0.12") self.requires("tensorflow-lite/2.15.0") self.requires("abseil/20240116.2", override=True) @@ -57,11 +57,3 @@ def build(self): def layout(self): cmake_layout(self, src_folder=".") - - def test(self): - if can_run(self): - cmd = os.path.join(self.cpp.build.bindir, "tflite_module") - stderr = StringIO() - self.run(cmd, env='conanrun', stderr=stderr, ignore_errors=True) - if "main failed with exception:" not in stderr.getvalue(): - raise ConanException("Unexpected error output from test") diff --git a/cpp/src/tflite_cpu.cpp b/cpp/src/tflite_cpu.cpp index 24d0dd8..4113722 100644 --- a/cpp/src/tflite_cpu.cpp +++ b/cpp/src/tflite_cpu.cpp @@ -28,7 +28,7 @@ #include #include -#include +#include #include #include #include @@ -70,7 +70,7 @@ class MLModelServiceTFLite : public vsdk::MLModelService, // drain. } - void stop(const vsdk::AttributeMap& extra) noexcept final { + void stop(const vsdk::ProtoStruct& extra) noexcept final { return stop(); } @@ -135,7 +135,7 @@ class MLModelServiceTFLite : public vsdk::MLModelService, } std::shared_ptr infer(const named_tensor_views& inputs, - const vsdk::AttributeMap& extra) final { + const vsdk::ProtoStruct& extra) final { auto state = lease_state_(); // We serialize access to the interpreter. We use a @@ -244,7 +244,7 @@ class MLModelServiceTFLite : public vsdk::MLModelService, return {std::move(inference_result), views}; } - struct metadata metadata(const vsdk::AttributeMap& extra) final { + struct metadata metadata(const vsdk::ProtoStruct& extra) final { // Just return a copy of our metadata from leased state. return lease_state_()->metadata; } @@ -284,14 +284,14 @@ class MLModelServiceTFLite : public vsdk::MLModelService, // Now we can begin parsing and validating the provided `configuration`. // Pull the model path out of the configuration. const auto& attributes = state->configuration.attributes(); - auto model_path = attributes->find("model_path"); - if (model_path == attributes->end()) { + auto model_path = attributes.find("model_path"); + if (model_path == attributes.end()) { std::ostringstream buffer; buffer << service_name << ": Required parameter `model_path` not found in configuration"; throw std::invalid_argument(buffer.str()); } - const auto* const model_path_string = model_path->second->get(); + const auto* const model_path_string = model_path->second.get(); if (!model_path_string || model_path_string->empty()) { std::ostringstream buffer; buffer << service_name @@ -300,9 +300,9 @@ class MLModelServiceTFLite : public vsdk::MLModelService, throw std::invalid_argument(buffer.str()); } std::string label_path_string = ""; // default value for label_path - auto label_path = attributes->find("label_path"); - if (label_path != attributes->end()) { - const auto* const lp_string = label_path->second->get(); + auto label_path = attributes.find("label_path"); + if (label_path != attributes.end()) { + const auto* const lp_string = label_path->second.get(); if (!lp_string) { std::ostringstream buffer; buffer << service_name @@ -360,9 +360,9 @@ class MLModelServiceTFLite : public vsdk::MLModelService, // If present, extract and validate the number of threads to // use in the interpreter and create an interpreter options // object to carry that information. - auto num_threads = attributes->find("num_threads"); - if (num_threads != attributes->end()) { - const auto* num_threads_double = num_threads->second->get(); + auto num_threads = attributes.find("num_threads"); + if (num_threads != attributes.end()) { + const auto* num_threads_double = num_threads->second.get(); if (!num_threads_double || !std::isnormal(*num_threads_double) || (*num_threads_double < 0) || (*num_threads_double >= std::numeric_limits::max()) || @@ -459,11 +459,7 @@ class MLModelServiceTFLite : public vsdk::MLModelService, output_info.shape.push_back(TfLiteTensorDim(tensor, j)); } if (state->label_path != "") { - if (!output_info.extra) { - output_info.extra = std::make_shared>>(); - } - auto protoValue = std::make_shared(std::string(state->label_path)); - output_info.extra->insert({"labels", protoValue}); + output_info.extra.insert({"labels", state->label_path}); } state->output_tensor_indices_by_name[output_info.name] = i; state->metadata.outputs.emplace_back(std::move(output_info));