diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2e7fee2..b07fe31 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -175,7 +175,7 @@ jobs: -DTRITON_BACKEND_REPO_TAG=${{env.TRITON_REPO_TAG}} \ -DTRITON_CORE_REPO_TAG=${{env.TRITON_REPO_TAG}} \ -DTRITON_COMMON_REPO_TAG=${{env.TRITON_REPO_TAG}} \ - -PAPI_PROFILING_ENABLE=ON \ + -DPAPI_PROFILING_ENABLE=ON \ -DTRITON_ENABLE_MALI_GPU=${{env.TRITON_ENABLE_MALI_GPU}} \ -DTFLITE_ENABLE_RUY=${{env.TFLITE_ENABLE_RUY}} \ -DTFLITE_BAZEL_BUILD=${{env.TFLITE_BAZEL_BUILD}} \ diff --git a/.gitignore b/.gitignore index f772ec4..1346996 100644 --- a/.gitignore +++ b/.gitignore @@ -9,5 +9,5 @@ /.devcontainer /**/triton_qa_models /**/armnn_tflite_backend_triton_model_repo.tar.gz -**/papi_hl_output* +*.csv diff --git a/CMakeLists.txt b/CMakeLists.txt index 940188a..17dbb63 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,10 @@ if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() +set(CMAKE_CXX_STANDARD 17) + +SET(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake" "${CMAKE_MODULE_PATH}") + set(TARGET_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) # Triton Options @@ -118,6 +122,12 @@ if(NOT (ACL_VERSION VERSION_GREATER "21.05")) list(APPEND ACL_BUILD_FLAGS "internal_only=0") endif() +# Enable REPROC++ +set(REPROC++ ON) + +# Numa +option(LIBNUMA_ENABLE "Enable libnuma usage" OFF) + # # Dependencies # @@ -141,8 +151,18 @@ FetchContent_Declare( GIT_REPOSITORY https://github.com/triton-inference-server/backend.git GIT_TAG ${TRITON_BACKEND_REPO_TAG} GIT_SHALLOW ON) +FetchContent_Declare( + tensorpipe + GIT_REPOSITORY https://github.com/pytorch/tensorpipe.git + GIT_TAG bb1473a4b38b18268e8693044afdb8635bc8351b + GIT_SHALLOW ON) +FetchContent_Declare( + reproc + GIT_REPOSITORY https://github.com/DaanDeMeyer/reproc + GIT_TAG v14.2.4 + GIT_SHALLOW ON) -set(MAKE_AVAILABLE_LIST repo-common repo-core repo-backend) +set(MAKE_AVAILABLE_LIST repo-common repo-core repo-backend tensorpipe reproc) if(NOT TFLITE_BAZEL_BUILD) FetchContent_Declare( @@ -169,6 +189,19 @@ configure_file(src/libtriton_armnn_tflite.ldscript include(ExternalProject) +# Handle hwloc +ExternalProject_Add( + hwloc + GIT_REPOSITORY https://github.com/open-mpi/hwloc + GIT_TAG hwloc-2.8.0 + GIT_SHALLOW ON + SOURCE_DIR ${CMAKE_BINARY_DIR}/hwloc + BINARY_DIR ${CMAKE_BINARY_DIR}/hwloc + CONFIGURE_COMMAND ./autogen.sh && ./configure --prefix= --enable-debug=$,"1","0" + BUILD_COMMAND make -j$(nproc) + UPDATE_COMMAND "" + INSTALL_COMMAND make install) + set(TFLITE_LOCATION ${CMAKE_CURRENT_BINARY_DIR}/external/tensorflow_lite) if(TFLITE_BAZEL_BUILD) @@ -335,58 +368,187 @@ if (PAPI_PROFILING_ENABLE) BUILD_COMMAND make -j$(nproc) UPDATE_COMMAND "" INSTALL_COMMAND make install - TEST_COMMAND make test - ) + TEST_COMMAND make test) endif() # -# Handle libs for TFLite Backend +# Handle libs for Model Instance standalone executable # -set(BACKEND_SRCS src/tflite.cc src/tflite_utils.cc src/tflite_utils.h) - +set(MODEL_INSTANCE_SRCS + src/model_instance/model_instance_main.cc + src/model_instance/model_instance.cc + src/model_instance/model_instance.h + src/model_instance/model_instance_utils.h) if(PAPI_PROFILING_ENABLE) - list(APPEND BACKEND_SRCS src/papi_profiler.cc) + list(APPEND MODEL_INSTANCE_SRCS src/model_instance/papi_profiler.cc) endif() -add_library(triton-armnn-tflite-backend SHARED ${BACKEND_SRCS}) - +add_executable(model_instance ${MODEL_INSTANCE_SRCS}) + +set(MODEL_INSTANCE_LINK_LIBS + tensorpipe + triton-core-serverstub + triton-backend-utils) + +# Handle discovery of libnuma +if(LIBNUMA_ENABLE) + find_package(Numa) + if(NUMA_FOUND) + # Here we just make numa available to all of our targets + link_directories(${NUMA_LIBRARY_DIR}) + list(APPEND CMAKE_REQUIRED_LIBRARIES numa) + list(APPEND CMAKE_REQUIRED_INCLUDES ${NUMA_INCLUDE_DIR}) + list(APPEND CMAKE_REQUIRED_LINK_OPTIONS "-L${NUMA_LIBRARY_DIR}") + check_symbol_exists(numa_node_of_cpu "numa.h" NUMA_V2) + if(NUMA_V2) + add_definitions(-DHAVE_LIBNUMA) + message(STATUS "libnuma found, building with support for NUMA nodes") + list(APPEND MODEL_INSTANCE_LINK_LIBS numa) + include_directories(SYSTEM ${NUMA_INCLUDE_DIR}) + else() + message(FATAL_ERROR "libnuma not found, but was requested via option LIBNUMA_ENABLE") + endif() + endif() + mark_as_advanced(NUMA_FOUND) +endif(LIBNUMA_ENABLE) + +set(MODEL_INSTANCE_INCLUDE_DIRS + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${TENSORFLOW_ROOT} # for tflite headers +) if(ARMNN_DELEGATE_ENABLE) - add_dependencies(triton-armnn-tflite-backend armnn) + add_dependencies(model_instance armnn) + list(APPEND MODEL_INSTANCE_INCLUDE_DIRS + ${ARMNN_LOCATION}/include # for armnn headers + ${ARMNN_LOCATION}/src/armnn/delegate/include # for delegate headers + ) + # As per https://review.mlplatform.org/c/ml/armnn/+/7327 + if(ARMNN_VERSION VERSION_GREATER_EQUAL "22.05") + list(APPEND MODEL_INSTANCE_INCLUDE_DIRS ${ARMNN_LOCATION}/src/armnn/profiling) + endif() + target_compile_definitions(model_instance PRIVATE ARMNN_DELEGATE_ENABLE=1) + # Link the armnn lib + target_link_libraries( + model_instance PRIVATE "-L${ARMNN_LOCATION}/lib" -larmnn -larmnnDelegate) endif() if(PAPI_PROFILING_ENABLE) - add_dependencies(triton-armnn-tflite-backend papi) + add_dependencies(model_instance papi) target_compile_definitions( - triton-armnn-tflite-backend + model_instance PRIVATE PAPI_PROFILING_ENABLE=1 ) - target_include_directories(triton-armnn-tflite-backend PRIVATE ${CMAKE_BINARY_DIR}/papi-prefix/include) + list(APPEND MODEL_INSTANCE_INCLUDE_DIRS ${CMAKE_BINARY_DIR}/papi-prefix/include) # Note that linking the STATIC papi library results in a segfault on call to PAPI_library_init, use shared lib - target_link_libraries(triton-armnn-tflite-backend PRIVATE ${CMAKE_BINARY_DIR}/papi-prefix/lib/libpapi.so) + target_link_libraries(model_instance PRIVATE ${CMAKE_BINARY_DIR}/papi-prefix/lib/libpapi.so) endif() +if(LIBNUMA_ENABLE) + target_compile_definitions( + model_instance + PRIVATE LIBNUMA_ENABLE=1 + ) +endif() + +if(TFLITE_BAZEL_BUILD) + list(APPEND MODEL_INSTANCE_INCLUDE_DIRS + ${TENSORFLOW_ROOT}/bazel-tensorflow-lite/external/flatbuffers/include) + # Link the tensorflow lite library from bazel tfile build + target_link_libraries( + model_instance + PRIVATE "-L${TFLITE_LOCATION}/src/tensorflow-lite/bazel-bin/tensorflow/lite" + -ltensorflowlite) +else() + list(APPEND MODEL_INSTANCE_INCLUDE_DIRS + ${TFLITE_LIB_ROOT}/flatbuffers/include) + list(APPEND MODEL_INSTANCE_LINK_LIBS tensorflow-lite) +endif() + +target_include_directories(model_instance PRIVATE ${MODEL_INSTANCE_INCLUDE_DIRS}) +target_link_libraries(model_instance PRIVATE ${MODEL_INSTANCE_LINK_LIBS}) + +target_compile_features(model_instance PRIVATE cxx_std_11) +target_compile_options( + model_instance + PRIVATE + $<$,$,$>: + -Wall + -Wextra + -Wno-unused-parameter + -Wno-type-limits + -Wno-comment + -Werror>) + +set_target_properties( + model_instance + PROPERTIES + POSITION_INDEPENDENT_CODE ON + OUTPUT_NAME model_instance + SKIP_BUILD_RPATH TRUE + BUILD_WITH_INSTALL_RPATH TRUE + INSTALL_RPATH_USE_LINK_PATH FALSE + INSTALL_RPATH "$\{ORIGIN\}" + LINK_FLAGS + "-Wl,--no-as-needed") + +# +# Handle libs for TFLite Backend +# + +set(BACKEND_SRCS + src/tflite.cc + src/tflite_utils.cc + src/tflite_utils.h) + +add_library(triton-armnn-tflite-backend SHARED ${BACKEND_SRCS}) + add_library(TritonArmNNTFLiteBackend::triton-armnn-tflite-backend ALIAS triton-armnn-tflite-backend) set(BACKEND_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/src ${TENSORFLOW_ROOT} # for tflite headers - ${ARMNN_LOCATION}/include # for armnn headers - ${ARMNN_LOCATION}/src/armnn/delegate/include # for delegate headers ) -# As per https://review.mlplatform.org/c/ml/armnn/+/7327 -if(ARMNN_VERSION VERSION_GREATER_EQUAL "22.05") - list(APPEND BACKEND_INCLUDE_DIRS ${ARMNN_LOCATION}/src/armnn/profiling) -endif() +set(BACKEND_LINK_LIBS + triton-core-serverapi triton-core-backendapi triton-core-serverstub + triton-backend-utils tensorpipe reproc++ ${CMAKE_DL_LIBS}) if(TFLITE_BAZEL_BUILD) list(APPEND BACKEND_INCLUDE_DIRS ${TENSORFLOW_ROOT}/bazel-tensorflow-lite/external/flatbuffers/include) + # Link the tensorflow lite library from bazel tfile build + target_link_libraries( + triton-armnn-tflite-backend + PRIVATE "-L${TFLITE_LOCATION}/src/tensorflow-lite/bazel-bin/tensorflow/lite" + -ltensorflowlite) else() list(APPEND BACKEND_INCLUDE_DIRS ${TFLITE_LIB_ROOT}/flatbuffers/include) + list(APPEND BACKEND_LINK_LIBS tensorflow-lite) +endif() + +add_dependencies(triton-armnn-tflite-backend hwloc) +list(APPEND BACKEND_INCLUDE_DIRS ${CMAKE_BINARY_DIR}/hwloc-prefix/include) +target_link_libraries(triton-armnn-tflite-backend PRIVATE ${CMAKE_BINARY_DIR}/hwloc-prefix/lib/libhwloc.so) + +if(ARMNN_DELEGATE_ENABLE) + target_compile_definitions(triton-armnn-tflite-backend PRIVATE ARMNN_DELEGATE_ENABLE=1) +endif() + +if(PAPI_PROFILING_ENABLE) + target_compile_definitions( + triton-armnn-tflite-backend + PRIVATE PAPI_PROFILING_ENABLE=1 + ) +endif() + +if(LIBNUMA_ENABLE) + target_compile_definitions( + triton-armnn-tflite-backend + PRIVATE LIBNUMA_ENABLE=1 + ) endif() target_include_directories(triton-armnn-tflite-backend @@ -404,12 +566,6 @@ target_compile_options( -Wno-comment -Werror>) -# ARMNN_DELEGATE_ENABLE exposed in header so set PUBLIC -if(${ARMNN_DELEGATE_ENABLE}) - target_compile_definitions(triton-armnn-tflite-backend - PUBLIC ARMNN_DELEGATE_ENABLE=1) -endif() # ARMNN_DELEGATE_ENABLE - set_target_properties( triton-armnn-tflite-backend PROPERTIES @@ -423,41 +579,31 @@ set_target_properties( LINK_FLAGS "-Wl,--no-as-needed,--version-script libtriton_armnn_tflite.ldscript") -set(BACKEND_LINK_LIBS - triton-core-serverapi triton-core-backendapi triton-core-serverstub - triton-backend-utils ${CMAKE_DL_LIBS}) - -if(TFLITE_BAZEL_BUILD) - # Link the tensorflow lite library from bazel tfile build - target_link_libraries( - triton-armnn-tflite-backend - PRIVATE "-L${TFLITE_LOCATION}/src/tensorflow-lite/bazel-bin/tensorflow/lite" - -ltensorflowlite) -else() - list(APPEND BACKEND_LINK_LIBS tensorflow-lite) -endif() - target_link_libraries(triton-armnn-tflite-backend PRIVATE ${BACKEND_LINK_LIBS}) -if(ARMNN_DELEGATE_ENABLE) - # Link the armnn lib - target_link_libraries( - triton-armnn-tflite-backend PRIVATE "-L${ARMNN_LOCATION}/lib" -larmnn - -larmnnDelegate) -endif() - # # Install # include(GNUInstallDirs) set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/TritonArmNNTFLiteBackend) +install( + TARGETS model_instance + DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/armnn_tflite) + install( TARGETS triton-armnn-tflite-backend EXPORT triton-armnn-tflite-backend-targets LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/armnn_tflite ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/armnn_tflite) +# Install hwloc libraries +install( + DIRECTORY ${CMAKE_BINARY_DIR}/hwloc-prefix/lib/ + DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/armnn_tflite + FILES_MATCHING + PATTERN "*.so*") + if(ARMNN_DELEGATE_ENABLE) # Install ArmNN libraries and license install( diff --git a/README.md b/README.md index 81f54d4..1ca2d1c 100644 --- a/README.md +++ b/README.md @@ -194,5 +194,21 @@ instance_group [ ``` ## Enabling PAPI events -This backend supports PAPI performance counter sampling. This is exposed through the PAPI High Level API. We support performance counter tracing at the tflite operator level using tflite tracing instrumentation. To enable this, when launching triton pass the flag `--backend-config=armnn_tflite,papi-events=PAPI_TOT_CYC,PAPI_LD_INS`. Internally, the events listed get set to the environment variable `PAPI_EVENTS` as per the PAPI High Level API documentation. Results of this will be written to a newly created `papi_hl_output` folder in the directory you launched the server from. -Internally, the events listed get set to the environment variable `PAPI_EVENTS` as per the PAPI High Level API documentation. Results of this will be written to a newly created `papi_hl_output` folder in the directory you launched the server from. +This backend supports PAPI performance counter sampling. We support performance counter tracing at the tflite operator level using tflite tracing instrumentation. To enable this, you can use the following in your model config: +``` +parameters { + key: "papi_events" + value: { + string_value:"PAPI_TOT_CYC,PAPI_LD_INS" + } +} +parameters { + key: "papi_uncore_events" + value: { + string_value:"tx2_dmc0::UNC_DMC_READS:u:cpu=0" + } +} +``` +`papi_events` is used for the per core events such as total load instructions, and can be tracked at the thread level, `papi_uncore_events` are uncore events which are tracked at the socket level such as userspace DRAM reads for socket 0 in the example above. + +Internally, the events listed get set to the environment variables `PAPI_EVENTS` and `PAPI_UNCORE_EVENTS`. Results of this will be written to a newly created file `counters_*.csv` file for you to use as you wish. diff --git a/cmake/FindNuma.cmake b/cmake/FindNuma.cmake new file mode 100644 index 0000000..94b23c8 --- /dev/null +++ b/cmake/FindNuma.cmake @@ -0,0 +1,43 @@ +# Module for locating libnuma +# +# Read-only variables: +# NUMA_FOUND +# Indicates that the library has been found. +# +# NUMA_INCLUDE_DIR +# Points to the libnuma include directory. +# +# NUMA_LIBRARY_DIR +# Points to the directory that contains the libraries. +# The content of this variable can be passed to link_directories. +# +# NUMA_LIBRARY +# Points to the libnuma that can be passed to target_link_libararies. +# +# Copyright (c) 2013-2020 MulticoreWare, Inc + +include(FindPackageHandleStandardArgs) + +find_path(NUMA_ROOT_DIR + NAMES include/numa.h + PATHS ENV NUMA_ROOT + DOC "NUMA root directory") + +find_path(NUMA_INCLUDE_DIR + NAMES numa.h + HINTS ${NUMA_ROOT_DIR} + PATH_SUFFIXES include + DOC "NUMA include directory") + +find_library(NUMA_LIBRARY + NAMES numa + HINTS ${NUMA_ROOT_DIR} + DOC "NUMA library") + +if (NUMA_LIBRARY) + get_filename_component(NUMA_LIBRARY_DIR ${NUMA_LIBRARY} PATH) +endif() + +mark_as_advanced(NUMA_INCLUDE_DIR NUMA_LIBRARY_DIR NUMA_LIBRARY) + +find_package_handle_standard_args(NUMA REQUIRED_VARS NUMA_ROOT_DIR NUMA_INCLUDE_DIR NUMA_LIBRARY) \ No newline at end of file diff --git a/qa/config-template.pbtxt b/qa/config-template.pbtxt index a114043..561c36a 100644 --- a/qa/config-template.pbtxt +++ b/qa/config-template.pbtxt @@ -43,6 +43,15 @@ string_value:"{{ model.papi_events }}" } {% endif %} +{% if model.papi_uncore_events %} +parameters { +key: "papi_uncore_events" +value: { +string_value:"{{ model.papi_uncore_events }}" +} +} +{% endif %} + instance_group [ {% if model.gpu > 0 %} { diff --git a/qa/helpers/triton_model_config.py b/qa/helpers/triton_model_config.py index d5d3cfe..99a8176 100644 --- a/qa/helpers/triton_model_config.py +++ b/qa/helpers/triton_model_config.py @@ -42,6 +42,7 @@ def __init__( outputs: List[Model.TensorIO], tflite_num_threads: int = None, papi_events: str = None, + papi_uncore_events: str = None, gpu: int = 0, cpu: int = 1, max_batch_size: int = 0, @@ -64,6 +65,7 @@ def __init__( ) self.tflite_num_threads = tflite_num_threads self.papi_events = papi_events + self.papi_uncore_events = papi_uncore_events self.armnn_cpu = armnn_cpu self.armnn_gpu = armnn_gpu self.armnn_cpu_parameters = armnn_cpu_parameters diff --git a/src/config.h b/src/config.h new file mode 100644 index 0000000..3bf4f24 --- /dev/null +++ b/src/config.h @@ -0,0 +1,87 @@ +// +// Copyright © 2023 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include +#include + +// This class is used to map an optimizer option to an index in an array so +// options can be sent across a tensorpipe payload +enum OptimizerOption { + TFLITE_NUM_THREADS, + XNNPACK_ENABLE, + XNNPACK_CPU_NUM_THREADS, + NUMA_ALLOC_POLICY, + NUMA_LOCAL_NODE_ID, + NUMA_REMOTE_NODE_ID, + +#ifdef ARMNN_DELEGATE_ENABLE + ARMNN_CPU_ENABLE, + ARMNN_GPU_ENABLE, + ARMNN_CPU_NUM_THREADS, + ARMNN_CPU_REDUCE_FP32_TO_FP16, + ARMNN_CPU_REDUCE_FP32_TO_BF16, + ARMNN_CPU_FAST_MATH_ENABLED, + ARMNN_GPU_FAST_MATH_ENABLED, + ARMNN_GPU_REDUCE_FP32_TO_FP16, + ARMNN_GPU_REDUCE_FP32_TO_BF16, +#endif // ARMNN_DELEGATE_ENABLE + + COUNT // Just used to track the number of options +}; + +enum class AllocationPolicy { + LOCAL, + WEIGHT_REMOTE_RESULT_LOCAL, + WEIGHT_LOCAL_RESULT_REMOTE, + REMOTE, + NONE +}; + +inline AllocationPolicy +AllocationPolicyFromString(std::string str) +{ + // Convert copy of string to uppercase + std::transform(str.begin(), str.end(), str.begin(), ::toupper); + + if (str == "LOCAL") { + return AllocationPolicy::LOCAL; + } else if (str == "WEIGHT_REMOTE_RESULT_LOCAL") { + return AllocationPolicy::WEIGHT_REMOTE_RESULT_LOCAL; + } else if (str == "WEIGHT_LOCAL_RESULT_REMOTE") { + return AllocationPolicy::WEIGHT_LOCAL_RESULT_REMOTE; + } else if (str == "REMOTE") { + return AllocationPolicy::REMOTE; + } else if (str == "NONE") { + return AllocationPolicy::NONE; + } else { + return AllocationPolicy::NONE; + } +} + +inline std::string +AllocationPolicyToString(const AllocationPolicy& alloc_policy) +{ + switch (alloc_policy) { + case AllocationPolicy::LOCAL: { + return "LOCAL"; + } + case AllocationPolicy::WEIGHT_REMOTE_RESULT_LOCAL: { + return "WEIGHT_REMOTE_RESULT_LOCAL"; + } + case AllocationPolicy::WEIGHT_LOCAL_RESULT_REMOTE: { + return "WEIGHT_LOCAL_RESULT_REMOTE"; + } + case AllocationPolicy::REMOTE: { + return "REMOTE"; + } + case AllocationPolicy::NONE: { + return "NONE"; + } + } + return "NONE"; +} \ No newline at end of file diff --git a/src/model_instance/model_instance.cc b/src/model_instance/model_instance.cc new file mode 100644 index 0000000..6ec9d9e --- /dev/null +++ b/src/model_instance/model_instance.cc @@ -0,0 +1,532 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "model_instance.h" + +#include + +#include +#include + +#include "model_instance_utils.h" + +// Triton backend headers +#include "triton/backend/backend_common.h" + +// TFLite headers +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/type_to_tflitetype.h" + +#ifdef ARMNN_DELEGATE_ENABLE +// ArmNN headers +#include "armnn/ArmNN.hpp" +#include "armnn_delegate.hpp" +#endif // ARMNN_DELEGATE_ENABLE + +void +ModelInstance::Finalize() +{ + pipe_->close(); +} + +void +ModelInstance::Start(const std::string& addr) +{ + pipe_ = context_->connect(addr); + ReceiveFromPipe(); +} + +TfLiteStatus +ModelInstance::BuildInterpreter(tensorpipe::Descriptor descriptor) +{ + // Build the tflite interpreter + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver; + tflite::InterpreterBuilder builder(*model_, resolver); + builder(&interpreter_); + if (!interpreter_) { + return kTfLiteError; + } + + // Set interpreter threads + if (interpreter_->SetNumThreads(std::stoi( + descriptor.payloads[OptimizerOption::TFLITE_NUM_THREADS].metadata)) != + kTfLiteOk) { + return kTfLiteError; + } + + // Get range of cpus to pin to if specified + for (std::vector::iterator it = + descriptor.payloads.begin() + OptimizerOption::COUNT + 1; + it < descriptor.payloads.end(); it++) { + cpus_.push_back(std::stoi(it->metadata)); + } + + // Set numa parameters + numa_alloc_policy_ = AllocationPolicyFromString( + descriptor.payloads[OptimizerOption::NUMA_ALLOC_POLICY].metadata); + + local_numa_node_id_ = std::stoi( + descriptor.payloads[OptimizerOption::NUMA_LOCAL_NODE_ID].metadata); + + remote_numa_node_id_ = std::stoi( + descriptor.payloads[OptimizerOption::NUMA_REMOTE_NODE_ID].metadata); + +#ifdef ARMNN_DELEGATE_ENABLE + armnn::OptimizerOptions armnn_optimizer_options_cpu; + armnn::OptimizerOptions armnn_optimizer_options_gpu; + bool armnn_cpu_delegate_enabled = + descriptor.payloads[OptimizerOption::ARMNN_CPU_ENABLE].metadata == + std::string("y"); + + bool armnn_gpu_delegate_enabled = + descriptor.payloads[OptimizerOption::ARMNN_GPU_ENABLE].metadata == + std::string("y"); + + if (armnn_cpu_delegate_enabled || armnn_gpu_delegate_enabled) { + armnnDelegate::DelegateOptions armnn_delegate_options = + armnnDelegate::TfLiteArmnnDelegateOptionsDefault(); + + // Set backend prefs based on gpu or cpu selection + if (armnn_gpu_delegate_enabled) { + armnn_delegate_options.SetBackends( + {armnn::Compute::GpuAcc, armnn::Compute::CpuAcc}); + armnn_optimizer_options_gpu.m_ReduceFp32ToFp16 = + descriptor.payloads[OptimizerOption::ARMNN_GPU_REDUCE_FP32_TO_FP16] + .metadata == std::string("on"); + armnn_optimizer_options_gpu.m_ReduceFp32ToBf16 = + descriptor.payloads[OptimizerOption::ARMNN_GPU_REDUCE_FP32_TO_BF16] + .metadata == std::string("on"); + armnn::BackendOptions gpu_fast_math_option( + "GpuAcc", + {{"FastMathEnabled", + descriptor.payloads[OptimizerOption::ARMNN_GPU_FAST_MATH_ENABLED] + .metadata == std::string("on")}}); + armnn_optimizer_options_gpu.m_ModelOptions.push_back( + gpu_fast_math_option); + armnn_delegate_options.SetOptimizerOptions(armnn_optimizer_options_gpu); + } else { + // Set backend pref to Neon ACL backend + armnn_delegate_options.SetBackends({armnn::Compute::CpuAcc}); + armnn_optimizer_options_cpu.m_ReduceFp32ToFp16 = + descriptor.payloads[OptimizerOption::ARMNN_CPU_REDUCE_FP32_TO_FP16] + .metadata == std::string("on"); + armnn_optimizer_options_cpu.m_ReduceFp32ToBf16 = + descriptor.payloads[OptimizerOption::ARMNN_CPU_REDUCE_FP32_TO_BF16] + .metadata == std::string("on"); + armnn::BackendOptions cpu_fast_math_option( + "CpuAcc", + {{"FastMathEnabled", + descriptor.payloads[OptimizerOption::ARMNN_CPU_FAST_MATH_ENABLED] + .metadata == std::string("on")}}); + armnn_optimizer_options_cpu.m_ModelOptions.push_back( + cpu_fast_math_option); + armnn::BackendOptions num_threads_option( + "CpuAcc", + {{"NumberOfThreads", + static_cast(std::stoi( + descriptor.payloads[OptimizerOption::ARMNN_CPU_NUM_THREADS] + .metadata))}}); + armnn_optimizer_options_cpu.m_ModelOptions.push_back(num_threads_option); + armnn_delegate_options.SetOptimizerOptions(armnn_optimizer_options_cpu); + } + + // Create ArmNN Delegate with options registered in model state + std::unique_ptr< + TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)> + armnn_delegate( + armnnDelegate::TfLiteArmnnDelegateCreate(armnn_delegate_options), + armnnDelegate::TfLiteArmnnDelegateDelete); + + // Instruct the Interpreter to use the armnnDelegate + if (interpreter_->ModifyGraphWithDelegate(std::move(armnn_delegate)) != + kTfLiteOk) { + return kTfLiteError; + } + LogDelegation("armnn"); + } else if ( + descriptor.payloads[OptimizerOption::XNNPACK_ENABLE].metadata == + std::string("y")) { +#else + if (descriptor.payloads[OptimizerOption::XNNPACK_ENABLE].metadata == + std::string("y")) { +#endif // ARMNN_DELEGATE_ENABLE + // Create the XNNPack Delegate + TfLiteXNNPackDelegateOptions options = + TfLiteXNNPackDelegateOptionsDefault(); + + options.num_threads = std::stoi( + descriptor.payloads[OptimizerOption::XNNPACK_CPU_NUM_THREADS].metadata); + + tflite::Interpreter::TfLiteDelegatePtr xnnpack_delegate( + TfLiteXNNPackDelegateCreate(&options), + [](TfLiteDelegate* xnnpack_delegate) { + TfLiteXNNPackDelegateDelete(xnnpack_delegate); + }); + + // Instruct the Interpreter to use the xnnpack + if (interpreter_->ModifyGraphWithDelegate(std::move(xnnpack_delegate)) != + kTfLiteOk) { + return kTfLiteError; + } + LogDelegation("xnnpack"); + } else { + LOG_MESSAGE(TRITONSERVER_LOG_INFO, "No delegates used for model execution"); + } + + return kTfLiteOk; +} + +void +ModelInstance::LogDelegation(const std::string& delegate_name) +{ + std::unordered_set checked_node_ids; + unsigned int num_delegated_kernels = 0; + for (uint64_t i = 0; i < interpreter_->execution_plan().size(); i++) { + int node_id = interpreter_->execution_plan()[i]; + if (checked_node_ids.find(node_id) != checked_node_ids.end()) { + continue; + } + const TfLiteNode& node = + interpreter_->node_and_registration(node_id)->first; + + if (node.delegate != nullptr) { + num_delegated_kernels++; + checked_node_ids.insert(node_id); + } + } + bool fully_delegated = + (num_delegated_kernels == 1 && + interpreter_->execution_plan().size() == 1); + + if (fully_delegated) { + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, ("Applied " + delegate_name + + " delegate, and the model graph will be " + "completely executed by the delegate.") + .c_str()); + } else if (num_delegated_kernels > 0) { + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + ("Applied " + delegate_name + + " delegate, and the model graph will be paritally executed by the " + "delegate w/ " + + std::to_string(num_delegated_kernels) + " delegate kernels.") + .c_str()); + } else { + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, ("Though " + delegate_name + + " delegate is applied, the model graph will " + "not be executed by the delegate.") + .c_str()); + } +} + +void +ModelInstance::ReceiveFromPipe() +{ + pipe_->readDescriptor([this]( + const tensorpipe::Error& error, + tensorpipe::Descriptor descriptor) { + if (error) { + if (error.isOfType()) { + // Expected. + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Remote side hungup: ") + error.what()).c_str()); + } else { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + (std::string("Unexpected error when reading from accepted pipe: ") + + error.what()) + .c_str()); + } + return; + } + if (descriptor.metadata == "model_load") { + LoadModelFromPipe(descriptor); + } else if (descriptor.metadata == "model_input") { + Infer(descriptor); + } + }); +} + +void +ModelInstance::LoadModelFromPipe(tensorpipe::Descriptor descriptor) +{ + // TODO: Make sure this can only be called once as it loads the model and + // builds the interpreter + tensorpipe::Allocation allocation; + allocation.payloads.resize(descriptor.payloads.size()); + allocation.payloads[OptimizerOption::COUNT].data = + new char[descriptor.payloads[OptimizerOption::COUNT].length]; + pipe_->read( + allocation, + [this, descriptor, allocation](const tensorpipe::Error& error) { + if (error) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + ("Failed to read model from pipe with err:" + error.what()) + .c_str()); + return; + } + // Load the tflite model from the buffer + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates + builtin_op_resolver; + model_ = tflite::FlatBufferModel::BuildFromBuffer( + reinterpret_cast( + allocation.payloads[OptimizerOption::COUNT].data), + descriptor.payloads[OptimizerOption::COUNT].length); + + // Initalize the interpreter after loading the flatbuffers model + tensorpipe::Message tp_msg; + bool success = BuildInterpreter(descriptor) == kTfLiteOk; + tp_msg.metadata = success ? "success" : "fail"; + pipe_->write(tp_msg, [this, success](const tensorpipe::Error& error) { + if (error) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + ("Failed send model load ack:" + error.what()).c_str()); + return; + } +#ifdef LIBNUMA_ENABLE + if (success) { + // Model is loaded, apply the numa policy + InitNuma(local_numa_node_id_, remote_numa_node_id_); + } +#endif // LIBNUMA_ENABLE + }); + + // Arm for getting more data + ReceiveFromPipe(); + }); +} + +void +ModelInstance::Infer(tensorpipe::Descriptor& descriptor) +{ + bool allocate_tensors = first_inference_; + bool success = true; + + if (first_inference_) { + allocation_.tensors.resize(descriptor.tensors.size()); + tp_response_msg_.tensors.resize(interpreter_->outputs().size()); + } + + // Get model inputs from request and ready the buffers (Allocation obj) to + // write tensor data + for (uint64_t i = 0; i < descriptor.tensors.size(); ++i) { + // If the size of the incoming tensor + // is different from the last call, tell the interpreter to resize the + // input tensor and note that we are going to have to make another call to + // AllocateTensors below + + // First element of tensor_info is input tensor index, remaining is the dims + // of the input tensor + int input_tensor_index = std::stoi(descriptor.tensors[i].metadata); + + // incoming_length holds the num bytes of the incoming vector + int incoming_length = descriptor.tensors[i].length; + + int tflite_input_tensor_len = + interpreter_->tensor(input_tensor_index)->bytes; + + if (incoming_length != tflite_input_tensor_len) { + // Resize input tensors based on current total batch size + TfLiteIntArray* tflite_input_tensor_dims = + interpreter_->tensor(input_tensor_index)->dims; + std::vector tflite_input_shape( + tflite_input_tensor_dims->data, + (tflite_input_tensor_dims->data + tflite_input_tensor_dims->size)); + + allocate_tensors = true; + + // Set the new batch size + tflite_input_shape[0] = incoming_length > tflite_input_tensor_len + ? incoming_length / tflite_input_tensor_len + : tflite_input_tensor_len / incoming_length; + + interpreter_->ResizeInputTensor(input_tensor_index, tflite_input_shape); + } + } + + // Once we have resized all input tensors in the loop above, + // now we can allocate the memory plan within the tflite runtime if + // necessary + if (allocate_tensors || first_inference_) { + if (interpreter_->AllocateTensors() != kTfLiteOk) { + success = false; + } + + // Assign Cpu buffers to read incoming tensor bytes into after allocate + // tensors is called + for (uint64_t i = 0; i < descriptor.tensors.size(); ++i) { + allocation_.tensors[i].buffer = tensorpipe::CpuBuffer{ + .ptr = interpreter_->tensor(std::stoi(descriptor.tensors[i].metadata)) + ->data.raw}; + } + } + + pipe_->read( + allocation_, + [this, &success, allocate_tensors](const tensorpipe::Error& error) { + success = !error; + + // At this point our input tensors should be written to by the read + // function, now we invoke the interpreter and read the output + if (interpreter_->Invoke() != kTfLiteOk) { + success = false; + } else { + // After the first inference, all threads should be alive + if (first_inference_) { +#ifdef PAPI_PROFILING_ENABLE + papi_profiler_ = MaybeCreatePapiProfiler(); + interpreter_->AddProfiler(papi_profiler_.get()); +#endif // PAPI_PROFILING_ENABLE + + // If cpus are specified pin the inference threads + if (cpus_.size() > 0) { + int i = 0; + for (pid_t& tid : InferenceThreadIds()) { + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + // Selected cpu loops around if more threads than cpus + CPU_SET(cpus_[i++ % cpus_.size()], &cpuset); + int rc = sched_setaffinity(tid, sizeof(cpu_set_t), &cpuset); + if (rc != 0) { + std::cout << "Error calling sched_setaffinity: " << rc + << "\n"; + } + } + } + } + } + + first_inference_ = false; + + // Write output back to client + if (!success) { + tp_response_msg_.metadata = "f"; + } else if (allocate_tensors) { + // If we (re)allocated tensors then we need to update response message + for (uint64_t i = 0; i < interpreter_->outputs().size(); ++i) { + int output_index = interpreter_->outputs()[i]; + TfLiteTensor* output_tensor = interpreter_->tensor(output_index); + tensorpipe::Message::Tensor tensor; + // We use the output tensor name as the metadata in the request + tensor.metadata = std::string(output_tensor->name); + tensor.length = output_tensor->bytes; + tensor.buffer = + tensorpipe::CpuBuffer{.ptr = output_tensor->data.raw}; + tp_response_msg_.tensors[i] = tensor; + } + } + pipe_->write(tp_response_msg_, [](const tensorpipe::Error& error) { + if (error) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + ("Failed to send inference response to client. Details:" + + error.what()) + .c_str()); + } + }); + // Arm for getting more data + ReceiveFromPipe(); + }); +} + +#ifdef LIBNUMA_ENABLE +void +ModelInstance::InitNuma(int local_node_id, int remote_node_id) +{ + if (numa_alloc_policy_ == AllocationPolicy::NONE) { + LOG_MESSAGE( + TRITONSERVER_LOG_WARN, "Allocation policy ignored, policy is NONE"); + return; + } + + if (numa_available() < 0) { + LOG_MESSAGE( + TRITONSERVER_LOG_WARN, + "System does not support NUMA API, Allocation policy ignored"); + return; + } else if (num_numa_nodes_ < 2) { + LOG_MESSAGE( + TRITONSERVER_LOG_WARN, + "Only one numa node available to system. Allocation policy " + "ignored\n"); + return; + } + + // Set numa mem pollicies + // In the case of the split policies, we need to explictly move the pages of + // the weights to the target numa node, as it goes against the set memory + // policy this process was launched with + switch (numa_alloc_policy_) { + case AllocationPolicy::WEIGHT_LOCAL_RESULT_REMOTE: { + MoveModelWeights(local_node_id); + break; + } + case AllocationPolicy::WEIGHT_REMOTE_RESULT_LOCAL: { + MoveModelWeights(remote_node_id); + break; + } + default: { + break; + } + } +} + +void +ModelInstance::MoveModelWeights(int numa_node_id) +{ + // Get pointer to base of mmapped model file + const void* model_file_base = model_->allocation()->base(); + int page_size = getpagesize(); + + std::vector pages(model_->allocation()->bytes() / page_size + 1); + + char* begin = (char*)model_file_base; + char* end = begin + model_->allocation()->bytes(); + + int i = 0; + for (char* piter = (char*)AlignPage(model_file_base); piter < end; + piter += page_size) { + pages[i++] = (void*)piter; + } + + std::vector dst(pages.size(), numa_node_id); + std::vector status(pages.size(), 0); + + // Touch all pages of the file to force mapping to phys mem + volatile char c; + for (char* piter = (char*)AlignPage(model_file_base); piter < end; + piter += page_size) { + c = *piter; + } + // This is just to avoid the unused var compiler warning + (void)c; + + // With all pages mapped, now move them to target numa node + int ret = numa_move_pages( + 0, pages.size(), pages.data(), dst.data(), status.data(), + MPOL_MF_MOVE_ALL); + + if (ret < 0) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + (std::string("Numa move page error: ") + strerror(errno)).c_str()); + for (auto& i : status) { + if (i < 0) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + (std::string("Page error status: ") + strerror(i)).c_str()); + } + } + } +} +#endif // LIBNUMA_ENABLE diff --git a/src/model_instance/model_instance.h b/src/model_instance/model_instance.h new file mode 100644 index 0000000..52dd9ee --- /dev/null +++ b/src/model_instance/model_instance.h @@ -0,0 +1,119 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "config.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/optional_debug_tools.h" +#include "tensorpipe/tensorpipe.h" + + +#ifdef PAPI_PROFILING_ENABLE +#include "papi.h" +#include "papi_profiler.h" +#endif // PAPI_PROFILING_ENABLE + +#ifdef LIBNUMA_ENABLE +// Lib Numa headers +#include +#include +#endif // LIBNUMA_ENABLE + +// ModelInstance for backend end execution of model +class ModelInstance { + public: + ModelInstance() + { + context_ = std::make_shared(); + auto transportContext = tensorpipe::transport::shm::create(); + context_->registerTransport(0 /* priority */, "shm", transportContext); + // Register cma shm channel + auto cmaChannel = tensorpipe::channel::cma::create(); + context_->registerChannel(0 /* low priority */, "cma", cmaChannel); + } + + ~ModelInstance() { Finalize(); } + + // Start model instance and attempt to connect to passed address + void Start(const std::string& addr); + + // Cleanup + void Finalize(); + + // Issue a receive request pipe + void ReceiveFromPipe(); + + private: + // Callback for new connection is accepted. + void OnAccepted(const tensorpipe::Error&, std::shared_ptr); + + // Callback for loading a tflite model. + void LoadModelFromPipe(tensorpipe::Descriptor descriptor); + + // Builds the tflite interpreter based on passed descriptor + TfLiteStatus BuildInterpreter(tensorpipe::Descriptor descriptor); + + void LogDelegation(const std::string& delegate_name); + + // Callback for inferencing on a loaded tflite model. + void Infer(tensorpipe::Descriptor& descriptor); + + // Numa policy for instance + AllocationPolicy numa_alloc_policy_; + + // Local numa node id + int local_numa_node_id_ = 0; + + // remote numa node id + int remote_numa_node_id_ = 1; + + // thread ids + std::vector inference_thread_ids_; + +#ifdef LIBNUMA_ENABLE + // Initalize numa policy for this model + void InitNuma(int local_node_id, int remote_node_id); + + // Move model weights to target numa node + void MoveModelWeights(int numa_node_id); + + // Numa nodes available to the instance + const int num_numa_nodes_ = numa_max_node() + 1; +#endif // LIBNUMA_ENABLE + + // Global tensorpipe context + std::shared_ptr context_; + + // Pipe for client connection + std::shared_ptr pipe_; + + // Tflite interpreter + std::unique_ptr interpreter_; + + // Tflite model + std::unique_ptr model_; + + // Unique model instance name + std::string model_instance_name_; + + // State variable to register whether inference has been called at least once + bool first_inference_ = true; + + // Tensorpipe allocation that we can reuse to write inputs into + tensorpipe::Allocation allocation_; + + // Tensorpipe response message we can reuse to write outputs into + tensorpipe::Message tp_response_msg_; + + // CPU Range + std::vector cpus_; + +#ifdef PAPI_PROFILING_ENABLE + std::unique_ptr papi_profiler_; +#endif // PAPI_PROFILING_ENABLE +}; \ No newline at end of file diff --git a/src/model_instance/model_instance_main.cc b/src/model_instance/model_instance_main.cc new file mode 100644 index 0000000..6918375 --- /dev/null +++ b/src/model_instance/model_instance_main.cc @@ -0,0 +1,89 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include + +#include +#include + +#include "model_instance.h" + +// Triton backend headers +#include "triton/backend/backend_common.h" + +#ifdef PAPI_PROFILING_ENABLE +#include "papi.h" +#endif // PAPI_PROFILING_ENABLE + +int +main(int argc, char* argv[]) +{ +#ifdef PAPI_PROFILING_ENABLE + // Init PAPI library + if (PAPI_library_init(PAPI_VER_CURRENT) != PAPI_VER_CURRENT) { + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Failed to init PAPI lib"); + return 1; + } + if (PAPI_multiplex_init() != PAPI_OK) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, "Failed to init multiplexing for PAPI lib"); + return 1; + } + if (PAPI_thread_init(pthread_self) != PAPI_OK) { + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Failed to init PAPI thread lib"); + return 1; + } +#endif // PAPI_PROFILING_ENABLE + + // Parse listen address + if (argc != 2) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + "Args should be model_instance "); + + return 1; + } + const char* addr = argv[1]; + + // block signals in this thread and subsequently + // spawned threads + sigset_t sigset; + sigemptyset(&sigset); + sigaddset(&sigset, SIGINT); + sigaddset(&sigset, SIGTERM); + pthread_sigmask(SIG_BLOCK, &sigset, nullptr); + + std::atomic shutdown_requested(false); + std::mutex cv_mutex; + std::condition_variable cv; + + auto signal_handler = [&shutdown_requested, &cv, &sigset]() { + int signum = 0; + // wait until a signal is delivered: + sigwait(&sigset, &signum); + shutdown_requested.store(true); + // notify all waiting workers to check their predicate: + cv.notify_all(); + return signum; + }; + + auto ft_signal_handler = std::async(std::launch::async, signal_handler); + + ModelInstance model_instance; + + // Will connect to the address provided as the first argument in the list + model_instance.Start(std::string(addr)); + + LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, "Model instance running..."); + + // wait for signal handler to complete + int signal = ft_signal_handler.get(); + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Received signal: ") + std::to_string(signal)).c_str()); + + return 0; +} \ No newline at end of file diff --git a/src/model_instance/model_instance_utils.h b/src/model_instance/model_instance_utils.h new file mode 100644 index 0000000..edc3a35 --- /dev/null +++ b/src/model_instance/model_instance_utils.h @@ -0,0 +1,73 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +#include +#include + +// Triton backend headers +#include "triton/backend/backend_common.h" + +#ifdef PAPI_PROFILING_ENABLE +#include "papi.h" + +inline bool +PAPIEventValid(std::string& event_name) +{ + int event_set = PAPI_NULL; + bool valid = false; + if (PAPI_create_eventset(&event_set) == PAPI_OK) { + valid = PAPI_add_named_event(event_set, event_name.c_str()) == PAPI_OK; + if (valid) { + if (PAPI_cleanup_eventset(event_set) != PAPI_OK) { + } + } + if (PAPI_destroy_eventset(&event_set) != PAPI_OK) { + } + } + return valid; +} +#endif // PAPI_PROFILING_ENABLE + +inline std::vector +CurrentThreadIds() +{ + std::vector r; + for (auto& p : std::filesystem::directory_iterator("/proc/self/task")) { + if (p.is_directory()) { + r.push_back(std::stoi(p.path().filename().string())); + } + } + return r; +} + +inline std::vector +InferenceThreadIds() +{ + // We only care about the 4th thread in the process on, as these are used + // for inference + std::vector current_threads = CurrentThreadIds(); + return std::vector(current_threads.begin() + 3, current_threads.end()); +} + +inline void +LogThreads() +{ + for (auto pid : CurrentThreadIds()) { + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, ("Thread id: " + std::to_string(pid)).c_str()); + } +} + +// Get base page address for given pointer +inline static void* +AlignPage(const void* ptr) +{ + static uintptr_t PAGE_MASK = ~(uintptr_t(getpagesize() - 1)); + return (void*)(((uintptr_t)ptr) & PAGE_MASK); +} \ No newline at end of file diff --git a/src/model_instance/papi_profiler.cc b/src/model_instance/papi_profiler.cc new file mode 100644 index 0000000..aeba119 --- /dev/null +++ b/src/model_instance/papi_profiler.cc @@ -0,0 +1,345 @@ +// +// Copyright © 2023 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "papi_profiler.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// Triton backend headers +#include "model_instance_utils.h" +#include "papi.h" +#include "triton/backend/backend_common.h" + +constexpr uint32_t kInvalidEventHandle = static_cast(~0) - 1; + +void +handle_error(int retval, int line, const std::string& file) +{ + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + ("PAPI error at line " + file + ":" + std::to_string(line) + " " + + std::to_string(retval) + ", " + PAPI_strerror(retval)) + .c_str()); + + // TODO: graceful exit here + exit(1); +} + +class PapiProfiler : public tflite::Profiler { + public: + PapiProfiler( + const std::vector& papi_events, + const std::vector& papi_uncore_events, + const std::vector inf_thread_ids) + : supported_event_types_( + static_cast(EventType::DELEGATE_OPERATOR_INVOKE_EVENT) + + static_cast(EventType::OPERATOR_INVOKE_EVENT)), + papi_events_(papi_events), papi_uncore_events_(papi_uncore_events), + inf_thread_ids_(inf_thread_ids) + { + // Reserve space for recording the data ahead of time + papi_regions_.reserve(1000); + timings_.reserve(1000); + + int retval; + + // Handle core specific events per inference thread + if (!papi_events_.empty()) { + for (uint64_t i = 0; i < inf_thread_ids_.size(); ++i) { + event_sets_.push_back(PAPI_NULL); + retval = PAPI_create_eventset(&event_sets_.back()); + if (retval != PAPI_OK) { + handle_error(retval, __LINE__, __FILE__); + } + for (auto& event_name : papi_events_) { + retval = PAPI_add_named_event(event_sets_.back(), event_name.c_str()); + if (retval != PAPI_OK) + handle_error(retval, __LINE__, __FILE__); + } + + // Attach event to thread + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + ("Attaching to " + std::to_string(inf_thread_ids_[i])).c_str()); + retval = PAPI_attach(event_sets_.back(), inf_thread_ids_[i]); + if (retval != PAPI_OK) + handle_error(retval, __LINE__, __FILE__); + + // Start eventset + retval = PAPI_start(event_sets_.back()); + if (retval != PAPI_OK) + handle_error(retval, __LINE__, __FILE__); + } + event_values_.resize(papi_events_.size()); + } + + // Handle uncore events separately + if (!papi_uncore_events_.empty()) { + retval = PAPI_create_eventset(&uncore_event_set_); + if (retval != PAPI_OK) { + handle_error(retval, __LINE__, __FILE__); + } + for (auto& event_name : papi_uncore_events_) { + retval = PAPI_add_named_event(uncore_event_set_, event_name.c_str()); + if (retval != PAPI_OK) + handle_error(retval, __LINE__, __FILE__); + } + uncore_event_values_.resize(papi_uncore_events_.size()); + // Start uncore eventset + retval = PAPI_start(uncore_event_set_); + if (retval != PAPI_OK) + handle_error(retval, __LINE__, __FILE__); + } + } + + ~PapiProfiler() + { + // Save results to file + std::ofstream myfile; + auto now = std::chrono::system_clock::now(); + auto utc = + std::chrono::duration_cast(now.time_since_epoch()) + .count(); + + myfile.open(("counters_" + std::to_string(utc) + ".csv").c_str()); + // Header + myfile << "op_id,thread_id,sample_id,papi_event,value\n"; + // Iterate over map keyed on tflite operation id, with values being a vector + // of counter values for each tracked perf event + for (auto& event : results_) { + for (uint64_t i = 0; i < event.second.size(); ++i) { + myfile << event.first << "," + << inf_thread_ids_[i / papi_events_.size() % event_sets_.size()] + << "," << i / (papi_events_.size() * event_sets_.size()) << "," + << papi_events_[i % papi_events_.size()] << "," + << event.second[i] << "\n"; + } + } + + for (auto& event : results_uncore_) { + // Now write the uncore events with a dummy thread id of -1 + for (uint64_t i = 0; i < results_uncore_[event.first].size(); ++i) { + myfile << event.first << "," << -1 << "," + << i / papi_uncore_events_.size() << "," + << papi_uncore_events_[i % papi_uncore_events_.size()] << "," + << results_uncore_[event.first][i] << "\n"; + } + } + + for (auto& event : results_timings_) { + // Now write the timing events with a dummy thread id of -1 + for (uint64_t i = 0; i < results_timings_[event.first].size(); ++i) { + myfile << event.first << "," << -1 << "," << i << "," + << "TIME_NS" + << "," << results_timings_[event.first][i] << "\n"; + } + } + myfile.close(); + + for (auto& event_set : event_sets_) { + PAPI_cleanup_eventset(event_set); + PAPI_destroy_eventset(&event_set); + } + } + + // This function wants to return a handle to the profile event, which seems to + // be a unique value. Because we are interested in the op names, we just has + // the op tag to generate the event handle value. + // In the case of and Op event, metadata1 holds the node index, and metadata 2 + // holds the subgraph index + uint32_t BeginEvent( + const char* tag, EventType event_type, int64_t event_metadata1, + int64_t event_metadata2) override + { + if (!ShouldAddEvent(event_type)) { + return kInvalidEventHandle; + } + + // Get a unique name for the papi computation region + std::string trace_event_tag = tag; + trace_event_tag += ("_" + std::to_string(event_metadata1)); + + int retval; + + if (!papi_events_.empty()) { // Reset event set attached to each thread + for (uint64_t i = 0; i < event_sets_.size(); ++i) { + // Reset counters + retval = PAPI_reset(event_sets_[i]); + if (retval != PAPI_OK) + handle_error(retval, __LINE__, __FILE__); + } + } + + // Handle uncore events + if (!papi_uncore_events_.empty()) { + // Reset counters + retval = PAPI_reset(uncore_event_set_); + if (retval != PAPI_OK) + handle_error(retval, __LINE__, __FILE__); + } + + event_index_++; + papi_regions_[event_index_] = std::move(trace_event_tag); + timings_[event_index_] = PAPI_get_real_nsec(); + return event_index_; + } + + void EndEvent(uint32_t event_handle) override + { + if (event_handle == kInvalidEventHandle) { + return; + } + + // Push back the op timing + results_timings_[papi_regions_[event_handle]].push_back( + PAPI_get_real_nsec() - timings_[event_handle]); + + // For performance reserve space for 10000 elements for each perf event in + // results + if (results_[papi_regions_[event_handle]].empty()) { + results_[papi_regions_[event_handle]].reserve( + papi_events_.size() * 10000); + results_timings_.reserve(10000); + } + if (results_uncore_[papi_regions_[event_handle]].empty()) { + results_uncore_[papi_regions_[event_handle]].reserve( + papi_uncore_events_.size() * 10000); + } + + int retval; + + if (!papi_events_.empty()) { // For each thread we are profiling + for (uint64_t i = 0; i < event_sets_.size(); ++i) { + retval = PAPI_read(event_sets_[i], event_values_.data()); + if (retval != PAPI_OK) + handle_error(retval, __LINE__, __FILE__); + // Write event counter values to end of results vector for current op + results_[papi_regions_[event_handle]].insert( + results_[papi_regions_[event_handle]].end(), event_values_.begin(), + event_values_.end()); + } + } + // Handle uncore events + if (!papi_uncore_events_.empty()) { + retval = PAPI_read(uncore_event_set_, uncore_event_values_.data()); + if (retval != PAPI_OK) + handle_error(retval, __LINE__, __FILE__); + // For each of the events we collected a counter value for + results_uncore_[papi_regions_[event_handle]].insert( + results_uncore_[papi_regions_[event_handle]].end(), + uncore_event_values_.begin(), uncore_event_values_.end()); + } + } + + protected: + inline bool ShouldAddEvent(EventType event_type) + { + return (static_cast(event_type) & supported_event_types_) != 0; + } + + private: + uint32_t event_index_ = 0; + std::unordered_map papi_regions_; + std::unordered_map timings_; + const uint64_t supported_event_types_; + + // Vector holding the papi event names we are tracking for each core/thread + std::vector papi_events_; + + // Vector holding the papi event names we are tracking which are socket + // specific + std::vector papi_uncore_events_; + + // Vector holding papi event set data structures (one per tracked inf thread) + std::vector event_sets_; + + // Vector holding papi event set data structures for our uncore events because + // this is per socket, we only need one event set + int uncore_event_set_ = PAPI_NULL; + + // We only care about the 4th thread in the process on, as these are used for + // inference + std::vector inf_thread_ids_; + + // Vector to hold papi counter values when we read them + std::vector event_values_; + + // Vector to hold papi uncore values when we read them + std::vector uncore_event_values_; + + // Vector holding all per core counter values to be processed at end + std::unordered_map> results_; + + // Vector holding all per core counter values to be processed at end + std::unordered_map> results_uncore_; + + // Vector holding op timings + std::unordered_map> results_timings_; +}; + +std::unique_ptr +MaybeCreatePapiProfiler() +{ + // Per core events + char* papi_events = getenv("PAPI_EVENTS"); + std::vector papi_events_vec; + if (papi_events != NULL) { + // Parse out all papi events indivdually + std::stringstream ss(papi_events); + while (ss.good()) { + std::string substr; + std::getline(ss, substr, ','); + if (!PAPIEventValid(substr)) { + LOG_MESSAGE( + TRITONSERVER_LOG_WARN, + ("Event: " + substr + " invalid, op level profiling disabled!") + .c_str()); + return nullptr; + } + papi_events_vec.push_back(substr); + } + } + + // Uncore events + char* papi_uncore_events = getenv("PAPI_UNCORE_EVENTS"); + std::vector papi_uncore_events_vec; + if (papi_uncore_events != NULL) { + // Parse out all papi events indivdually + std::stringstream ss(papi_uncore_events); + while (ss.good()) { + std::string substr; + std::getline(ss, substr, ','); + if (!PAPIEventValid(substr)) { + LOG_MESSAGE( + TRITONSERVER_LOG_WARN, + ("Event: " + substr + " invalid, op level profiling disabled!") + .c_str()); + return nullptr; + } + papi_uncore_events_vec.push_back(substr); + } + } + + if ((papi_events == NULL) && (papi_uncore_events == NULL)) { + LOG_MESSAGE( + TRITONSERVER_LOG_WARN, + "PAPI_EVENTS nor PAPI_UNCORE_EVENTS specified, op level profiling " + "disabled!"); + return nullptr; + } + + return std::unique_ptr(new PapiProfiler( + papi_events_vec, papi_uncore_events_vec, InferenceThreadIds())); +} diff --git a/src/papi_profiler.h b/src/model_instance/papi_profiler.h similarity index 90% rename from src/papi_profiler.h rename to src/model_instance/papi_profiler.h index 434f90f..470bb40 100644 --- a/src/papi_profiler.h +++ b/src/model_instance/papi_profiler.h @@ -8,7 +8,6 @@ #include #include "tensorflow/lite/core/api/profiler.h" -#include "triton/core/tritonbackend.h" // Creates a profiler which reports the papi traced events. diff --git a/src/papi_profiler.cc b/src/papi_profiler.cc deleted file mode 100644 index 36b63cd..0000000 --- a/src/papi_profiler.cc +++ /dev/null @@ -1,99 +0,0 @@ -// -// Copyright © 2023 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "papi_profiler.h" - -#include -#include -#include -#include - -#include -#include -#include - -#include "triton/backend/backend_model.h" - -constexpr uint32_t kInvalidEventHandle = static_cast(~0) - 1; - -void -handle_error(int retval) -{ - throw triton::backend::BackendModelException(TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - ("PAPI error: " + std::string(PAPI_strerror(retval))).c_str())); -} - -class PapiProfiler : public tflite::Profiler { - public: - PapiProfiler() - : supported_event_types_( - static_cast(EventType::DELEGATE_OPERATOR_INVOKE_EVENT) + - static_cast(EventType::OPERATOR_INVOKE_EVENT)) - { - } - - ~PapiProfiler() { PAPI_hl_stop(); } - - // This function wants to return a handle to the profile event, which seems to - // be a unique value. Because we are interested in the op names, we just has - // the op tag to generate the event handle value. - // In the case of and Op event, metadata1 holds the node index, and metadata 2 - // holds the subgraph index - uint32_t BeginEvent( - const char* tag, EventType event_type, int64_t event_metadata1, - int64_t event_metadata2) override - { - if (!ShouldAddEvent(event_type)) { - return kInvalidEventHandle; - } - - // Get a unique name for the papi computation region - std::string trace_event_tag = tag; - trace_event_tag += ("_" + std::to_string(event_metadata1)); - - // Begin tracking counters - int retval = PAPI_hl_region_begin(trace_event_tag.c_str()); - if (retval != PAPI_OK) - handle_error(retval); - - uint32_t event_handle = event_index_++; - papi_regions_[event_handle] = trace_event_tag; - return event_handle; - } - - void EndEvent(uint32_t event_handle) override - { - if (event_handle == kInvalidEventHandle) { - return; - } - int retval = PAPI_hl_region_end(papi_regions_[event_handle].c_str()); - if (retval != PAPI_OK) - handle_error(retval); - } - - protected: - inline bool ShouldAddEvent(EventType event_type) - { - return (static_cast(event_type) & supported_event_types_) != 0; - } - - private: - uint32_t event_index_ = 0; - std::unordered_map papi_regions_; - const uint64_t supported_event_types_; -}; - -std::unique_ptr -MaybeCreatePapiProfiler() -{ - if (getenv("PAPI_EVENTS") == NULL) { - LOG_MESSAGE( - TRITONSERVER_LOG_WARN, - "PAPI_EVENTS not specified, op level profiling disabled"); - return nullptr; - } - return std::unique_ptr(new PapiProfiler()); -} diff --git a/src/tflite.cc b/src/tflite.cc index 134516a..272ae71 100644 --- a/src/tflite.cc +++ b/src/tflite.cc @@ -3,21 +3,35 @@ // SPDX-License-Identifier: MIT // +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include #include #include #include #include +#include #include #include #include +#include #include #include #include #include #include +// Local headers +#include "config.h" #include "tflite_utils.h" + +// Tensorpipe headers +#include "tensorpipe/tensorpipe.h" + +// Triton headers #include "triton/backend/backend_common.h" #include "triton/backend/backend_input_collector.h" #include "triton/backend/backend_memory.h" @@ -33,17 +47,14 @@ #include "tensorflow/lite/model.h" #include "tensorflow/lite/type_to_tflitetype.h" -#ifdef ARMNN_DELEGATE_ENABLE -// ArmNN headers -#include "armnn/ArmNN.hpp" -#include "armnn_delegate.hpp" -#endif // ARMNN_DELEGATE_ENABLE - -#ifdef PAPI_PROFILING_ENABLE -#include +// Reproc headers +#include "reproc++/reproc.hpp" -#include "papi_profiler.h" -#endif // PAPI_PROFILING_ENABLE +#ifdef LIBNUMA_ENABLE +// Lib Numa headers +#include +#include +#endif // LIBNUMA_ENABLE // // TFLite Backend that implements the TRITONBACKEND API. @@ -51,6 +62,31 @@ namespace triton { namespace backend { namespace tensorflowlite { +// Custom object to store global state for this backend +struct ArmNNTFLiteBackendState { + // Map managing list of avail cpus in system, keyed on socket + // TODO: Change this to a bitmap + std::unordered_map> avail_cpus_; + std::unordered_map> used_cpus_; + + explicit ArmNNTFLiteBackendState(const std::vector cpus_to_use) + { + // Start with list of all available CPUs on system + PopulateCpusMap(avail_cpus_); + + // If we have a cpu restriction, modify avail_cpus accordingly + if (!cpus_to_use.empty()) { + for (auto& [socket_id, cpus] : avail_cpus_) { + std::vector valid_cpus; + std::set_union( + cpus_to_use.begin(), cpus_to_use.end(), cpus.begin(), cpus.end(), + std::back_inserter(valid_cpus)); + cpus = std::move(valid_cpus); + } + } + } +}; + // // ModelState // @@ -61,22 +97,18 @@ namespace triton { namespace backend { namespace tensorflowlite { class ModelState : public BackendModel { public: static TRITONSERVER_Error* Create( - TRITONBACKEND_Model* triton_model, ModelState** state, - int32_t* armnn_threads); + TRITONBACKEND_Model* triton_model, ModelState** state); ~ModelState(); - // Load a serialized tflite model using 'artifact_name' as the name for the - // tflite model file. Return in 'model_path' the full path to the - // tflite model file. Return in 'model' the TFLite network, - // representing the model. - TRITONSERVER_Error* LoadModel( - const std::string& artifact_name, std::string* model_path, - common::TritonJson::Value& model_config, - std::unique_ptr* model); + TRITONSERVER_Error* LoadModel(); + + TRITONSERVER_Error* InitConfig(); // Validate that model configuration is supported by this backend. TRITONSERVER_Error* ValidateModelConfig(); + void InitTensorPipe(); + // Default TFLite runtime options int32_t tflite_num_threads_ = static_cast(std::thread::hardware_concurrency()); @@ -85,9 +117,16 @@ class ModelState : public BackendModel { // ArmNN Delegate options bool use_armnn_delegate_cpu_ = false; bool use_armnn_delegate_gpu_ = false; - armnn::OptimizerOptions armnn_optimizer_options_cpu_; - armnn::OptimizerOptions armnn_optimizer_options_gpu_; - int32_t* armnn_threads_; + + int32_t armnn_cpu_num_threads_ = + static_cast(std::thread::hardware_concurrency()); + std::string armnn_cpu_reduce_fp32_to_fp16_ = "off"; + std::string armnn_cpu_reduce_fp32_to_bf16_ = "off"; + std::string armnn_cpu_fast_math_enabled_ = "off"; + + std::string armnn_gpu_fast_math_enabled_ = "off"; + std::string armnn_gpu_reduce_fp32_to_fp16_ = "off"; + std::string armnn_gpu_reduce_fp32_to_bf16_ = "off"; #endif // ARMNN_DELEGATE_ENABLE // XNNPACK Delegate options @@ -104,6 +143,40 @@ class ModelState : public BackendModel { // that output in the model. std::unordered_map output_index_map_; std::unordered_map output_dtype_map_; + std::unordered_map> output_shape_map_; + + // Pointer to shared backend state + ArmNNTFLiteBackendState* backend_state_; + + // The pointer to the tflite network + std::unique_ptr model_; + + // Global context of tensorpipe + std::shared_ptr context_; + + // Path string for the model_instance binary + const char* model_instance_location_; + +#ifdef PAPI_PROFILING_ENABLE + // String holding comma-separated list of events for child inference process + std::string papi_events_ = ""; + + // String holding comma-separated list of uncore events for child inference + // process + std::string papi_uncore_events_ = ""; +#endif // PAPI_PROFILING_ENABLE + + // Numa policy for instance + AllocationPolicy numa_alloc_policy_ = AllocationPolicy::NONE; + + // Local numa node id + int local_numa_node_id_ = 0; + + // Remote numa node id + int remote_numa_node_id_ = 1; + + // pin threads + bool pin_threads_ = false; private: ModelState(TRITONBACKEND_Model* triton_model); @@ -112,9 +185,7 @@ class ModelState : public BackendModel { TRITONSERVER_Error* -ModelState::Create( - TRITONBACKEND_Model* triton_model, ModelState** state, - int32_t* armnn_threads) +ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) { try { *state = new ModelState(triton_model); @@ -126,9 +197,15 @@ ModelState::Create( RETURN_IF_ERROR(ex.err_); } -#ifdef ARMNN_DELEGATE_ENABLE - (*state)->armnn_threads_ = armnn_threads; -#endif + TRITONBACKEND_Backend* backend; + RETURN_IF_ERROR(TRITONBACKEND_ModelBackend(triton_model, &backend)); + void* vbackendstate; + RETURN_IF_ERROR(TRITONBACKEND_BackendState(backend, &vbackendstate)); + RETURN_ERROR_IF_TRUE( + vbackendstate == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr state in TRITONBACKEND_ModelInitialize")); + (*state)->backend_state_ = + reinterpret_cast(vbackendstate); return nullptr; // success } @@ -136,47 +213,26 @@ ModelState::Create( ModelState::ModelState(TRITONBACKEND_Model* triton_model) : BackendModel(triton_model) { // Here we can add information to the model state that can be shared across - // model instances. See onnx backend for example. MALI GPU optimization level - // may be candidate. + // model instances. See onnx backend for example. MALI GPU optimization level + // may be candidate. + InitTensorPipe(); + THROW_IF_BACKEND_MODEL_ERROR(InitConfig()); + THROW_IF_BACKEND_MODEL_ERROR(LoadModel()); + + // Get the directory of the backend to find the path to the model instance + // binary + TRITONBACKEND_Backend* backend; + TRITONBACKEND_ArtifactType artifact_type; + TRITONBACKEND_ModelBackend(triton_model, &backend); + TRITONBACKEND_BackendArtifacts( + backend, &artifact_type, &model_instance_location_); } ModelState::~ModelState() {} TRITONSERVER_Error* -ModelState::LoadModel( - const std::string& artifact_name, std::string* model_path, - common::TritonJson::Value& model_config, - std::unique_ptr* model) +ModelState::InitConfig() { - // Find the TFLite model file that describes the model. If the model - // configuration doesn't have an explicit model file specified then - // use the default name ("model.tflite"). - std::string cc_model_filename = artifact_name; - if (cc_model_filename.empty()) { - cc_model_filename = "model.tflite"; - } - - *model_path = JoinPath( - {RepositoryPath(), std::to_string(Version()), cc_model_filename}); - - { - bool exists; - RETURN_IF_ERROR(FileExists(*model_path, &exists)); - RETURN_ERROR_IF_FALSE( - exists, TRITONSERVER_ERROR_UNAVAILABLE, - std::string("unable to find '") + *model_path + - "' for model instance '" + Name() + "'"); - } - - // Load the Tflite FlatBufferModel into memory - *model = tflite::FlatBufferModel::BuildFromFile((*model_path).c_str()); - - if (!*model) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - ("failed to load model " + Name()).c_str()); - } - // Handle tflite default interpeter options set in parameters { triton::common::TritonJson::Value params; @@ -204,6 +260,98 @@ ModelState::LoadModel( .c_str()); } } + + // Handle pin_threads parameter + err = GetParameterValue(params, "pin_threads", &value_str); + // pin_threads is not required so clear error if not found + if (err != nullptr) { + if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { + return err; + } else { + TRITONSERVER_ErrorDelete(err); + } + } else { + if (value_str == "on") { + pin_threads_ = true; + } else if (value_str == "off") { + pin_threads_ = false; + } else { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("parameter 'pin_threads' must be 'on' or 'off' ") + + Name() + "'") + .c_str()); + } + } + + // Handle numa parameters + err = GetParameterValue(params, "numa_alloc_policy", &value_str); + + // numa_alloc_policy is not required so clear error if not found + if (err != nullptr) { + if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { + return err; + } else { + TRITONSERVER_ErrorDelete(err); + } + } else { + numa_alloc_policy_ = AllocationPolicyFromString(value_str); + } + +#ifdef LIBNUMA_ENABLE + err = GetParameterValue(params, "local_numa_node_id", &value_str); + + // local_numa_node_id is not required so clear error if not found + if (err != nullptr) { + if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { + return err; + } else { + TRITONSERVER_ErrorDelete(err); + } + } else { + RETURN_IF_ERROR(ParseIntValue(value_str, &local_numa_node_id_)); + if (local_numa_node_id_ < 0 || local_numa_node_id_ > numa_max_node()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string( + "parameter 'local_numa_node_id_' must be non-negative " + "or less than max numa node id for tflite model '") + + Name() + "'") + .c_str()); + } + } + + // Handle remote_numa_node_id parameter + err = GetParameterValue(params, "remote_numa_node_id", &value_str); + + // remote_numa_node_id is not required so clear error if not found + if (err != nullptr) { + if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { + return err; + } else { + TRITONSERVER_ErrorDelete(err); + } + } else { + RETURN_IF_ERROR(ParseIntValue(value_str, &remote_numa_node_id_)); + if (remote_numa_node_id_ < 0 || + remote_numa_node_id_ > numa_max_node()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string( + "parameter 'remote_numa_node_id_' must be non-negative " + "or less than max numa node id for tflite model '") + + Name() + "'") + .c_str()); + } + } + +#else + RETURN_ERROR_IF_TRUE( + numa_alloc_policy_ != AllocationPolicy::NONE, + TRITONSERVER_ERROR_INVALID_ARG, + std::string("Backend built without NUMA support, only valid " + "allocation policy is 'NONE'")); +#endif // LIBNUMA_ENABLE } } @@ -241,10 +389,8 @@ ModelState::LoadModel( if (param_key == "reduce_fp32_to_fp16") { RETURN_IF_ERROR(params.MemberAsString( param_key.c_str(), &value_string)); - if (value_string == "on") { - armnn_optimizer_options_cpu_.m_ReduceFp32ToFp16 = true; - } else if (value_string == "off") { - armnn_optimizer_options_cpu_.m_ReduceFp32ToFp16 = false; + if (value_string == "on" || value_string == "off") { + armnn_cpu_reduce_fp32_to_fp16_ = value_string; } else { RETURN_ERROR_IF_FALSE( false, TRITONSERVER_ERROR_INVALID_ARG, @@ -255,10 +401,8 @@ ModelState::LoadModel( } else if (param_key == "reduce_fp32_to_bf16") { RETURN_IF_ERROR(params.MemberAsString( param_key.c_str(), &value_string)); - if (value_string == "on") { - armnn_optimizer_options_cpu_.m_ReduceFp32ToBf16 = true; - } else if (value_string == "off") { - armnn_optimizer_options_cpu_.m_ReduceFp32ToBf16 = false; + if (value_string == "on" || value_string == "off") { + armnn_cpu_reduce_fp32_to_bf16_ = value_string; } else { RETURN_ERROR_IF_FALSE( false, TRITONSERVER_ERROR_INVALID_ARG, @@ -269,16 +413,8 @@ ModelState::LoadModel( } else if (param_key == "fast_math_enabled") { RETURN_IF_ERROR(params.MemberAsString( param_key.c_str(), &value_string)); - if (value_string == "on") { - armnn::BackendOptions option( - "CpuAcc", {{"FastMathEnabled", true}}); - armnn_optimizer_options_cpu_.m_ModelOptions.push_back( - option); - } else if (value_string == "off") { - armnn::BackendOptions option( - "CpuAcc", {{"FastMathEnabled", false}}); - armnn_optimizer_options_cpu_.m_ModelOptions.push_back( - option); + if (value_string == "on" || value_string == "off") { + armnn_cpu_fast_math_enabled_ = value_string; } else { RETURN_ERROR_IF_FALSE( false, TRITONSERVER_ERROR_INVALID_ARG, @@ -287,41 +423,18 @@ ModelState::LoadModel( value_string + "' is requested"); } } else if (param_key == "num_threads") { - int32_t num_threads; RETURN_IF_ERROR(params.MemberAsString( param_key.c_str(), &value_string)); - RETURN_IF_ERROR(ParseIntValue(value_string, &num_threads)); - if (num_threads < 0) { + RETURN_IF_ERROR( + ParseIntValue(value_string, &armnn_cpu_num_threads_)); + if (armnn_cpu_num_threads_ < -1) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, std::string( "armnn thread count '" + value_string + - "' is not in range [1-64]") + "' is not in range [-1-64]") .c_str()); } - - // Here we do an ugly hack to prevent armnn/acl thread - // issues For now we make sure the next armnn accelerated - // model loaded does not request more threads than the - // previous, as this creates a segfault - if (num_threads > *armnn_threads_) { - num_threads = *armnn_threads_; - LOG_MESSAGE( - TRITONSERVER_LOG_INFO, - (std::string("Model threads requested larger than " - "that of first model loaded: ") + - value_string + " > " + - std::to_string(*armnn_threads_) + - ". Using smaller thread value instead.") - .c_str()); - } else { - *armnn_threads_ = num_threads; - } - armnn::BackendOptions option( - "CpuAcc", {{"NumberOfThreads", - static_cast(num_threads)}}); - armnn_optimizer_options_cpu_.m_ModelOptions.push_back( - option); } else { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, @@ -386,7 +499,6 @@ ModelState::LoadModel( RETURN_IF_ERROR(ea.MemberAsString("name", &name)); if (name == "armnn") { use_armnn_delegate_gpu_ = true; - armnn::OptimizerOptions armnn_optimizer_options_gpu_; LOG_MESSAGE( TRITONSERVER_LOG_VERBOSE, (std::string( @@ -403,10 +515,8 @@ ModelState::LoadModel( if (param_key == "reduce_fp32_to_fp16") { RETURN_IF_ERROR(params.MemberAsString( param_key.c_str(), &value_string)); - if (value_string == "on") { - armnn_optimizer_options_gpu_.m_ReduceFp32ToFp16 = true; - } else if (value_string == "off") { - armnn_optimizer_options_gpu_.m_ReduceFp32ToFp16 = false; + if (value_string == "on" || value_string == "off") { + armnn_gpu_reduce_fp32_to_fp16_ == value_string; } else { RETURN_ERROR_IF_FALSE( false, TRITONSERVER_ERROR_INVALID_ARG, @@ -417,10 +527,8 @@ ModelState::LoadModel( } else if (param_key == "reduce_fp32_to_bf16") { RETURN_IF_ERROR(params.MemberAsString( param_key.c_str(), &value_string)); - if (value_string == "on") { - armnn_optimizer_options_gpu_.m_ReduceFp32ToBf16 = true; - } else if (value_string == "off") { - armnn_optimizer_options_gpu_.m_ReduceFp32ToBf16 = false; + if (value_string == "on" || value_string == "off") { + armnn_gpu_reduce_fp32_to_bf16_ == value_string; } else { RETURN_ERROR_IF_FALSE( false, TRITONSERVER_ERROR_INVALID_ARG, @@ -431,16 +539,8 @@ ModelState::LoadModel( } else if (param_key == "fast_math_enabled") { RETURN_IF_ERROR(params.MemberAsString( param_key.c_str(), &value_string)); - if (value_string == "on") { - armnn::BackendOptions option( - "GpuAcc", {{"FastMathEnabled", true}}); - armnn_optimizer_options_gpu_.m_ModelOptions.push_back( - option); - } else if (value_string == "off") { - armnn::BackendOptions option( - "GpuAcc", {{"FastMathEnabled", false}}); - armnn_optimizer_options_gpu_.m_ModelOptions.push_back( - option); + if (value_string == "on" || value_string == "off") { + armnn_gpu_fast_math_enabled_ == value_string; } else { RETURN_ERROR_IF_FALSE( false, TRITONSERVER_ERROR_INVALID_ARG, @@ -472,7 +572,46 @@ ModelState::LoadModel( } } - return nullptr; // success + return nullptr; +} + +TRITONSERVER_Error* +ModelState::LoadModel() +{ + std::string artifact_filename; + RETURN_IF_ERROR(ModelConfig().MemberAsString( + "default_model_filename", &artifact_filename)); + + // Find the TFLite model file that describes the model. If the model + // configuration doesn't have an explicit model file specified then + // use the default name ("model.tflite"). + std::string cc_model_filename = artifact_filename; + if (cc_model_filename.empty()) { + cc_model_filename = "model.tflite"; + } + + std::string model_path = JoinPath( + {RepositoryPath(), std::to_string(Version()), cc_model_filename}); + + { + bool exists; + RETURN_IF_ERROR(FileExists(model_path, &exists)); + RETURN_ERROR_IF_FALSE( + exists, TRITONSERVER_ERROR_UNAVAILABLE, + std::string("unable to find '") + model_path + + "' for model instance '" + Name() + "'"); + } + + // Load the Tflite FlatBufferModel into memory + model_ = tflite::FlatBufferModel::BuildFromFile((model_path).c_str()); + + if (!model_) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + ("failed to load model " + Name()).c_str()); + } + + return nullptr; } TRITONSERVER_Error* @@ -485,19 +624,37 @@ ModelState::ValidateModelConfig() TRITONSERVER_LOG_VERBOSE, (std::string("model configuration:\n") + buffer.Contents()).c_str()); +#ifdef PAPI_PROFILING_ENABLE + // Take this opportunity to handle papi events + triton::common::TritonJson::Value params; + if (ModelConfig().Find("parameters", ¶ms)) { + auto err = GetParameterValue(params, "papi_events", &papi_events_); + // papi_events is not required so clear error if not found + if (err != nullptr) { + if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { + return err; + } else { + TRITONSERVER_ErrorDelete(err); + } + } + + err = GetParameterValue(params, "papi_uncore_events", &papi_uncore_events_); + // papi_events is not required so clear error if not found + if (err != nullptr) { + if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { + return err; + } else { + TRITONSERVER_ErrorDelete(err); + } + } + } +#endif // PAPI_PROFILING_ENABLE + // To check input and output names we will load and release the model during // the validation process without allocating memory for inference - std::string model_path; - std::unique_ptr model; std::unique_ptr interpreter; - std::string artifact_filename; - RETURN_IF_ERROR(ModelConfig().MemberAsString( - "default_model_filename", &artifact_filename)); - RETURN_IF_ERROR( - LoadModel(artifact_filename, &model_path, ModelConfig(), &model)); - tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver; - tflite::InterpreterBuilder builder(*model, resolver); + tflite::InterpreterBuilder builder(*model_, resolver); builder(&interpreter); if (!interpreter) { return TRITONSERVER_ErrorNew( @@ -516,22 +673,36 @@ ModelState::ValidateModelConfig() // Populate input name map for (size_t i = 0; i < num_inputs; i++) { - input_index_map_[interpreter->GetInputName(i)] = inputs[i]; - input_dtype_map_[interpreter->GetInputName(i)] = - ConvertTFLiteTypeToDataType(interpreter->tensor(inputs[i])->type); + TfLiteTensor* input_tensor = interpreter->tensor(inputs[i]); + if (input_tensor->allocation_type == kTfLiteArenaRw) { + // Only worry about inputs that require user input + input_index_map_[input_tensor->name] = inputs[i]; + input_dtype_map_[input_tensor->name] = + ConvertTFLiteTypeToDataType(input_tensor->type); + } } - // Populate output name and dtype map + // Populate output name, dtype, shape map for (size_t i = 0; i < num_outputs; i++) { - output_index_map_[interpreter->GetOutputName(i)] = outputs[i]; - output_dtype_map_[interpreter->GetOutputName(i)] = - ConvertTFLiteTypeToDataType(interpreter->tensor(outputs[i])->type); + TfLiteTensor* output_tensor = interpreter->tensor(outputs[i]); + TfLiteIntArray* tflite_output_tensor_dims = output_tensor->dims; + std::vector output_shape_vector = std::vector( + tflite_output_tensor_dims->data, + (tflite_output_tensor_dims->data + tflite_output_tensor_dims->size)); + output_shape_map_[output_tensor->name] = output_shape_vector; + output_index_map_[output_tensor->name] = outputs[i]; + output_dtype_map_[output_tensor->name] = + ConvertTFLiteTypeToDataType(output_tensor->type); } triton::common::TritonJson::Value ios; // Validate model inputs RETURN_IF_ERROR(ModelConfig().MemberAsArray("input", &ios)); + RETURN_ERROR_IF_FALSE( + input_index_map_.size() == ios.ArraySize(), TRITONSERVER_ERROR_INTERNAL, + std::string( + "Number of required inputs for model does not match provided")); for (size_t i = 0; i < ios.ArraySize(); i++) { triton::common::TritonJson::Value io; @@ -542,42 +713,35 @@ ModelState::ValidateModelConfig() RETURN_IF_ERROR(io.MemberAsString("name", &io_name)); // Return an error if the input name within the model config DNE in model - if (input_index_map_.count(io_name) == 0) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_NOT_FOUND, - std::string( - "Model input: " + std::string(io_name) + - " is not a valid input name for '" + Name() + "'") - .c_str()); - } + RETURN_ERROR_IF_TRUE( + input_index_map_.count(io_name) == 0, TRITONSERVER_ERROR_NOT_FOUND, + std::string( + "Model input: " + std::string(io_name) + + " is not a valid input name for '" + Name() + "'")); + // Validate data type std::string io_dtype; RETURN_IF_ERROR(io.MemberAsString("data_type", &io_dtype)); const auto pr = ModelConfigDataTypeToTFLiteType(io_dtype); - if (!pr.first) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - ("unsupported datatype " + io_dtype + " for input '" + io_name + - "' for model '" + Name() + "'") - .c_str()); - } + RETURN_ERROR_IF_TRUE( + !pr.first, TRITONSERVER_ERROR_INTERNAL, + ("unsupported datatype " + io_dtype + " for input '" + io_name + + "' for model '" + Name() + "'")); // Validate datatype matches expected from model TRITONSERVER_DataType config_dtype = TRITONSERVER_StringToDataType(io_dtype.substr(strlen("TYPE_")).c_str()); - if (config_dtype != input_dtype_map_[io_name]) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - ("data type " + io_dtype + " for input '" + io_name + - "' does not match expected of '" + - TRITONSERVER_DataTypeString(input_dtype_map_[io_name]) + "'" + - "' for model '" + Name() + "'") - .c_str()); - } + RETURN_ERROR_IF_TRUE( + config_dtype != input_dtype_map_[io_name], TRITONSERVER_ERROR_INTERNAL, + ("data type " + io_dtype + " for input '" + io_name + + "' does not match expected of '" + + TRITONSERVER_DataTypeString(input_dtype_map_[io_name]) + "'" + + "' for model '" + Name() + "'")); // Validate input shape matches expected from model - TfLiteIntArray* tflite_dims = interpreter->tensor(inputs[i])->dims; + const TfLiteIntArray* tflite_dims = + interpreter->tensor(inputs[i])->dims_signature; std::vector model_input_shape( tflite_dims->data, tflite_dims->data + tflite_dims->size); @@ -591,10 +755,10 @@ ModelState::ValidateModelConfig() RETURN_IF_ERROR(ParseShape(io, "dims", &config_input_shape)); } if (max_batch_size_ > 0) { - // if batching is supported, you tflite doesn't encode -1 as - // the dim like tf does, it's just a 1. So just insert a 1 as the - // batch dim for the config input shape to see if it lines up - config_input_shape.insert(config_input_shape.begin(), 1); + // if batching is supported, tflite encodes -1 as the signature dim like + // tf does. So just insert a -1 as the batch dim for the config input + // shape to see if it lines up + config_input_shape.insert(config_input_shape.begin(), -1); } if (config_input_shape != model_input_shape) { return TRITONSERVER_ErrorNew( @@ -620,41 +784,33 @@ ModelState::ValidateModelConfig() RETURN_IF_ERROR(io.MemberAsString("name", &io_name)); // Return an error if the output name within the model config DNE in model - if (output_index_map_.count(io_name) == 0) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_NOT_FOUND, - std::string( - "Model output: " + std::string(io_name) + - " is not a valid output name for '" + Name() + "'") - .c_str()); - } + RETURN_ERROR_IF_TRUE( + output_index_map_.count(io_name) == 0, TRITONSERVER_ERROR_NOT_FOUND, + std::string( + "Model output: " + std::string(io_name) + + " is not a valid output name for '" + Name() + "'")); // Validate data type std::string io_dtype; RETURN_IF_ERROR(io.MemberAsString("data_type", &io_dtype)); const auto pr = ModelConfigDataTypeToTFLiteType(io_dtype); - if (!pr.first) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - ("unsupported datatype " + io_dtype + " for output '" + io_name + - "' for model '" + Name() + "'") - .c_str()); - } + RETURN_ERROR_IF_TRUE( + !pr.first, TRITONSERVER_ERROR_INTERNAL, + ("unsupported datatype " + io_dtype + " for output '" + io_name + + "' for model '" + Name() + "'")); // Validate datatype matches expected from model TRITONSERVER_DataType config_dtype = TRITONSERVER_StringToDataType(io_dtype.substr(strlen("TYPE_")).c_str()); - if (config_dtype != output_dtype_map_[io_name]) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - ("data type " + io_dtype + " for output '" + io_name + - "' does not match expected of '" + - TRITONSERVER_DataTypeString(output_dtype_map_[io_name]) + "'" + - "' for model '" + Name() + "'") - .c_str()); - } + RETURN_ERROR_IF_TRUE( + config_dtype != output_dtype_map_[io_name], TRITONSERVER_ERROR_INTERNAL, + ("data type " + io_dtype + " for output '" + io_name + + "' does not match expected of '" + + TRITONSERVER_DataTypeString(output_dtype_map_[io_name]) + "'" + + "' for model '" + Name() + "'")); // Validate output shape matches expected from model - TfLiteIntArray* tflite_dims = interpreter->tensor(outputs[i])->dims; + const TfLiteIntArray* tflite_dims = + interpreter->tensor(outputs[i])->dims_signature; std::vector model_output_shape( tflite_dims->data, tflite_dims->data + tflite_dims->size); @@ -668,17 +824,15 @@ ModelState::ValidateModelConfig() RETURN_IF_ERROR(ParseShape(io, "dims", &config_output_shape)); } if (max_batch_size_ > 0) { - config_output_shape.insert(config_output_shape.begin(), 1); - } - if (config_output_shape != model_output_shape) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - ("shape " + VectorToString(config_output_shape) + " for output '" + - io_name + "' does not match expected of '" + - VectorToString(model_output_shape) + "'" + "' for model '" + - Name() + "'") - .c_str()); + config_output_shape.insert(config_output_shape.begin(), -1); } + RETURN_ERROR_IF_TRUE( + config_output_shape != model_output_shape, + TRITONSERVER_ERROR_INTERNAL, + ("shape " + VectorToString(config_output_shape) + " for output '" + + io_name + "' does not match expected of '" + + VectorToString(model_output_shape) + "'" + "' for model '" + Name() + + "'")); } } @@ -698,19 +852,32 @@ ModelState::AutoCompleteConfig() return nullptr; // success } +void +ModelState::InitTensorPipe() +{ + context_ = std::make_shared(); + auto transportContext = tensorpipe::transport::shm::create(); + // Consider here also registering tcp transport if shm not avail + context_->registerTransport(0 /* priority */, "shm", transportContext); + // Register cma shm channel + auto cmaChannel = tensorpipe::channel::cma::create(); + context_->registerChannel(0 /* low priority */, "cma", cmaChannel); +} // // ModelInstanceState // // State associated with a model instance. An object of this class is // created and associated with each TRITONBACKEND_ModelInstance. +// This class acts as a manager for a subprocess which handles the actual tflite +// inference. // class ModelInstanceState : public BackendModelInstance { public: static TRITONSERVER_Error* Create( ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, - ModelInstanceState** state); + const std::string& model_instance_name, ModelInstanceState** state); virtual ~ModelInstanceState(); // Get the state of the model that corresponds to this instance. @@ -723,49 +890,54 @@ class ModelInstanceState : public BackendModelInstance { private: ModelInstanceState( ModelState* model_state, - TRITONBACKEND_ModelInstance* triton_model_instance); - TRITONSERVER_Error* BuildInterpreter(); - void LogDelegation(const std::string& delegate_name); - void Execute( - std::vector* responses, - const uint32_t response_count); - void SetInputTensors( + TRITONBACKEND_ModelInstance* triton_model_instance, + const std::string& model_instance_name); + TRITONSERVER_Error* ConnectModelInstance(); + TRITONSERVER_Error* SendModel(); + TRITONSERVER_Error* LaunchModelInstance(); + bool ModelInstanceRunning(); + TRITONSERVER_Error* SetInputTensors( size_t total_batch_size, TRITONBACKEND_Request** requests, const uint32_t request_count, std::vector* responses, BackendInputCollector* collector, - std::vector* input_memories); - void ReadOutputTensors( + std::vector* input_memories, tensorpipe::Message* tp_msg); + TRITONSERVER_Error* Execute( + std::vector* responses, + const uint32_t response_count, tensorpipe::Message* tp_msg, + std::unordered_map>& inference_output); + TRITONSERVER_Error* ReadOutputTensors( size_t total_batch_size, TRITONBACKEND_Request** requests, const uint32_t request_count, - std::vector* responses); + std::vector* responses, + const std::unordered_map>& + inference_output); + // Pointer to the model state shared between instances ModelState* model_state_; - // The full path to the TFLite model file. - std::string model_path_; + // Name of the model instance used as a unique indenfier for this + // instance + const std::string model_instance_name_; - // The pointer to the tflite network - std::unique_ptr model_; + // Tensorpipe listener to establish connection with child process + std::shared_ptr listener_{nullptr}; - // The pointer to the tflite interpreter instance - std::unique_ptr interpreter_; + // Tensorpipe to send input tensors over + std::shared_ptr pipe_; - // State variable to register whether inference has been called at least once - bool first_inference_ = true; - -#ifdef PAPI_PROFILING_ENABLE - std::unique_ptr papi_profiler_ = MaybeCreatePapiProfiler(); -#endif // PAPI_PROFILING_ENABLE + // Process object for our backend model instance + reproc::process model_instance_process_; }; TRITONSERVER_Error* ModelInstanceState::Create( ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, - ModelInstanceState** state) + const std::string& model_instance_name, ModelInstanceState** state) { try { - *state = new ModelInstanceState(model_state, triton_model_instance); + *state = new ModelInstanceState( + model_state, triton_model_instance, model_instance_name); } catch (const BackendModelInstanceException& ex) { RETURN_ERROR_IF_TRUE( @@ -778,170 +950,358 @@ ModelInstanceState::Create( } ModelInstanceState::ModelInstanceState( - ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance) + ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance, + const std::string& model_instance_name) : BackendModelInstance(model_state, triton_model_instance), - model_state_(model_state) + model_state_(model_state), model_instance_name_(model_instance_name) { - // Load the TFLite network - THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel( - ArtifactFilename(), &model_path_, model_state->ModelConfig(), &model_)); - - // Build interpreter - THROW_IF_BACKEND_INSTANCE_ERROR(BuildInterpreter()); - -#ifdef PAPI_PROFILING_ENABLE - interpreter_->AddProfiler(papi_profiler_.get()); -#endif // PAPI_PROFILING_ENABLE + THROW_IF_BACKEND_INSTANCE_ERROR(LaunchModelInstance()); } ModelInstanceState::~ModelInstanceState() { - // Consider the function ReleaseNonPersistentMemory here for our interpreter - interpreter_.reset(); + // Cleanup tensorpipe and reproc process + pipe_->close(); + listener_->close(); + reproc::stop_actions stop = { + {reproc::stop::terminate, reproc::milliseconds(10000)}, + {reproc::stop::kill, reproc::milliseconds(2000)}, + {}}; + reproc::options options; + options.stop = stop; + std::error_code ec; + int status = 0; + std::tie(status, ec) = model_instance_process_.stop(options.stop); + if (ec) { + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Failed to stop child process"); + } + + // Give back cpus to avail_cpus backend state object + std::vector& avail_cpus = + model_state_->backend_state_ + ->avail_cpus_[model_state_->local_numa_node_id_]; + avail_cpus.insert( + avail_cpus.begin(), + model_state_->backend_state_->used_cpus_[model_instance_name_].begin(), + model_state_->backend_state_->used_cpus_[model_instance_name_].end()); } TRITONSERVER_Error* -ModelInstanceState::BuildInterpreter() +ModelInstanceState::LaunchModelInstance() { - // Build the tflite interpreter - tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver; - tflite::InterpreterBuilder builder(*model_, resolver); - builder(&interpreter_); - if (!interpreter_) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - ("failed to build tflite interpreter for model " + Name()).c_str()); - } + // Start listening for child process to connect to shm channel + listener_ = model_state_->context_->listen({"shm://" + model_instance_name_}); + auto done = std::make_shared>(); + listener_->accept([&, this]( + const tensorpipe::Error& error, + std::shared_ptr pipe) { + // When the child process connects, we act here in this lambda function + if (error) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + (std::string("Unexpected error when accepting incoming pipe: ") + + error.what()) + .c_str()); - // Tell interpreter to use max threads available to system - if (interpreter_->SetNumThreads(model_state_->tflite_num_threads_) != - kTfLiteOk) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - ("failed to set number of threads for interpreter for model " + Name()) - .c_str()); + done->set_value(false); + return; + } + pipe_ = std::move(pipe); + done->set_value(true); + }); + + std::vector model_instance_args = { + std::string(model_state_->model_instance_location_) + "/model_instance", + std::string("shm://") + model_instance_name_}; + +#ifdef LIBNUMA_ENABLE + // Model instance will always be pinned to numa node set as local, it's + // the membinding we change + switch (model_state_->numa_alloc_policy_) { + case AllocationPolicy::LOCAL: + case AllocationPolicy::WEIGHT_REMOTE_RESULT_LOCAL: + // In the case of local result tensors (heap), membind to local numa node + model_instance_args.insert( + model_instance_args.begin(), + {"numactl", "--membind", + std::to_string(model_state_->local_numa_node_id_), "--cpunodebind", + std::to_string(model_state_->local_numa_node_id_)}); + break; + case AllocationPolicy::WEIGHT_LOCAL_RESULT_REMOTE: + case AllocationPolicy::REMOTE: + // In the case of remote result tensors (heap), membind to remote numa + // node + model_instance_args.insert( + model_instance_args.begin(), + {"numactl", "--membind", + std::to_string(model_state_->remote_numa_node_id_), "--cpunodebind", + std::to_string(model_state_->local_numa_node_id_)}); + break; + default: { + break; + } } +#endif // LIBNUMA_ENABLE -#ifdef ARMNN_DELEGATE_ENABLE - bool armnn_gpu_delegate_enabled = - model_state_->use_armnn_delegate_gpu_ && - Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU; - bool armnn_cpu_delegate_enabled = - model_state_->use_armnn_delegate_cpu_ && - Kind() == TRITONSERVER_INSTANCEGROUPKIND_CPU; - if (armnn_cpu_delegate_enabled || armnn_gpu_delegate_enabled) { - armnnDelegate::DelegateOptions armnn_delegate_options = - armnnDelegate::TfLiteArmnnDelegateOptionsDefault(); - - // Set backend prefs based on gpu or cpu selection - if (armnn_gpu_delegate_enabled) { - armnn_delegate_options.SetBackends( - {armnn::Compute::GpuAcc, armnn::Compute::CpuAcc}); - armnn_delegate_options.SetOptimizerOptions( - model_state_->armnn_optimizer_options_gpu_); - } else { - // Set backend pref to Neon ACL backend - armnn_delegate_options.SetBackends({armnn::Compute::CpuAcc}); - armnn_delegate_options.SetOptimizerOptions( - model_state_->armnn_optimizer_options_cpu_); - } + if (model_state_->pin_threads_) { + // CPUS affinity always set to local node + std::vector& avail_cpus = + model_state_->backend_state_ + ->avail_cpus_[model_state_->local_numa_node_id_]; - // Create ArmNN Delegate with options registered in model state - std::unique_ptr< - TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)> - armnn_delegate( - armnnDelegate::TfLiteArmnnDelegateCreate(armnn_delegate_options), - armnnDelegate::TfLiteArmnnDelegateDelete); + RETURN_ERROR_IF_TRUE( + avail_cpus.empty(), TRITONSERVER_ERROR_INTERNAL, + std::string("not enough cpus left in system to pin on.")); + + // Assign cpus with max assignment being all cpus if thread count > num + // cores + int end_idx = std::min( + static_cast(model_state_->tflite_num_threads_), + static_cast(avail_cpus.size())); + model_state_->backend_state_->used_cpus_[model_instance_name_] = + std::vector(avail_cpus.begin(), avail_cpus.begin() + end_idx); + avail_cpus.erase(avail_cpus.begin(), avail_cpus.begin() + end_idx); + } - // Instruct the Interpreter to use the armnnDelegate - if (interpreter_->ModifyGraphWithDelegate(std::move(armnn_delegate)) != - kTfLiteOk) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - ("failed to use armnn delegate for model " + Name()).c_str()); - } - LogDelegation("armnn"); - } else if ( - model_state_->use_xnnpack_delegate_ && - Kind() == TRITONSERVER_INSTANCEGROUPKIND_CPU) { -#else - if (model_state_->use_xnnpack_delegate_ && - Kind() == TRITONSERVER_INSTANCEGROUPKIND_CPU) { -#endif // ARMNN_DELEGATE_ENABLE - // Create the XNNPack Delegate - TfLiteXNNPackDelegateOptions options = - TfLiteXNNPackDelegateOptionsDefault(); + // We have the model_instance process inherit the parent's standard streams + // so the it reads directly from the stdin and writes directly to the + // stdout/stderr triton. + reproc::options options; + options.redirect.out.type = reproc::redirect::type::parent; + options.redirect.err.type = reproc::redirect::type::parent; + options.env.behavior = reproc::env::extend; + + // For the child process to use Triton logging infra, we have to give it the + // location of the actual tritonserver.so lib, as the backend is just linked + // against a stub + std::string* tritonserver_lib_path; + dl_iterate_phdr( + [](struct dl_phdr_info* info, size_t size, void* data) -> int { + if (std::string(info->dlpi_name).find("tritonserver.so") != + std::string::npos) { + *(reinterpret_cast(data)) = + new std::string(info->dlpi_name); + return 1; + } + return 0; + }, + &tritonserver_lib_path); - options.num_threads = model_state_->num_threads_xnnpack_; + auto base_path = [](const std::string& str) -> std::string { + size_t found; + found = str.find_last_of("/\\"); + return str.substr(0, found); + }; - tflite::Interpreter::TfLiteDelegatePtr xnnpack_delegate( - TfLiteXNNPackDelegateCreate(&options), - [](TfLiteDelegate* xnnpack_delegate) { - TfLiteXNNPackDelegateDelete(xnnpack_delegate); - }); + std::unordered_map model_instance_env{ + {"LD_LIBRARY_PATH", base_path(*tritonserver_lib_path)}}; - // Instruct the Interpreter to use the xnnpack - if (interpreter_->ModifyGraphWithDelegate(std::move(xnnpack_delegate)) != - kTfLiteOk) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - ("failed to use xnnpack delegate for model " + Name()).c_str()); - } - LogDelegation("xnnpack"); +#ifdef PAPI_PROFILING_ENABLE + if (!model_state_->papi_events_.empty()) { + model_instance_env.insert({"PAPI_EVENTS", model_state_->papi_events_}); + } + if (!model_state_->papi_uncore_events_.empty()) { + model_instance_env.insert( + {"PAPI_UNCORE_EVENTS", model_state_->papi_uncore_events_}); } +#endif // PAPI_PROFILING_ENABLE + + options.env.extra = model_instance_env; + + std::error_code ec = + model_instance_process_.start(model_instance_args, options); + + RETURN_ERROR_IF_TRUE( + ec == std::errc::no_such_file_or_directory, TRITONSERVER_ERROR_INTERNAL, + std::string( + "model_instance binary not found. Make sure it's available from the " + "PATH.")); + RETURN_ERROR_IF_TRUE( + ec, TRITONSERVER_ERROR_INTERNAL, + (std::string("Failed to launch model instance process: ") + + ec.message())); + + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Launched model instance: ") + model_instance_name_) + .c_str()); + + // If the process did not come up in time something has gone wrong + RETURN_ERROR_IF_TRUE( + done->get_future().wait_for(std::chrono::seconds(5)) == + std::future_status::timeout, + TRITONSERVER_ERROR_INTERNAL, + std::string( + "Model instance failed: process did not connect back to parent")); + + // Send the model across the wire to the instance + SendModel(); return nullptr; } -void -ModelInstanceState::LogDelegation(const std::string& delegate_name) +bool +ModelInstanceState::ModelInstanceRunning() { - std::unordered_set checked_node_ids; - uint32_t num_delegated_kernels = 0; - for (uint64_t i = 0; i < interpreter_->execution_plan().size(); i++) { - int32_t node_id = interpreter_->execution_plan()[i]; - if (checked_node_ids.find(node_id) != checked_node_ids.end()) { - continue; - } - const TfLiteNode& node = - interpreter_->node_and_registration(node_id)->first; + int events = 0; + std::error_code ec; + std::tie(events, ec) = model_instance_process_.poll( + reproc::event::exit, reproc::milliseconds(1000)); + return !ec && ((events & reproc::event::exit) != 0); +} - if (node.delegate != nullptr) { - num_delegated_kernels++; - checked_node_ids.insert(node_id); - } +TRITONSERVER_Error* +ModelInstanceState::SendModel() +{ + tensorpipe::Message tp_msg; + tp_msg.metadata = "model_load"; + + // Size the payloads vector + tp_msg.payloads.resize(OptimizerOption::COUNT + 1); + + // Place deserialized flatbuffer model in msg payload field + const tflite::Allocation* model_allocation = + model_state_->model_->allocation(); + tensorpipe::Message::Payload model_payload{ + .data = const_cast(model_allocation->base()), + .length = model_allocation->bytes(), + .metadata = std::string(model_instance_name_), + }; + tp_msg.payloads[OptimizerOption::COUNT] = model_payload; + + // Define a helper function for generating payloads for our options + auto gen_metadata = [](std::string s) { + tensorpipe::Message::Payload result{.metadata = s}; + return result; + }; + + // Add in model configuration data to message + tp_msg.payloads[OptimizerOption::TFLITE_NUM_THREADS] = + gen_metadata(std::to_string(model_state_->tflite_num_threads_)); + + // Add in numa config data to message + tp_msg.payloads[OptimizerOption::NUMA_ALLOC_POLICY] = + gen_metadata(AllocationPolicyToString(model_state_->numa_alloc_policy_)); + + tp_msg.payloads[OptimizerOption::NUMA_LOCAL_NODE_ID] = + gen_metadata(std::to_string(model_state_->local_numa_node_id_)); + + tp_msg.payloads[OptimizerOption::NUMA_REMOTE_NODE_ID] = + gen_metadata(std::to_string(model_state_->remote_numa_node_id_)); + + // Add in use xnnpack + std::string use_xnnpack = std::string("n"); + if (model_state_->use_xnnpack_delegate_ && + Kind() == TRITONSERVER_INSTANCEGROUPKIND_CPU) { + use_xnnpack = std::string("y"); } - bool fully_delegated = - (num_delegated_kernels == 1 && - interpreter_->execution_plan().size() == 1); + tp_msg.payloads[OptimizerOption::XNNPACK_ENABLE] = gen_metadata(use_xnnpack); - if (fully_delegated) { - LOG_MESSAGE( - TRITONSERVER_LOG_INFO, ("Applied " + delegate_name + - " delegate, and the model graph will be " - "completely executed by the delegate.") - .c_str()); - } else if (num_delegated_kernels > 0) { - LOG_MESSAGE( - TRITONSERVER_LOG_INFO, - ("Applied " + delegate_name + - " delegate, and the model graph will be paritally executed by the " - "delegate w/ " + - std::to_string(num_delegated_kernels) + " delegate kernels.") - .c_str()); - } else { - LOG_MESSAGE( - TRITONSERVER_LOG_INFO, ("Though " + delegate_name + - " delegate is applied, the model graph will " - "not be executed by the delegate.") - .c_str()); + // Add in xnnpack threads + tp_msg.payloads[OptimizerOption::XNNPACK_CPU_NUM_THREADS] = + gen_metadata(std::to_string(model_state_->num_threads_xnnpack_)); + +#ifdef ARMNN_DELEGATE_ENABLE + // Add in use armnn cpu + std::string use_armnn_cpu = std::string("n"); + if (model_state_->use_armnn_delegate_cpu_ && + Kind() == TRITONSERVER_INSTANCEGROUPKIND_CPU) { + use_armnn_cpu = std::string("y"); + } + tp_msg.payloads[OptimizerOption::ARMNN_CPU_ENABLE] = + gen_metadata(use_armnn_cpu); + + // Add in use armnn gpu + std::string use_armnn_gpu = std::string("n"); + if (model_state_->use_armnn_delegate_gpu_ && + Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { + use_armnn_gpu = std::string("y"); } + tp_msg.payloads[OptimizerOption::ARMNN_GPU_ENABLE] = + gen_metadata(use_armnn_gpu); + + // Add in armnn threads + tp_msg.payloads[OptimizerOption::ARMNN_CPU_NUM_THREADS] = + gen_metadata(std::to_string(model_state_->armnn_cpu_num_threads_)); + + // Add in armnn cpu and gpu options + tp_msg.payloads[OptimizerOption::ARMNN_CPU_FAST_MATH_ENABLED] = + gen_metadata(model_state_->armnn_cpu_fast_math_enabled_); + + tp_msg.payloads[OptimizerOption::ARMNN_CPU_REDUCE_FP32_TO_FP16] = + gen_metadata(model_state_->armnn_cpu_reduce_fp32_to_fp16_); + + tp_msg.payloads[OptimizerOption::ARMNN_CPU_REDUCE_FP32_TO_BF16] = + gen_metadata(model_state_->armnn_cpu_reduce_fp32_to_bf16_); + + tp_msg.payloads[OptimizerOption::ARMNN_GPU_FAST_MATH_ENABLED] = + gen_metadata(model_state_->armnn_gpu_fast_math_enabled_); + + tp_msg.payloads[OptimizerOption::ARMNN_GPU_REDUCE_FP32_TO_BF16] = + gen_metadata(model_state_->armnn_gpu_reduce_fp32_to_bf16_); + + tp_msg.payloads[OptimizerOption::ARMNN_GPU_REDUCE_FP32_TO_FP16] = + gen_metadata(model_state_->armnn_gpu_reduce_fp32_to_fp16_); +#endif // ARMNN_DELEGATE_ENABLE + + if (model_state_->pin_threads_) { + // The rest of the remaining spots will go to what cpus to use for inference + for (auto& cpuid : + model_state_->backend_state_->used_cpus_[model_instance_name_]) { + tp_msg.payloads.push_back(gen_metadata(std::to_string(cpuid))); + } + } + + // Write the message + auto done = std::make_shared>(); + pipe_->write(tp_msg, [this, done](const tensorpipe::Error& error) { + // We now listen for a message to come back indicating the model load was + // successful + if (error) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + ("Failed to model load message. Details:" + error.what()).c_str()); + done->set_value(false); + return; + } + pipe_->readDescriptor([this, done]( + const tensorpipe::Error& error, + tensorpipe::Descriptor descriptor) { + if (error) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + (std::string("Unexpected error when reading from accepted pipe: ") + + error.what()) + .c_str()); + done->set_value(false); + return; + } + tensorpipe::Allocation allocation; + pipe_->read( + allocation, [descriptor, done](const tensorpipe::Error& error) { + done->set_value(descriptor.metadata == "success"); + }); + }); + }); + RETURN_ERROR_IF_TRUE( + done->get_future().wait_for(std::chrono::seconds(30)) == + std::future_status::timeout, + TRITONSERVER_ERROR_INTERNAL, + std::string("Model instance failed: process did not send model load " + "acknowledgement")); + return nullptr; } void ModelInstanceState::ProcessRequests( TRITONBACKEND_Request** requests, const uint32_t request_count) { + LOG_MESSAGE( + TRITONSERVER_LOG_VERBOSE, + (std::string("TRITONBACKEND_ModelExecute: Running ") + Name() + " with " + + std::to_string(request_count) + " requests") + .c_str()); + uint64_t exec_start_ns = 0; SET_TIMESTAMP(exec_start_ns); @@ -964,31 +1324,6 @@ ModelInstanceState::ProcessRequests( .c_str())); return; } - - if (max_batch_size > 0) { - // Retrieve the batch size from one of the inputs, if the model - // supports batching, the first dimension size is batch size - TRITONBACKEND_Input* input; - TRITONSERVER_Error* err = - TRITONBACKEND_RequestInputByIndex(requests[i], 0 /* index */, &input); - if (err == nullptr) { - const int64_t* shape; - err = TRITONBACKEND_InputProperties( - input, nullptr, nullptr, &shape, nullptr, nullptr, nullptr); - total_batch_size += shape[0]; - } - if (err != nullptr) { - RequestsRespondWithError(requests, request_count, err); - return; - } - } else { - total_batch_size += 1; - } - } - - // If there are no valid payloads then no need to run the inference. - if (total_batch_size == 0) { - return; } // Make sure the maximum batch size is not exceeded. The @@ -1022,8 +1357,9 @@ ModelInstanceState::ProcessRequests( // can skip them in the output tensors). std::vector responses; responses.reserve(request_count); + bool all_response_failed = false; - for (size_t i = 0; i < request_count; i++) { + for (size_t i = 0; i < request_count; ++i) { TRITONBACKEND_Response* response; auto err = TRITONBACKEND_ResponseNew(&response, requests[i]); if (err == nullptr) { @@ -1035,22 +1371,86 @@ ModelInstanceState::ProcessRequests( } } - std::vector input_memories; - BackendInputCollector collector( - requests, request_count, &responses, model_state_->TritonMemoryManager(), - false, nullptr); + for (size_t i = 0; i < request_count; i++) { + if (max_batch_size > 0) { + // Retrieve the batch size from one of the inputs, if the model + // supports batching, the first dimension size is batch size + TRITONBACKEND_Input* input; + TRITONSERVER_Error* err = + TRITONBACKEND_RequestInputByIndex(requests[i], 0 /* index */, &input); + if (err == nullptr) { + const int64_t* shape; + err = TRITONBACKEND_InputProperties( + input, nullptr, nullptr, &shape, nullptr, nullptr, nullptr); + total_batch_size += shape[0]; + } + if (err != nullptr) { + RESPOND_ALL_AND_SET_TRUE_IF_ERROR( + responses, request_count, all_response_failed, err); + } + } else { + total_batch_size += 1; + } + } - // Note here we are copying the triton input buffers to the tflite allocated - // buffers - SetInputTensors( - total_batch_size, requests, request_count, &responses, &collector, - &input_memories); + // If there are no valid payloads then no need to run the inference. + if (total_batch_size == 0) { + return; + } + + // Make sure the maximum batch size is not exceeded. The + // total_batch_size must be 1 for models that don't support batching + // (i.e. max_batch_size == 0). If max_batch_size is exceeded then + // scheduler has done something badly wrong so fail and release all + // requests. + if (!all_response_failed) { + if ((total_batch_size != 1) && + (total_batch_size > (size_t)max_batch_size)) { + RESPOND_ALL_AND_SET_TRUE_IF_ERROR( + responses, request_count, all_response_failed, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string( + "batch size " + std::to_string(total_batch_size) + " for '" + + Name() + "', max allowed is " + + std::to_string(max_batch_size)) + .c_str())); + } + } + + // Here we allocate the space for the tensorpipe message that's used to + // communicate with our backend ModelInstance process + tensorpipe::Message tp_msg; + + // Here we allocate the space for the tensorpipe allocation that the result of + // the inference is written to upon success + std::unordered_map> inference_output; + + std::vector input_memories; + std::unique_ptr collector; + + if (!all_response_failed) { + collector.reset(new BackendInputCollector( + requests, request_count, &responses, + model_state_->TritonMemoryManager(), false, nullptr)); + // Note here we are copying the triton input buffers to the tflite allocated + // buffers + RESPOND_ALL_AND_SET_TRUE_IF_ERROR( + responses, request_count, all_response_failed, + SetInputTensors( + total_batch_size, requests, request_count, &responses, + collector.get(), &input_memories, &tp_msg)); + } uint64_t compute_start_ns = 0; SET_TIMESTAMP(compute_start_ns); // Run... - Execute(&responses, request_count); + if (!all_response_failed) { + RESPOND_ALL_AND_SET_TRUE_IF_ERROR( + responses, request_count, all_response_failed, + Execute(&responses, request_count, &tp_msg, inference_output)); + } uint64_t compute_end_ns = 0; SET_TIMESTAMP(compute_end_ns); @@ -1061,7 +1461,13 @@ ModelInstanceState::ProcessRequests( } input_memories.clear(); - ReadOutputTensors(total_batch_size, requests, request_count, &responses); + if (!all_response_failed) { + RESPOND_ALL_AND_SET_TRUE_IF_ERROR( + responses, request_count, all_response_failed, + ReadOutputTensors( + total_batch_size, requests, request_count, &responses, + inference_output)); + } uint64_t exec_end_ns = 0; SET_TIMESTAMP(exec_end_ns); @@ -1080,7 +1486,7 @@ ModelInstanceState::ProcessRequests( } // Report statistics for each request. - for (uint32_t r = 0; r < request_count; ++r) { + for (uint64_t r = 0; r < request_count; ++r) { auto& request = requests[r]; LOG_IF_ERROR( TRITONBACKEND_ModelInstanceReportStatistics( @@ -1094,73 +1500,56 @@ ModelInstanceState::ProcessRequests( "failed releasing request"); } - // Report the entire batch statistics. - LOG_IF_ERROR( - TRITONBACKEND_ModelInstanceReportBatchStatistics( - TritonModelInstance(), total_batch_size, exec_start_ns, - compute_start_ns, compute_end_ns, exec_end_ns), - "failed reporting batch request statistics"); -} - -void -ModelInstanceState::Execute( - std::vector* responses, - const uint32_t response_count) -{ - static TfLiteStatus status; - status = interpreter_->Invoke(); - if (status != kTfLiteOk) { - SendErrorForResponses( - responses, response_count, - TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, ("TFLite execute failure"))); + if (!all_response_failed) { + // Report the entire batch statistics. + LOG_IF_ERROR( + TRITONBACKEND_ModelInstanceReportBatchStatistics( + TritonModelInstance(), total_batch_size, exec_start_ns, + compute_start_ns, compute_end_ns, exec_end_ns), + "failed reporting batch request statistics"); } - first_inference_ = false; } -void +TRITONSERVER_Error* ModelInstanceState::SetInputTensors( size_t total_batch_size, TRITONBACKEND_Request** requests, const uint32_t request_count, std::vector* responses, BackendInputCollector* collector, - std::vector* input_memories) + std::vector* input_memories, tensorpipe::Message* tp_msg) { const int32_t max_batch_size = model_state_->MaxBatchSize(); - bool allocate_tensors = false; + + // Construct tensorpipe message + tp_msg->metadata = "model_input"; + tp_msg->tensors.resize(model_state_->input_index_map_.size()); // All requests must have equally-sized input tensors so use any // request as the representative for the input tensors. uint32_t input_count; - RESPOND_ALL_AND_RETURN_IF_ERROR( - responses, request_count, - TRITONBACKEND_RequestInputCount(requests[0], &input_count)); - for (uint32_t input_idx = 0; input_idx < input_count; input_idx++) { + RETURN_IF_ERROR(TRITONBACKEND_RequestInputCount(requests[0], &input_count)); + for (uint64_t input_idx = 0; input_idx < input_count; input_idx++) { TRITONBACKEND_Input* input; - RESPOND_ALL_AND_RETURN_IF_ERROR( - responses, request_count, + RETURN_IF_ERROR( TRITONBACKEND_RequestInputByIndex(requests[0], input_idx, &input)); const char* input_name; TRITONSERVER_DataType input_datatype; const int64_t* input_shape; uint32_t input_dims_count; - RESPOND_ALL_AND_RETURN_IF_ERROR( - responses, request_count, - TRITONBACKEND_InputProperties( - input, &input_name, &input_datatype, &input_shape, - &input_dims_count, nullptr, nullptr)); + uint64_t byte_size; + RETURN_IF_ERROR(TRITONBACKEND_InputProperties( + input, &input_name, &input_datatype, &input_shape, &input_dims_count, + &byte_size, nullptr)); // Return an error if the input name within the request DNE in model if (model_state_->input_index_map_.count(input_name) == 0) { - SendErrorForResponses( - responses, request_count, - TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_NOT_FOUND, - std::string( - "Model input: " + std::string(input_name) + - " is not a valid input name for '" + Name() + "'") - .c_str())); + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_NOT_FOUND, + std::string( + "Model input: " + std::string(input_name) + + " is not a valid input name for '" + Name() + "'") + .c_str()); } // The shape for the entire input patch, [total_batch_size, ...] @@ -1177,159 +1566,169 @@ ModelInstanceState::SetInputTensors( " is: " + std::to_string(total_batch_size) + "\n")) .c_str()); - // Get the batch input tensor shape and compare against the shape of the - // input tensor as is registered with the current interpreter. If the size - // is different from the last call, tell the interpreter to resize the - // input tensor and note that we are going to have to make another call to - // AllocateTensors below - std::vector batchn_tflite_size_vector( - begin(batchn_shape), end(batchn_shape)); - TfLiteIntArray* tflite_input_tensor_dims = - interpreter_->tensor(model_state_->input_index_map_[input_name])->dims; - std::vector tflite_input_shape( - tflite_input_tensor_dims->data, - (tflite_input_tensor_dims->data + tflite_input_tensor_dims->size)); - if (batchn_tflite_size_vector != tflite_input_shape) { - // Resize input tensors based on current total batch size - allocate_tensors = true; - LOG_MESSAGE( - TRITONSERVER_LOG_VERBOSE, - (std::string( - "resizing input " + std::string(input_name) + - " with total batch size: " + std::to_string(total_batch_size) + - "\n")) - .c_str()); - interpreter_->ResizeInputTensor( - model_state_->input_index_map_[input_name], - batchn_tflite_size_vector); - } - } - - // Once we have resized all input tensors in the loop above, - // now we can allocate the memory plan within the tflite runtime if - // necessary - if (allocate_tensors || first_inference_) { - if (interpreter_->AllocateTensors() != kTfLiteOk) { - SendErrorForResponses( - responses, request_count, - TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "TfLite interpreter failed to allocate tensor inputs")); - } - } - - // With the memory now allocated appropriately for all input tensors, we can - // call process tensor for each - for (uint32_t input_idx = 0; input_idx < input_count; input_idx++) { - TRITONBACKEND_Input* input; - RESPOND_ALL_AND_RETURN_IF_ERROR( - responses, request_count, - TRITONBACKEND_RequestInputByIndex(requests[0], input_idx, &input)); - - const char* input_name; - RESPOND_ALL_AND_RETURN_IF_ERROR( - responses, request_count, - TRITONBACKEND_InputProperties( - input, &input_name, nullptr, nullptr, nullptr, nullptr, nullptr)); + // We use the metadata string field to pass the input tensor index. + tp_msg->tensors[input_idx].metadata = + std::to_string(model_state_->input_index_map_[input_name]); // Even if running on MALI GPU, we use CPU memory std::vector> alloc_perference; alloc_perference = {{TRITONSERVER_MEMORY_CPU, 0}}; - const char* input_buffer; size_t batchn_byte_size; TRITONSERVER_MemoryType memory_type; int64_t memory_type_id; - TfLiteTensor* tflite_input_tensor = - interpreter_->tensor(model_state_->input_index_map_[input_name]); - char* tflite_input_buffer = tflite_input_tensor->data.raw; - - // Here we use ProcessTensor to copy the data from triton into the buffer - // allocated by the tflite interpreter. I don't believe the data copy can - // be avoided using the tflite runtime - RESPOND_ALL_AND_RETURN_IF_ERROR( - responses, request_count, - collector->ProcessTensor( - input_name, tflite_input_buffer, tflite_input_tensor->bytes, - alloc_perference, &input_buffer, &batchn_byte_size, &memory_type, - &memory_type_id)); + + // Here we use ProcessTensor to manage the input buffer for the tensor. In + // the overload of this function, the backend input collector manages the + // memory, as opposed to copying it into the destination buffer we could + // pass, `buffer`. At the end of this call, cpu_buffer will point to the + // contiguous memory for the potentially batched input tensors + tensorpipe::CpuBuffer cpu_buffer; + RETURN_IF_ERROR(collector->ProcessTensor( + input_name, nullptr, 0, alloc_perference, + const_cast(reinterpret_cast(&cpu_buffer.ptr)), + &batchn_byte_size, &memory_type, &memory_type_id)); + + // Set the space for the tensors for tensorpipe message + tp_msg->tensors[input_idx].length = static_cast(batchn_byte_size); + tp_msg->tensors[input_idx].buffer = cpu_buffer; } // Finalize Backend Input Collector... collector->Finalize(); + + return nullptr; } -void +TRITONSERVER_Error* +ModelInstanceState::Execute( + std::vector* responses, + const uint32_t response_count, tensorpipe::Message* tp_msg, + std::unordered_map>& inference_output) +{ + // Write tensor across pipe and wait for completion asynchronously + auto done = std::make_shared>(); + pipe_->write( + *tp_msg, + [this, &inference_output, &done](const tensorpipe::Error& error) { + if (error) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + (std::string( + "Failed to send model_input request to server. Details: ") + + error.what()) + .c_str()); + done->set_value(false); + return; + } + // Read a response from the client with description of incoming + // result tensors so we can get ready to write the data + pipe_->readDescriptor([this, &inference_output, &done]( + const tensorpipe::Error& error, + tensorpipe::Descriptor descriptor) { + if (error) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + (std::string( + "Unexpected error when reading descriptor from accepted " + "pipe. Details: ") + + error.what()) + .c_str()); + done->set_value(false); + return; + } + + tensorpipe::Allocation allocation; + + // If there was a problem running the inference we get that back in + // the message metadata + if (descriptor.metadata == "f") { + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Failed to run inference"); + pipe_->read(allocation, [&done](const tensorpipe::Error& error) {}); + done->set_value(false); + return; + } + + // Create a cpu buffer instance and assign its buffer + // pointer to that of the tflite allocated buffer for our + // output tensor + allocation.tensors.resize(descriptor.tensors.size()); + for (uint64_t i = 0; i < descriptor.tensors.size(); ++i) { + inference_output[descriptor.tensors[i].metadata].resize( + descriptor.tensors[i].length); + + allocation.tensors[i].buffer = tensorpipe::CpuBuffer{ + .ptr = static_cast( + inference_output[descriptor.tensors[i].metadata].data())}; + } + + // Read the data from the client response into the tensor + // buffer assigned above + pipe_->read(allocation, [&done](const tensorpipe::Error& error) { + if (error) { + LOG_MESSAGE( + TRITONSERVER_LOG_ERROR, + (std::string( + "Unexpected error when reading data from accepted " + "pipe. Details: ") + + error.what()) + .c_str()); + done->set_value(false); + return; + } + done->set_value(true); + }); + }); + }); + + RETURN_ERROR_IF_FALSE( + done->get_future().get(), TRITONSERVER_ERROR_INTERNAL, + std::string("TFLite execute failure")); + return nullptr; +} + +TRITONSERVER_Error* ModelInstanceState::ReadOutputTensors( size_t total_batch_size, TRITONBACKEND_Request** requests, const uint32_t request_count, - std::vector* responses) + std::vector* responses, + const std::unordered_map>& inference_output) { BackendOutputResponder responder( requests, request_count, responses, model_state_->MaxBatchSize(), model_state_->TritonMemoryManager(), false, nullptr); - for (const auto& map_entry : model_state_->output_index_map_) { - std::string output_name = map_entry.first; - int tensor_index = map_entry.second; - - TfLiteTensor* tflite_output_tensor = interpreter_->tensor(tensor_index); - - // Verify output datatype matches datatype from model config - TRITONSERVER_DataType output_dtype = - ConvertTFLiteTypeToDataType(tflite_output_tensor->type); - TRITONSERVER_DataType config_datatype = - model_state_->output_dtype_map_[output_name]; - if (config_datatype != output_dtype) { - RESPOND_ALL_AND_RETURN_IF_ERROR( - responses, request_count, - TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("unexpected datatype TYPE_") + - TRITONSERVER_DataTypeString(output_dtype) + - " for inference output '" + output_name + "', expecting TYPE_" + - TRITONSERVER_DataTypeString(config_datatype)) - .c_str())); - } - - // Assign data pointer to head of data container for output tensor - const char* output_buffer = - static_cast(tflite_output_tensor->data.raw); - - // Set output shape - std::vector batchn_shape; - TfLiteIntArray* dims = tflite_output_tensor->dims; - for (int32_t i = 0; i < dims->size; i++) { - batchn_shape.push_back(dims->data[i]); + // Respond to each output individually + try { + for (const auto& map_entry : model_state_->output_index_map_) { + const std::string& output_name = map_entry.first; + model_state_->output_shape_map_[output_name][0] = total_batch_size; + + responder.ProcessTensor( + output_name, model_state_->output_dtype_map_[output_name], + model_state_->output_shape_map_[output_name], + inference_output.at(output_name).data(), TRITONSERVER_MEMORY_CPU, 0); } - - responder.ProcessTensor( - output_name, output_dtype, batchn_shape, output_buffer, - TRITONSERVER_MEMORY_CPU, 0); + } + catch (std::out_of_range& err) { + responder.Finalize(); + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_NOT_FOUND, "Failed to process output tensor"); } // Finalize and wait for any pending buffer copies. responder.Finalize(); + + return nullptr; } ///////////// extern "C" { -int32_t armnn_threads = INT_MAX; - TRITONSERVER_Error* TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) { -#ifdef PAPI_PROFILING_ENABLE - // Init PAPI library - RETURN_ERROR_IF_FALSE( - PAPI_library_init(PAPI_VER_CURRENT) == PAPI_VER_CURRENT, - TRITONSERVER_ERROR_UNAVAILABLE, std::string("Failed to init PAPI lib")); - RETURN_ERROR_IF_FALSE( - PAPI_thread_init(pthread_self) == PAPI_OK, TRITONSERVER_ERROR_UNAVAILABLE, - std::string("Failed to init PAPI thread lib")); - // The backend configuration may contain information needed by the // backend, such a command-line arguments. TRITONSERVER_Message* backend_config_message; @@ -1347,29 +1746,38 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) RETURN_IF_ERROR(backend_config.Parse(buffer, byte_size)); } triton::common::TritonJson::Value cmdline; + std::vector cpus_to_use; if (backend_config.Find("cmdline", &cmdline)) { triton::common::TritonJson::Value value; std::string value_str; - if (cmdline.Find("papi-events", &value)) { + if (cmdline.Find("cpus", &value)) { RETURN_IF_ERROR(value.AsString(&value_str)); std::stringstream ss(value_str); + std::vector range(2); + int i = 0; while (ss.good()) { std::string substr; - std::getline(ss, substr, ','); - // Validate counter is a valid papi counter - RETURN_ERROR_IF_FALSE( - PAPIEventValid(substr), TRITONSERVER_ERROR_INVALID_ARG, - std::string("PAPI event '") + substr + - "' is requested but invalid"); + std::getline(ss, substr, '-'); + // Get range of cpu values + range[i++] = std::stoi(substr); } - // Set environment for papi to do high level op profiling - RETURN_ERROR_IF_TRUE( - setenv("PAPI_EVENTS", value_str.c_str(), 1), - TRITONSERVER_ERROR_INVALID_ARG, - std::string("Could not set PAPI_EVENTS env variable")); + cpus_to_use.resize(range[1] - range[0] + 1); + std::iota(cpus_to_use.begin(), cpus_to_use.end(), range[0]); } } -#endif // PAPI_PROFILING_ENABLE + + // If we have any global backend state we create and set it here + try { + ArmNNTFLiteBackendState* state = new ArmNNTFLiteBackendState(cpus_to_use); + RETURN_IF_ERROR( + TRITONBACKEND_BackendSetState(backend, reinterpret_cast(state))); + } + catch (const BackendModelException& ex) { + RETURN_ERROR_IF_TRUE( + ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL, + std::string("unexpected nullptr in BackendModelException")); + RETURN_IF_ERROR(ex.err_); + } const char* cname; RETURN_IF_ERROR(TRITONBACKEND_BackendName(backend, &cname)); @@ -1433,7 +1841,7 @@ TRITONBACKEND_ModelInitialize(TRITONBACKEND_Model* model) // Create a ModelState object and associate it with the // TRITONBACKEND_Model. ModelState* model_state; - RETURN_IF_ERROR(ModelState::Create(model, &model_state, &armnn_threads)); + RETURN_IF_ERROR(ModelState::Create(model, &model_state)); RETURN_IF_ERROR( TRITONBACKEND_ModelSetState(model, reinterpret_cast(model_state))); @@ -1489,7 +1897,7 @@ TRITONBACKEND_ModelInstanceInitialize(TRITONBACKEND_ModelInstance* instance) // TRITONBACKEND_ModelInstance. ModelInstanceState* instance_state; RETURN_IF_ERROR( - ModelInstanceState::Create(model_state, instance, &instance_state)); + ModelInstanceState::Create(model_state, instance, name, &instance_state)); RETURN_IF_ERROR(TRITONBACKEND_ModelInstanceSetState( instance, reinterpret_cast(instance_state))); diff --git a/src/tflite_utils.cc b/src/tflite_utils.cc index 5be2fbd..b2ab3b8 100644 --- a/src/tflite_utils.cc +++ b/src/tflite_utils.cc @@ -1,8 +1,11 @@ +// +// Copyright © 2023 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + #include "tflite_utils.h" -#ifdef PAPI_PROFILING_ENABLE -#include -#endif // PAPI_PROFILING_ENABLE +#include namespace triton { namespace backend { namespace tensorflowlite { @@ -115,37 +118,47 @@ ModelConfigDataTypeToTFLiteType(const std::string& data_type_str) return std::make_pair(true, type); } -#ifdef PAPI_PROFILING_ENABLE -bool -PAPIEventValid(std::string& event_name) +std::vector +StringToIntVector(std::string const& s) { - int event_set = PAPI_NULL; - bool valid = false; - if (PAPI_create_eventset(&event_set) == PAPI_OK) { - valid = PAPI_add_named_event(event_set, event_name.c_str()) == PAPI_OK; - if (valid) { - if (PAPI_cleanup_eventset(event_set) != PAPI_OK) { - LOG_MESSAGE( - TRITONSERVER_LOG_WARN, - (std::string( - "Call to cleanup event_set failed when trying to check " - "event ") + - event_name) - .c_str()); - } - } - if (PAPI_destroy_eventset(&event_set) != PAPI_OK) { - LOG_MESSAGE( - TRITONSERVER_LOG_WARN, - (std::string("Call to destroy event_set failed when trying to check " - "event ") + - event_name) - .c_str()); - } + std::stringstream iss(s); + + int val; + std::vector result; + while (iss >> val) { + result.push_back(val); } - return valid; + return result; } -#endif // PAPI_PROFILING_ENABLE +void +PopulateCpusMap(std::unordered_map>& cpus) +{ + hwloc_topology_t topology; + hwloc_topology_init(&topology); + hwloc_topology_load(topology); + int num_phys_cpus = hwloc_get_nbobjs_by_type(topology, HWLOC_OBJ_CORE); + for (int i = 0; i < num_phys_cpus; ++i) { + hwloc_obj_t core = hwloc_get_obj_by_type(topology, HWLOC_OBJ_CORE, i); + if (core) { + hwloc_bitmap_t nodeset = core->nodeset; + for (unsigned int j = 0; j < core->arity; ++j) { + unsigned int cpu_id = core->children[j]->os_index; + // First insert first thread of cpu near front of list, then push all + // its children back + if (j == 0) { + cpus[hwloc_bitmap_first(nodeset)].insert( + cpus[hwloc_bitmap_first(nodeset)].begin() + + cpus[hwloc_bitmap_first(nodeset)].size() / core->arity, + cpu_id); + } else { + cpus[hwloc_bitmap_first(nodeset)].push_back(cpu_id); + } + } + } + } + + hwloc_topology_destroy(topology); +} }}} // namespace triton::backend::tensorflowlite diff --git a/src/tflite_utils.h b/src/tflite_utils.h index 9438937..e227539 100644 --- a/src/tflite_utils.h +++ b/src/tflite_utils.h @@ -6,6 +6,7 @@ #include +#include "hwloc.h" #include "tensorflow/lite/model.h" #include "triton/backend/backend_model.h" #include "triton/core/tritonserver.h" @@ -27,6 +28,8 @@ std::pair ConvertDataTypeToTFLiteType( std::pair ModelConfigDataTypeToTFLiteType( const std::string& data_type_str); +std::vector StringToIntVector(std::string const& s); + template std::string VectorToString(std::vector const& v) @@ -34,15 +37,13 @@ VectorToString(std::vector const& v) std::stringstream ss; for (size_t i = 0; i < v.size(); i++) { if (i != 0) { - ss << ", "; + ss << ","; } ss << v[i]; } return ss.str(); } -#ifdef PAPI_PROFILING_ENABLE -bool PAPIEventValid(std::string& event_name); -#endif // PAPI_PROFILING_ENABLE +void PopulateCpusMap(std::unordered_map>&); }}} // namespace triton::backend::tensorflowlite