diff --git a/CMakeLists.txt b/CMakeLists.txt index 130455d7b..e43a8c2c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -415,3 +415,7 @@ export( ) export(PACKAGE TritonCore) + +if(NOT TRITON_CORE_HEADERS_ONLY) + add_subdirectory(python python) +endif() diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt new file mode 100644 index 000000000..71eeeaba4 --- /dev/null +++ b/python/CMakeLists.txt @@ -0,0 +1,69 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +cmake_minimum_required(VERSION 3.18) + +add_subdirectory(tritonserver) + +file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/TRITON_VERSION ${TRITON_VERSION}) +configure_file(../LICENSE LICENSE.txt COPYONLY) +configure_file(setup.py setup.py @ONLY) +file(COPY test/test_binding.py DESTINATION ./test/.) + +set(WHEEL_DEPENDS + ${CMAKE_CURRENT_BINARY_DIR}/TRITON_VERSION + ${CMAKE_CURRENT_BINARY_DIR}/LICENSE.txt + ${CMAKE_CURRENT_BINARY_DIR}/setup.py + ${CMAKE_CURRENT_BINARY_DIR}/tritonserver + python-bindings +) + +set(wheel_stamp_file "stamp.whl") + +add_custom_command( + OUTPUT "${wheel_stamp_file}" + COMMAND python3 + ARGS + "${CMAKE_CURRENT_SOURCE_DIR}/build_wheel.py" + --dest-dir "${CMAKE_CURRENT_BINARY_DIR}/generic" + --binding-path $ + DEPENDS ${WHEEL_DEPENDS} +) + +add_custom_target( + generic-server-wheel ALL + DEPENDS + "${wheel_stamp_file}" +) + +install( + CODE "file(GLOB _Wheel \"${CMAKE_CURRENT_BINARY_DIR}/generic/triton*.whl\")" + CODE "file(INSTALL \${_Wheel} DESTINATION \"${CMAKE_INSTALL_PREFIX}/python\")" +) + +# Test +install( + CODE "file(INSTALL ${CMAKE_CURRENT_BINARY_DIR}/test/test_binding.py DESTINATION \"${CMAKE_INSTALL_PREFIX}/python\")" +) diff --git a/python/build_wheel.py b/python/build_wheel.py new file mode 100755 index 000000000..45885479e --- /dev/null +++ b/python/build_wheel.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import os +import pathlib +import re +import shutil +import subprocess +import sys +from distutils.dir_util import copy_tree +from tempfile import mkstemp + + +def fail_if(p, msg): + if p: + print("error: {}".format(msg), file=sys.stderr) + sys.exit(1) + + +def mkdir(path): + pathlib.Path(path).mkdir(parents=True, exist_ok=True) + + +def touch(path): + pathlib.Path(path).touch() + + +def cpdir(src, dest): + copy_tree(src, dest, preserve_symlinks=1) + + +def sed(pattern, replace, source, dest=None): + fin = open(source, "r") + if dest: + fout = open(dest, "w") + else: + fd, name = mkstemp() + fout = open(name, "w") + + for line in fin: + out = re.sub(pattern, replace, line) + fout.write(out) + + fin.close() + fout.close() + if not dest: + shutil.copyfile(name, source) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dest-dir", type=str, required=True, help="Destination directory." + ) + parser.add_argument( + "--binding-path", type=str, required=True, help="Path to Triton Python binding." + ) + + FLAGS = parser.parse_args() + + FLAGS.triton_version = None + with open("TRITON_VERSION", "r") as vfile: + FLAGS.triton_version = vfile.readline().strip() + + FLAGS.whl_dir = os.path.join(FLAGS.dest_dir, "wheel") + + print("=== Building in: {}".format(os.getcwd())) + print("=== Using builddir: {}".format(FLAGS.whl_dir)) + print("Adding package files") + + mkdir(os.path.join(FLAGS.whl_dir, "tritonserver")) + shutil.copy("tritonserver/__init__.py", os.path.join(FLAGS.whl_dir, "tritonserver")) + + cpdir("tritonserver/_c", os.path.join(FLAGS.whl_dir, "tritonserver", "_c")) + PYBIND_LIB = os.path.basename(FLAGS.binding_path) + shutil.copyfile( + FLAGS.binding_path, + os.path.join(FLAGS.whl_dir, "tritonserver", "_c", PYBIND_LIB), + ) + + shutil.copyfile("LICENSE.txt", os.path.join(FLAGS.whl_dir, "LICENSE.txt")) + shutil.copyfile("setup.py", os.path.join(FLAGS.whl_dir, "setup.py")) + + os.chdir(FLAGS.whl_dir) + print("=== Building wheel") + args = ["python3", "setup.py", "bdist_wheel"] + + wenv = os.environ.copy() + wenv["VERSION"] = FLAGS.triton_version + wenv["TRITON_PYBIND"] = PYBIND_LIB + p = subprocess.Popen(args, env=wenv) + p.wait() + fail_if(p.returncode != 0, "setup.py failed") + + cpdir("dist", FLAGS.dest_dir) + + print("=== Output wheel file is in: {}".format(FLAGS.dest_dir)) + touch(os.path.join(FLAGS.dest_dir, "stamp.whl")) diff --git a/python/setup.py b/python/setup.py new file mode 100755 index 000000000..0c32be89b --- /dev/null +++ b/python/setup.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import sys +from itertools import chain + +from setuptools import find_packages, setup + +if "--plat-name" in sys.argv: + PLATFORM_FLAG = sys.argv[sys.argv.index("--plat-name") + 1] +else: + PLATFORM_FLAG = "any" + +if "VERSION" not in os.environ: + raise Exception("envvar VERSION must be specified") + +VERSION = os.environ["VERSION"] + +try: + from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + + class bdist_wheel(_bdist_wheel): + def finalize_options(self): + _bdist_wheel.finalize_options(self) + self.root_is_pure = False + + def get_tag(self): + pyver, abi, plat = "py3", "none", PLATFORM_FLAG + return pyver, abi, plat + +except ImportError: + bdist_wheel = None + +this_directory = os.path.abspath(os.path.dirname(__file__)) + +data_files = [ + ("", ["LICENSE.txt"]), +] +platform_package_data = [os.environ["TRITON_PYBIND"]] + +setup( + name="tritonserver", + version=VERSION, + author="NVIDIA Inc.", + author_email="sw-dl-triton@nvidia.com", + description="Python API of the Triton In-Process Server", + license="BSD", + url="https://developer.nvidia.com/nvidia-triton-inference-server", + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Intended Audience :: Information Technology", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Image Recognition", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries", + "Topic :: Utilities", + "License :: OSI Approved :: BSD License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Environment :: Console", + "Natural Language :: English", + "Operating System :: OS Independent", + ], + packages=find_packages(), + package_data={ + "": platform_package_data, + }, + zip_safe=False, + cmdclass={"bdist_wheel": bdist_wheel}, + data_files=data_files, +) diff --git a/python/test/test_binding.py b/python/test/test_binding.py new file mode 100644 index 000000000..8f084bec5 --- /dev/null +++ b/python/test/test_binding.py @@ -0,0 +1,1133 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import gc +import json +import os +import queue +import shutil +import unittest + +import numpy +from tritonserver import _c as triton_bindings + + +# Callback functions used in inference pipeline +# 'user_object' is a per-request counter of how many times the +# callback is invoked +def g_alloc_fn( + allocator, tensor_name, byte_size, memory_type, memory_type_id, user_object +): + if "alloc" not in user_object: + user_object["alloc"] = 0 + user_object["alloc"] += 1 + buffer = numpy.empty(byte_size, numpy.byte) + return (buffer.ctypes.data, buffer, triton_bindings.TRITONSERVER_MemoryType.CPU, 0) + + +def g_release_fn( + allocator, buffer, buffer_user_object, byte_size, memory_type, memory_type_id +): + # No-op, buffer ('buffer_user_object') will be garbage collected + # only sanity check that the objects are expected + if (not isinstance(buffer_user_object, numpy.ndarray)) or ( + buffer_user_object.ctypes.data != buffer + ): + raise Exception("Misaligned parameters in allocator release callback") + pass + + +def g_start_fn(allocator, user_object): + if "start" not in user_object: + user_object["start"] = 0 + user_object["start"] += 1 + pass + + +def g_query_fn( + allocator, user_object, tensor_name, byte_size, memory_type, memory_type_id +): + if "query" not in user_object: + user_object["query"] = 0 + user_object["query"] += 1 + return (triton_bindings.TRITONSERVER_MemoryType.CPU, 0) + + +def g_buffer_fn( + allocator, tensor_name, buffer_attribute, user_object, buffer_user_object +): + if "buffer" not in user_object: + user_object["buffer"] = 0 + user_object["buffer"] += 1 + buffer_attribute.memory_type = triton_bindings.TRITONSERVER_MemoryType.CPU + buffer_attribute.memory_type_id = 0 + buffer_attribute.byte_size = buffer_user_object.size + return buffer_attribute + + +def g_timestamp_fn(trace, activity, timestamp_ns, user_object): + if trace.id not in user_object: + user_object[trace.id] = [] + # not owning trace, so must read property out + trace_log = { + "id": trace.id, + "parent_id": trace.parent_id, + "model_name": trace.model_name, + "model_version": trace.model_version, + "request_id": trace.request_id, + "activity": activity, + "timestamp": timestamp_ns, + } + user_object[trace.id].append(trace_log) + + +def g_tensor_fn( + trace, + activity, + tensor_name, + data_type, + buffer, + byte_size, + shape, + memory_type, + memory_type_id, + user_object, +): + if trace.id not in user_object: + user_object[trace.id] = [] + + # not owning trace, so must read property out + trace_log = { + "id": trace.id, + "parent_id": trace.parent_id, + "model_name": trace.model_name, + "model_version": trace.model_version, + "request_id": trace.request_id, + "activity": activity, + "tensor": { + "name": tensor_name, + "data_type": data_type, + # skip 'buffer' + "byte_size": byte_size, + "shape": shape, + "memory_type": memory_type, + "memory_type_id": memory_type_id, + }, + } + user_object[trace.id].append(trace_log) + + +def g_trace_release_fn(trace, user_object): + # sanity check that 'trace' has been tracked, the object + # will be released on garbage collection + if trace.id not in user_object: + raise Exception("Releasing unseen trace") + user_object["signal_queue"].put("TRACE_RELEASED") + + +def g_response_fn(response, flags, user_object): + user_object.put((flags, response)) + + +def g_request_fn(request, flags, user_object): + if flags != 1: + raise Exception("Unexpected request release flag") + # counter of "inflight" requests + user_object.put(request) + + +# Python model file string to fastly deploy test model, depends on +# 'TRITONSERVER_Server' operates properly to load model with content passed +# through the load API. +g_python_addsub = b''' +import json +import numpy as np +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + @staticmethod + def auto_complete_config(auto_complete_model_config): + input0 = {"name": "INPUT0", "data_type": "TYPE_FP32", "dims": [4]} + input1 = {"name": "INPUT1", "data_type": "TYPE_FP32", "dims": [4]} + output0 = {"name": "OUTPUT0", "data_type": "TYPE_FP32", "dims": [4]} + output1 = {"name": "OUTPUT1", "data_type": "TYPE_FP32", "dims": [4]} + + auto_complete_model_config.set_max_batch_size(0) + auto_complete_model_config.add_input(input0) + auto_complete_model_config.add_input(input1) + auto_complete_model_config.add_output(output0) + auto_complete_model_config.add_output(output1) + + # [WARNING] Specify specific dynamic batching field by knowing + # the implementation detail + auto_complete_model_config.set_dynamic_batching() + auto_complete_model_config._model_config["dynamic_batching"]["priority_levels"] = 20 + auto_complete_model_config._model_config["dynamic_batching"]["default_priority_level"] = 10 + + return auto_complete_model_config + + def initialize(self, args): + self.model_config = model_config = json.loads(args["model_config"]) + + output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT0") + output1_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT1") + + self.output0_dtype = pb_utils.triton_string_to_numpy( + output0_config["data_type"] + ) + self.output1_dtype = pb_utils.triton_string_to_numpy( + output1_config["data_type"] + ) + + def execute(self, requests): + """This function is called on inference request.""" + + output0_dtype = self.output0_dtype + output1_dtype = self.output1_dtype + + responses = [] + for request in requests: + in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0") + in_1 = pb_utils.get_input_tensor_by_name(request, "INPUT1") + out_0, out_1 = ( + in_0.as_numpy() + in_1.as_numpy(), + in_0.as_numpy() - in_1.as_numpy(), + ) + + out_tensor_0 = pb_utils.Tensor("OUTPUT0", out_0.astype(output0_dtype)) + out_tensor_1 = pb_utils.Tensor("OUTPUT1", out_1.astype(output1_dtype)) + responses.append(pb_utils.InferenceResponse([out_tensor_0, out_tensor_1])) + return responses +''' + + +# ======================================= Test cases =========================== +class BindingTest(unittest.TestCase): + def setUp(self): + self._test_model_repo = os.path.join(os.getcwd(), "binding_test_repo") + # clear model repository that may be created for testing. + if os.path.exists(self._test_model_repo): + shutil.rmtree(self._test_model_repo) + os.makedirs(self._test_model_repo) + self._model_name = "addsub" + self._version = "1" + self._file_name = "model.py" + + def tearDown(self): + gc.collect() + # clear model repository that may be created for testing. + if os.path.exists(self._test_model_repo): + shutil.rmtree(self._test_model_repo) + + # helper functions + def _to_pyobject(self, triton_message): + return json.loads(triton_message.serialize_to_json()) + + # prepare a model repository with "addsub" model + def _create_model_repository(self): + os.makedirs( + os.path.join(self._test_model_repo, self._model_name, self._version) + ) + with open( + os.path.join( + self._test_model_repo, self._model_name, self._version, self._file_name + ), + "wb", + ) as f: + f.write(g_python_addsub) + + # create a Triton instance with POLL mode on repository prepared by + # '_create_model_repository' + def _start_polling_server(self): + # prepare model repository + self._create_model_repository() + + options = triton_bindings.TRITONSERVER_ServerOptions() + options.set_model_repository_path(self._test_model_repo) + options.set_model_control_mode( + triton_bindings.TRITONSERVER_ModelControlMode.POLL + ) + # enable "auto-complete" to skip providing config.pbtxt + options.set_strict_model_config(False) + options.set_server_id("testing_server") + # [FIXME] Need to fix coupling of response and server + options.set_exit_timeout(5) + return triton_bindings.TRITONSERVER_Server(options) + + def _prepare_inference_request(self, server): + allocator = triton_bindings.TRITONSERVER_ResponseAllocator( + g_alloc_fn, g_release_fn, g_start_fn + ) + allocator.set_buffer_attributes_function(g_buffer_fn) + allocator.set_query_function(g_query_fn) + + request_counter = queue.Queue() + response_queue = queue.Queue() + allocator_counter = {} + request = triton_bindings.TRITONSERVER_InferenceRequest( + server, self._model_name, -1 + ) + request.id = "req_0" + request.set_release_callback(g_request_fn, request_counter) + request.set_response_callback( + allocator, allocator_counter, g_response_fn, response_queue + ) + + input = numpy.ones([4], dtype=numpy.float32) + input_buffer = input.ctypes.data + ba = triton_bindings.TRITONSERVER_BufferAttributes() + ba.memory_type = triton_bindings.TRITONSERVER_MemoryType.CPU + ba.memory_type_id = 0 + ba.byte_size = input.itemsize * input.size + + request.add_input( + "INPUT0", triton_bindings.TRITONSERVER_DataType.FP32, input.shape + ) + request.add_input( + "INPUT1", triton_bindings.TRITONSERVER_DataType.FP32, input.shape + ) + request.append_input_data_with_buffer_attributes("INPUT0", input_buffer, ba) + request.append_input_data_with_buffer_attributes("INPUT1", input_buffer, ba) + + return request, allocator, response_queue, request_counter + + def test_exceptions(self): + ex_list = [ + triton_bindings.UnknownError, + triton_bindings.InternalError, + triton_bindings.NotFoundError, + triton_bindings.InvalidArgumentError, + triton_bindings.UnavailableError, + triton_bindings.UnsupportedError, + triton_bindings.AlreadyExistsError, + ] + for ex_type in ex_list: + with self.assertRaises(triton_bindings.TritonError) as ctx: + raise ex_type("Error message") + self.assertTrue(isinstance(ctx.exception, ex_type)) + self.assertEqual(str(ctx.exception), "Error message") + + def test_data_type(self): + t_list = [ + (triton_bindings.TRITONSERVER_DataType.INVALID, "", 0), + (triton_bindings.TRITONSERVER_DataType.BOOL, "BOOL", 1), + (triton_bindings.TRITONSERVER_DataType.UINT8, "UINT8", 1), + (triton_bindings.TRITONSERVER_DataType.UINT16, "UINT16", 2), + (triton_bindings.TRITONSERVER_DataType.UINT32, "UINT32", 4), + (triton_bindings.TRITONSERVER_DataType.UINT64, "UINT64", 8), + (triton_bindings.TRITONSERVER_DataType.INT8, "INT8", 1), + (triton_bindings.TRITONSERVER_DataType.INT16, "INT16", 2), + (triton_bindings.TRITONSERVER_DataType.INT32, "INT32", 4), + (triton_bindings.TRITONSERVER_DataType.INT64, "INT64", 8), + (triton_bindings.TRITONSERVER_DataType.FP16, "FP16", 2), + (triton_bindings.TRITONSERVER_DataType.FP32, "FP32", 4), + (triton_bindings.TRITONSERVER_DataType.FP64, "FP64", 8), + (triton_bindings.TRITONSERVER_DataType.BYTES, "BYTES", 0), + (triton_bindings.TRITONSERVER_DataType.BF16, "BF16", 2), + ] + + for t, t_str, t_size in t_list: + self.assertEqual(triton_bindings.TRITONSERVER_DataTypeString(t), t_str) + self.assertEqual(triton_bindings.TRITONSERVER_StringToDataType(t_str), t) + self.assertEqual(triton_bindings.TRITONSERVER_DataTypeByteSize(t), t_size) + + def test_memory_type(self): + t_list = [ + (triton_bindings.TRITONSERVER_MemoryType.CPU, "CPU"), + (triton_bindings.TRITONSERVER_MemoryType.CPU_PINNED, "CPU_PINNED"), + (triton_bindings.TRITONSERVER_MemoryType.GPU, "GPU"), + ] + for t, t_str in t_list: + self.assertEqual(triton_bindings.TRITONSERVER_MemoryTypeString(t), t_str) + + def test_parameter_type(self): + t_list = [ + (triton_bindings.TRITONSERVER_ParameterType.STRING, "STRING"), + (triton_bindings.TRITONSERVER_ParameterType.INT, "INT"), + (triton_bindings.TRITONSERVER_ParameterType.BOOL, "BOOL"), + (triton_bindings.TRITONSERVER_ParameterType.BYTES, "BYTES"), + ] + for t, t_str in t_list: + self.assertEqual(triton_bindings.TRITONSERVER_ParameterTypeString(t), t_str) + + def test_parameter(self): + # C API doesn't provide additional API for parameter, can only test + # New/Delete unless we mock the implementation to expose more info. + str_param = triton_bindings.TRITONSERVER_Parameter("str_key", "str_value") + int_param = triton_bindings.TRITONSERVER_Parameter("int_key", 123) + bool_param = triton_bindings.TRITONSERVER_Parameter("bool_key", True) + # bytes parameter doesn't own the buffer + b = bytes("abc", "utf-8") + bytes_param = triton_bindings.TRITONSERVER_Parameter("bytes_key", b) + del str_param + del int_param + del bool_param + del bytes_param + gc.collect() + + def test_instance_kind(self): + t_list = [ + (triton_bindings.TRITONSERVER_InstanceGroupKind.AUTO, "AUTO"), + (triton_bindings.TRITONSERVER_InstanceGroupKind.CPU, "CPU"), + (triton_bindings.TRITONSERVER_InstanceGroupKind.GPU, "GPU"), + (triton_bindings.TRITONSERVER_InstanceGroupKind.MODEL, "MODEL"), + ] + for t, t_str in t_list: + self.assertEqual( + triton_bindings.TRITONSERVER_InstanceGroupKindString(t), t_str + ) + + def test_log(self): + # This test depends on 'TRITONSERVER_ServerOptions' operates properly + # to modify log settings. + + # Direct Triton to log message into a file so that the log may be + # retrieved on the Python side. Otherwise the log will be default + # on stderr and Python utils can not redirect the pipe on Triton side. + log_file = "triton_binding_test_log_output.txt" + default_format_regex = r"[0-9][0-9][0-9][0-9] [0-9][0-9]:[0-9][0-9]:[0-9][0-9].[0-9][0-9][0-9][0-9][0-9][0-9]" + iso8601_format_regex = r"[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]T[0-9][0-9]:[0-9][0-9]:[0-9][0-9]Z" + try: + options = triton_bindings.TRITONSERVER_ServerOptions() + # Enable subset of log levels + options.set_log_file(log_file) + options.set_log_info(True) + options.set_log_warn(False) + options.set_log_error(True) + options.set_log_verbose(0) + options.set_log_format(triton_bindings.TRITONSERVER_LogFormat.DEFAULT) + for ll, enabled in [ + (triton_bindings.TRITONSERVER_LogLevel.INFO, True), + (triton_bindings.TRITONSERVER_LogLevel.WARN, False), + (triton_bindings.TRITONSERVER_LogLevel.ERROR, True), + (triton_bindings.TRITONSERVER_LogLevel.VERBOSE, False), + ]: + self.assertEqual(triton_bindings.TRITONSERVER_LogIsEnabled(ll), enabled) + # Write message to each of the log level + triton_bindings.TRITONSERVER_LogMessage( + triton_bindings.TRITONSERVER_LogLevel.INFO, + "filename", + 123, + "info_message", + ) + triton_bindings.TRITONSERVER_LogMessage( + triton_bindings.TRITONSERVER_LogLevel.WARN, + "filename", + 456, + "warn_message", + ) + triton_bindings.TRITONSERVER_LogMessage( + triton_bindings.TRITONSERVER_LogLevel.ERROR, + "filename", + 789, + "error_message", + ) + triton_bindings.TRITONSERVER_LogMessage( + triton_bindings.TRITONSERVER_LogLevel.VERBOSE, + "filename", + 147, + "verbose_message", + ) + with open(log_file, "r") as f: + log = f.read() + # Check level + self.assertRegex(log, r"filename:123.*info_message") + self.assertNotRegex(log, r"filename:456.*warn_message") + self.assertRegex(log, r"filename:789.*error_message") + self.assertNotRegex(log, r"filename:147.*verbose_message") + # Check format "MMDD hh:mm:ss.ssssss". + self.assertRegex(log, default_format_regex) + # sanity check that there is no log with other format "YYYY-MM-DDThh:mm:ssZ L" + self.assertNotRegex(log, iso8601_format_regex) + # Test different format + options.set_log_format(triton_bindings.TRITONSERVER_LogFormat.ISO8601) + triton_bindings.TRITONSERVER_LogMessage( + triton_bindings.TRITONSERVER_LogLevel.INFO, "fn", 258, "info_message" + ) + with open(log_file, "r") as f: + log = f.read() + self.assertRegex(log, r"fn:258.*info_message") + self.assertRegex(log, iso8601_format_regex) + finally: + # Must make sure the log settings are reset as the logger is unique + # within the process + options.set_log_file("") + options.set_log_info(False) + options.set_log_warn(False) + options.set_log_error(False) + options.set_log_verbose(0) + options.set_log_format(triton_bindings.TRITONSERVER_LogFormat.DEFAULT) + os.remove(log_file) + + def test_buffer_attributes(self): + expected_memory_type = triton_bindings.TRITONSERVER_MemoryType.CPU_PINNED + expected_memory_type_id = 4 + expected_byte_size = 1024 + buffer_attributes = triton_bindings.TRITONSERVER_BufferAttributes() + buffer_attributes.memory_type_id = expected_memory_type_id + self.assertEqual(buffer_attributes.memory_type_id, expected_memory_type_id) + buffer_attributes.memory_type = expected_memory_type + self.assertEqual(buffer_attributes.memory_type, expected_memory_type) + buffer_attributes.byte_size = expected_byte_size + self.assertEqual(buffer_attributes.byte_size, expected_byte_size) + # cuda_ipc_handle is supposed to be cudaIpcMemHandle_t, must initialize buffer + # of that size to avoid segfault. The handle getter/setter is different from other + # attributes that different pointers may be returned from the getter, but the byte + # content pointed by the pointer should be the same + import ctypes + from array import array + + handle_byte_size = 64 + mock_handle = array("b", [i for i in range(handle_byte_size)]) + buffer_attributes.cuda_ipc_handle = mock_handle.buffer_info()[0] + res_arr = (ctypes.c_char * handle_byte_size).from_address( + buffer_attributes.cuda_ipc_handle + ) + for i in range(handle_byte_size): + self.assertEqual(int.from_bytes(res_arr[i], "big"), mock_handle[i]) + + def test_allocator(self): + def alloc_fn( + allocator, tensor_name, byte_size, memory_type, memory_type_id, user_object + ): + return (123, None, triton_bindings.TRITONSERVER_MemoryType.GPU, 1) + + def release_fn( + allocator, + buffer, + buffer_user_object, + byte_size, + memory_type, + memory_type_id, + ): + pass + + def start_fn(allocator, user_object): + pass + + def query_fn( + allocator, user_object, tensor_name, byte_size, memory_type, memory_type_id + ): + return (triton_bindings.TRITONSERVER_MemoryType.GPU, 1) + + def buffer_fn( + allocator, tensor_name, buffer_attribute, user_object, buffer_user_object + ): + return buffer_attribute + + # allocator without start_fn + allocator = triton_bindings.TRITONSERVER_ResponseAllocator(alloc_fn, release_fn) + del allocator + gc.collect() + + # allocator with start_fn + allocator = triton_bindings.TRITONSERVER_ResponseAllocator( + alloc_fn, release_fn, start_fn + ) + allocator.set_buffer_attributes_function(buffer_fn) + allocator.set_query_function(query_fn) + + def test_message(self): + expected_dict = {"key_0": [1, 2, "3"], "key_1": {"nested_key": "nested_value"}} + message = triton_bindings.TRITONSERVER_Message(json.dumps(expected_dict)) + self.assertEqual(expected_dict, json.loads(message.serialize_to_json())) + + def test_metrics(self): + # This test depends on 'TRITONSERVER_Server' operates properly + # to access metrics. + + # Create server in EXPLICIT mode so we don't need to ensure + # a model repository is proper repository + options = triton_bindings.TRITONSERVER_ServerOptions() + options.set_model_repository_path(self._test_model_repo) + options.set_model_control_mode( + triton_bindings.TRITONSERVER_ModelControlMode.EXPLICIT + ) + server = triton_bindings.TRITONSERVER_Server(options) + metrics = server.metrics() + # Check one of the metrics is reported + self.assertTrue( + "nv_cpu_memory_used_bytes" + in metrics.formatted(triton_bindings.TRITONSERVER_MetricFormat.PROMETHEUS) + ) + + def test_trace_enum(self): + t_list = [ + (triton_bindings.TRITONSERVER_InferenceTraceLevel.DISABLED, "DISABLED"), + (triton_bindings.TRITONSERVER_InferenceTraceLevel.MIN, "MIN"), + (triton_bindings.TRITONSERVER_InferenceTraceLevel.MAX, "MAX"), + (triton_bindings.TRITONSERVER_InferenceTraceLevel.TIMESTAMPS, "TIMESTAMPS"), + (triton_bindings.TRITONSERVER_InferenceTraceLevel.TENSORS, "TENSORS"), + ] + for t, t_str in t_list: + self.assertEqual( + triton_bindings.TRITONSERVER_InferenceTraceLevelString(t), t_str + ) + # bit-wise operation + level = int(triton_bindings.TRITONSERVER_InferenceTraceLevel.TIMESTAMPS) | int( + triton_bindings.TRITONSERVER_InferenceTraceLevel.TENSORS + ) + self.assertNotEqual( + level & int(triton_bindings.TRITONSERVER_InferenceTraceLevel.TIMESTAMPS), 0 + ) + self.assertNotEqual( + level & int(triton_bindings.TRITONSERVER_InferenceTraceLevel.TENSORS), 0 + ) + + t_list = [ + ( + triton_bindings.TRITONSERVER_InferenceTraceActivity.REQUEST_START, + "REQUEST_START", + ), + ( + triton_bindings.TRITONSERVER_InferenceTraceActivity.QUEUE_START, + "QUEUE_START", + ), + ( + triton_bindings.TRITONSERVER_InferenceTraceActivity.COMPUTE_START, + "COMPUTE_START", + ), + ( + triton_bindings.TRITONSERVER_InferenceTraceActivity.COMPUTE_INPUT_END, + "COMPUTE_INPUT_END", + ), + ( + triton_bindings.TRITONSERVER_InferenceTraceActivity.COMPUTE_OUTPUT_START, + "COMPUTE_OUTPUT_START", + ), + ( + triton_bindings.TRITONSERVER_InferenceTraceActivity.COMPUTE_END, + "COMPUTE_END", + ), + ( + triton_bindings.TRITONSERVER_InferenceTraceActivity.REQUEST_END, + "REQUEST_END", + ), + ( + triton_bindings.TRITONSERVER_InferenceTraceActivity.TENSOR_QUEUE_INPUT, + "TENSOR_QUEUE_INPUT", + ), + ( + triton_bindings.TRITONSERVER_InferenceTraceActivity.TENSOR_BACKEND_INPUT, + "TENSOR_BACKEND_INPUT", + ), + ( + triton_bindings.TRITONSERVER_InferenceTraceActivity.TENSOR_BACKEND_OUTPUT, + "TENSOR_BACKEND_OUTPUT", + ), + ] + for t, t_str in t_list: + self.assertEqual( + triton_bindings.TRITONSERVER_InferenceTraceActivityString(t), t_str + ) + + def test_trace(self): + # This test depends on 'test_infer_async' test to capture + # the trace + level = int(triton_bindings.TRITONSERVER_InferenceTraceLevel.TIMESTAMPS) | int( + triton_bindings.TRITONSERVER_InferenceTraceLevel.TENSORS + ) + trace_dict = {"signal_queue": queue.Queue()} + trace = triton_bindings.TRITONSERVER_InferenceTrace( + level, 123, g_timestamp_fn, g_tensor_fn, g_trace_release_fn, trace_dict + ) + # [FIXME] get a copy of trace id due to potential issue of 'trace' + # lifecycle + trace_id = trace.id + + # Send and wait for inference, not care about result. + server = self._start_polling_server() + ( + request, + allocator, + response_queue, + request_counter, + ) = self._prepare_inference_request(server) + server.infer_async(request, trace) + + # [FIXME] WAR due to trace lifecycle is tied to response in Triton core, + # trace reference should drop on response send.. + res = response_queue.get(block=True, timeout=10) + del res + gc.collect() + + _ = trace_dict["signal_queue"].get(block=True, timeout=10) + + # check 'trace_dict' + self.assertTrue(trace_id in trace_dict) + + # check activity are logged correctly, + # value of 0 indicate it is timestamp trace, + # non-zero is tensor trace and the value is how many times this + # particular activity should be logged + expected_activities = { + # timestamp + triton_bindings.TRITONSERVER_InferenceTraceActivity.REQUEST_START: 0, + triton_bindings.TRITONSERVER_InferenceTraceActivity.QUEUE_START: 0, + triton_bindings.TRITONSERVER_InferenceTraceActivity.COMPUTE_START: 0, + triton_bindings.TRITONSERVER_InferenceTraceActivity.COMPUTE_INPUT_END: 0, + triton_bindings.TRITONSERVER_InferenceTraceActivity.COMPUTE_OUTPUT_START: 0, + triton_bindings.TRITONSERVER_InferenceTraceActivity.COMPUTE_END: 0, + triton_bindings.TRITONSERVER_InferenceTraceActivity.REQUEST_END: 0, + # not timestamp + triton_bindings.TRITONSERVER_InferenceTraceActivity.TENSOR_QUEUE_INPUT: 2, + # TENSOR_BACKEND_INPUT never get called with in Triton core + # triton_bindings.TRITONSERVER_InferenceTraceActivity.TENSOR_BACKEND_INPUT : 2, + triton_bindings.TRITONSERVER_InferenceTraceActivity.TENSOR_BACKEND_OUTPUT: 2, + } + for tl in trace_dict[trace_id]: + # basic check + self.assertEqual(tl["id"], trace_id) + self.assertEqual(tl["parent_id"], 123) + self.assertEqual(tl["model_name"], self._model_name) + self.assertEqual(tl["model_version"], 1) + self.assertEqual(tl["request_id"], "req_0") + self.assertTrue(tl["activity"] in expected_activities) + if expected_activities[tl["activity"]] == 0: + self.assertTrue("timestamp" in tl) + else: + self.assertTrue("tensor" in tl) + expected_activities[tl["activity"]] -= 1 + if expected_activities[tl["activity"]] == 0: + del expected_activities[tl["activity"]] + # check if dict is empty to ensure the activity are logged in correct + # amount. + self.assertFalse(bool(expected_activities)) + request_counter.get() + + def test_options(self): + options = triton_bindings.TRITONSERVER_ServerOptions() + + # Generic + options.set_server_id("server_id") + options.set_min_supported_compute_capability(7.0) + options.set_exit_on_error(False) + options.set_strict_readiness(False) + options.set_exit_timeout(30) + + # Models + options.set_model_repository_path("model_repo_0") + options.set_model_repository_path("model_repo_1") + for m in [ + triton_bindings.TRITONSERVER_ModelControlMode.NONE, + triton_bindings.TRITONSERVER_ModelControlMode.POLL, + triton_bindings.TRITONSERVER_ModelControlMode.EXPLICIT, + ]: + options.set_model_control_mode(m) + options.set_startup_model("*") + options.set_strict_model_config(True) + options.set_model_load_thread_count(2) + options.set_model_namespacing(True) + # Only support Kind GPU for now + options.set_model_load_device_limit( + triton_bindings.TRITONSERVER_InstanceGroupKind.GPU, 0, 0.5 + ) + for k in [ + triton_bindings.TRITONSERVER_InstanceGroupKind.AUTO, + triton_bindings.TRITONSERVER_InstanceGroupKind.CPU, + triton_bindings.TRITONSERVER_InstanceGroupKind.MODEL, + ]: + with self.assertRaises(triton_bindings.TritonError) as context: + options.set_model_load_device_limit(k, 0, 0) + self.assertTrue("not supported" in str(context.exception)) + + # Backend + options.set_backend_directory("backend_dir_0") + options.set_backend_directory("backend_dir_1") + options.set_backend_config("backend_name", "setting", "value") + + # Rate limiter + for r in [ + triton_bindings.TRITONSERVER_RateLimitMode.OFF, + triton_bindings.TRITONSERVER_RateLimitMode.EXEC_COUNT, + ]: + options.set_rate_limiter_mode(r) + options.add_rate_limiter_resource("shared_resource", 4, -1) + options.add_rate_limiter_resource("device_resource", 1, 0) + # memory pools + options.set_pinned_memory_pool_byte_size(1024) + options.set_cuda_memory_pool_byte_size(0, 2048) + # cache + options.set_response_cache_byte_size(4096) + options.set_cache_config( + "cache_name", json.dumps({"config_0": "value_0", "config_1": "value_1"}) + ) + options.set_cache_directory("cache_dir_0") + options.set_cache_directory("cache_dir_1") + # Log + try: + options.set_log_file("some_file") + options.set_log_info(True) + options.set_log_warn(True) + options.set_log_error(True) + options.set_log_verbose(2) + for f in [ + triton_bindings.TRITONSERVER_LogFormat.DEFAULT, + triton_bindings.TRITONSERVER_LogFormat.ISO8601, + ]: + options.set_log_format(f) + finally: + # Must make sure the log settings are reset as the logger is unique + # within the process + options.set_log_file("") + options.set_log_info(False) + options.set_log_warn(False) + options.set_log_error(False) + options.set_log_verbose(0) + options.set_log_format(triton_bindings.TRITONSERVER_LogFormat.DEFAULT) + + # Metrics + options.set_gpu_metrics(True) + options.set_cpu_metrics(True) + options.set_metrics_interval(5) + options.set_metrics_config("metrics_group", "setting", "value") + + # Misc.. + with self.assertRaises(triton_bindings.TritonError) as context: + options.set_host_policy("policy_name", "setting", "value") + self.assertTrue("Unsupported host policy setting" in str(context.exception)) + options.set_repo_agent_directory("repo_agent_dir_0") + options.set_repo_agent_directory("repo_agent_dir_1") + options.set_buffer_manager_thread_count(4) + + def test_server(self): + server = self._start_polling_server() + # is_live + self.assertTrue(server.is_live()) + # is_ready + self.assertTrue(server.is_ready()) + # model_is_ready + self.assertTrue(server.model_is_ready(self._model_name, -1)) + # model_batch_properties + expected_batch_properties = ( + int(triton_bindings.TRITONSERVER_ModelBatchFlag.UNKNOWN), + 0, + ) + self.assertEqual( + server.model_batch_properties(self._model_name, -1), + expected_batch_properties, + ) + # model_transaction_properties + expected_transaction_policy = ( + int(triton_bindings.TRITONSERVER_ModelTxnPropertyFlag.ONE_TO_ONE), + 0, + ) + self.assertEqual( + server.model_transaction_properties(self._model_name, -1), + expected_transaction_policy, + ) + # metadata + server_meta_data = self._to_pyobject(server.metadata()) + self.assertTrue("name" in server_meta_data) + self.assertEqual(server_meta_data["name"], "testing_server") + # model_metadata + model_meta_data = self._to_pyobject(server.model_metadata(self._model_name, -1)) + self.assertTrue("name" in model_meta_data) + self.assertEqual(model_meta_data["name"], self._model_name) + # model_statistics + model_statistics = self._to_pyobject( + server.model_statistics(self._model_name, -1) + ) + self.assertTrue("model_stats" in model_statistics) + # model_config + model_config = self._to_pyobject(server.model_config(self._model_name, -1, 1)) + self.assertTrue("input" in model_config) + # model_index + model_index = self._to_pyobject(server.model_index(0)) + self.assertEqual(model_index[0]["name"], self._model_name) + # metrics (see test_metrics) + # infer_async (see test_infer_async) + + def test_request(self): + # This test depends on 'TRITONSERVER_Server' operates properly to initialize + # the request + server = self._start_polling_server() + + with self.assertRaises(triton_bindings.NotFoundError) as ctx: + _ = triton_bindings.TRITONSERVER_InferenceRequest( + server, "not_existing_model", -1 + ) + self.assertTrue("unknown model" in str(ctx.exception)) + + expected_request_id = "request" + expected_flags = int( + triton_bindings.TRITONSERVER_RequestFlag.SEQUENCE_START + ) | int(triton_bindings.TRITONSERVER_RequestFlag.SEQUENCE_END) + expected_correlation_id = 2 + expected_correlation_id_string = "123" + expected_priority = 19 + # larger value than model max priority level, + # will be set to default (10, see 'g_python_addsub' for config detail) + expected_priority_uint64 = 67 + expected_timeout_microseconds = 222 + + request = triton_bindings.TRITONSERVER_InferenceRequest(server, "addsub", -1) + + # request metadata + request.id = expected_request_id + self.assertEqual(request.id, expected_request_id) + request.flags = expected_flags + self.assertEqual(request.flags, expected_flags) + request.correlation_id = expected_correlation_id + self.assertEqual(request.correlation_id, expected_correlation_id) + request.correlation_id_string = expected_correlation_id_string + self.assertEqual(request.correlation_id_string, expected_correlation_id_string) + # Expect error from retrieving correlation id in a wrong type, + # wrap in lambda function to avoid early evaluation that raises + # exception before assert + self.assertRaises(triton_bindings.TritonError, lambda: request.correlation_id) + request.priority = expected_priority + self.assertEqual(request.priority, expected_priority) + request.priority_uint64 = expected_priority_uint64 + self.assertEqual(request.priority_uint64, 10) + request.timeout_microseconds = expected_timeout_microseconds + self.assertEqual(request.timeout_microseconds, expected_timeout_microseconds) + + request.set_string_parameter("str_key", "str_val") + request.set_int_parameter("int_key", 567) + request.set_bool_parameter("bool_key", False) + + # I/O + input = numpy.ones([2, 3], dtype=numpy.float32) + buffer = input.ctypes.data + ba = triton_bindings.TRITONSERVER_BufferAttributes() + ba.memory_type = triton_bindings.TRITONSERVER_MemoryType.CPU + ba.memory_type_id = 0 + ba.byte_size = input.itemsize * input.size + + request.add_input( + "INPUT0", triton_bindings.TRITONSERVER_DataType.FP32, input.shape + ) + self.assertRaises(triton_bindings.TritonError, request.remove_input, "INPUT2") + # raw input assumes single input + self.assertRaises(triton_bindings.TritonError, request.add_raw_input, "INPUT1") + request.remove_input("INPUT0") + request.add_raw_input("INPUT1") + request.remove_all_inputs() + # all inputs are removed, all 'append' functions should raise exceptions + aid_args = ["INPUT0", buffer, ba.byte_size, ba.memory_type, ba.memory_type_id] + self.assertRaises( + triton_bindings.TritonError, request.append_input_data, *aid_args + ) + self.assertRaises( + triton_bindings.TritonError, + request.append_input_data_with_host_policy, + *aid_args, + "host_policy_name" + ) + self.assertRaises( + triton_bindings.TritonError, + request.append_input_data_with_buffer_attributes, + "INPUT0", + buffer, + ba, + ) + self.assertRaises( + triton_bindings.TritonError, request.remove_all_input_data, "INPUT0" + ) + # Add back input + request.add_input( + "INPUT0", triton_bindings.TRITONSERVER_DataType.FP32, input.shape + ) + request.append_input_data(*aid_args) + request.remove_all_input_data("INPUT0") + + request.add_requested_output("OUTPUT0") + request.remove_requested_output("OUTPUT1") + request.remove_all_requested_outputs() + + def test_infer_async(self): + # start server + server = self._start_polling_server() + + # prepare for infer + allocator = triton_bindings.TRITONSERVER_ResponseAllocator( + g_alloc_fn, g_release_fn, g_start_fn + ) + allocator.set_buffer_attributes_function(g_buffer_fn) + allocator.set_query_function(g_query_fn) + + request_counter = queue.Queue() + response_queue = queue.Queue() + allocator_counter = {} + request = triton_bindings.TRITONSERVER_InferenceRequest( + server, self._model_name, -1 + ) + request.id = "req_0" + request.set_release_callback(g_request_fn, request_counter) + request.set_response_callback( + allocator, allocator_counter, g_response_fn, response_queue + ) + + input = numpy.ones([4], dtype=numpy.float32) + input_buffer = input.ctypes.data + ba = triton_bindings.TRITONSERVER_BufferAttributes() + ba.memory_type = triton_bindings.TRITONSERVER_MemoryType.CPU + ba.memory_type_id = 0 + ba.byte_size = input.itemsize * input.size + + request.add_input( + "INPUT0", triton_bindings.TRITONSERVER_DataType.FP32, input.shape + ) + request.add_input( + "INPUT1", triton_bindings.TRITONSERVER_DataType.FP32, input.shape + ) + request.append_input_data_with_buffer_attributes("INPUT0", input_buffer, ba) + request.append_input_data_with_buffer_attributes("INPUT1", input_buffer, ba) + + # non-blocking, wait on response complete + server.infer_async(request) + + # Expect every response to be returned in 10 seconds + flags, res = response_queue.get(block=True, timeout=10) + self.assertEqual( + flags, int(triton_bindings.TRITONSERVER_ResponseCompleteFlag.FINAL) + ) + # expect no error + res.throw_if_response_error() + # version will be actual model version + self.assertEqual(res.model, (self._model_name, 1)) + self.assertEqual(res.id, request.id) + self.assertEqual(res.parameter_count, 0) + # out of range access + self.assertRaises(triton_bindings.TritonError, res.parameter, 0) + + # read output tensor + self.assertEqual(res.output_count, 2) + for out, expected_name, expected_data in [ + (res.output(0), "OUTPUT0", input + input), + (res.output(1), "OUTPUT1", input - input), + ]: + ( + name, + data_type, + shape, + out_buffer, + byte_size, + memory_type, + memory_type_id, + numpy_buffer, + ) = out + self.assertEqual(name, expected_name) + self.assertEqual(data_type, triton_bindings.TRITONSERVER_DataType.FP32) + self.assertEqual(shape, expected_data.shape) + self.assertEqual(out_buffer, numpy_buffer.ctypes.data) + # buffer attribute used for input doesn't necessarily to + # match output buffer attributes, this is just knowing the detail. + self.assertEqual(byte_size, ba.byte_size) + self.assertEqual(memory_type, ba.memory_type) + self.assertEqual(memory_type_id, ba.memory_type_id) + self.assertTrue( + numpy.allclose( + numpy_buffer.view(dtype=expected_data.dtype).reshape(shape), + expected_data, + ) + ) + + # label (no label so empty) + self.assertEqual(len(res.output_classification_label(0, 1)), 0) + # [FIXME] keep alive behavior is not established between response + # and server, so must explicitly handle the destruction order for now. + del res + + # sanity check on user objects + self.assertEqual(allocator_counter["start"], 1) + self.assertEqual(allocator_counter["alloc"], 2) + # Knowing implementation detail that the backend doesn't use query API + self.assertTrue("query" not in allocator_counter) + self.assertEqual(allocator_counter["buffer"], 2) + # Expect request to be released in 10 seconds + request = request_counter.get(block=True, timeout=10) + + def test_server_explicit(self): + self._create_model_repository() + # explicit : load with params + options = triton_bindings.TRITONSERVER_ServerOptions() + options.set_model_repository_path(self._test_model_repo) + options.set_model_control_mode( + triton_bindings.TRITONSERVER_ModelControlMode.EXPLICIT + ) + options.set_strict_model_config(False) + server = triton_bindings.TRITONSERVER_Server(options) + load_file_params = [ + triton_bindings.TRITONSERVER_Parameter("config", r"{}"), + triton_bindings.TRITONSERVER_Parameter( + "file:" + os.path.join(self._version, self._file_name), g_python_addsub + ), + ] + server.load_model_with_parameters("wired_addsub", load_file_params) + self.assertTrue(server.model_is_ready("wired_addsub", -1)) + + # Model Repository + self.assertFalse(server.model_is_ready(self._model_name, -1)) + # unregister + server.unregister_model_repository(self._test_model_repo) + self.assertRaises( + triton_bindings.TritonError, server.load_model, self._model_name + ) + # register + server.register_model_repository(self._test_model_repo, []) + server.load_model(self._model_name) + self.assertTrue(server.model_is_ready(self._model_name, -1)) + + # unload + server.unload_model("wired_addsub") + self.assertFalse(server.model_is_ready("wired_addsub", -1)) + server.unload_model_and_dependents(self._model_name) + self.assertFalse(server.model_is_ready(self._model_name, -1)) + + def test_custom_metric(self): + options = triton_bindings.TRITONSERVER_ServerOptions() + options.set_model_repository_path(self._test_model_repo) + options.set_model_control_mode( + triton_bindings.TRITONSERVER_ModelControlMode.EXPLICIT + ) + server = triton_bindings.TRITONSERVER_Server(options) + + # create custom metric + mf = triton_bindings.TRITONSERVER_MetricFamily( + triton_bindings.TRITONSERVER_MetricKind.COUNTER, + "custom_metric_familiy", + "custom metric example", + ) + m = triton_bindings.TRITONSERVER_Metric(mf, []) + m.increment(2) + self.assertEqual(m.kind, triton_bindings.TRITONSERVER_MetricKind.COUNTER) + self.assertEqual(m.value, 2) + # can't use 'set_value' due to wrong kind + self.assertRaises(triton_bindings.TritonError, m.set_value, 5) + + # Check custom metric is reported + metrics = server.metrics() + self.assertTrue( + "custom_metric_familiy" + in metrics.formatted(triton_bindings.TRITONSERVER_MetricFormat.PROMETHEUS) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tritonserver/CMakeLists.txt b/python/tritonserver/CMakeLists.txt new file mode 100644 index 000000000..252849d38 --- /dev/null +++ b/python/tritonserver/CMakeLists.txt @@ -0,0 +1,61 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cmake_minimum_required (VERSION 3.18) + +project(triton-bindings LANGUAGES C CXX) + +# Top level module entry point +file(COPY __init__.py DESTINATION .) +# Copy the '__init__.py' for the '_c' module +file(COPY _c/__init__.py DESTINATION ./_c/.) + +include(FetchContent) +FetchContent_Declare( + pybind11 + GIT_REPOSITORY "https://github.com/pybind/pybind11" + # COMMIT ID for v2.10.0 + GIT_TAG "aa304c9c7d725ffb9d10af08a3b34cb372307020" + GIT_SHALLOW ON +) +FetchContent_MakeAvailable(pybind11) +set( + PYTHON_BINDING_SRCS + _c/tritonserver_pybind.cc +) + +pybind11_add_module(python-bindings SHARED ${PYTHON_BINDING_SRCS}) +target_link_libraries( + python-bindings + PRIVATE + triton-core-serverapi # from repo-core + triton-core-serverstub # from repo-core +) +target_compile_features(python-bindings PRIVATE cxx_std_17) + +set_property(TARGET python-bindings PROPERTY OUTPUT_NAME triton_bindings) +# Add Triton library default path in 'rpath' for runtime library lookup +set_target_properties(python-bindings PROPERTIES BUILD_RPATH "$ORIGIN:/opt/tritonserver/lib") diff --git a/python/tritonserver/__init__.py b/python/tritonserver/__init__.py new file mode 100644 index 000000000..f3a6dc3f0 --- /dev/null +++ b/python/tritonserver/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/python/tritonserver/_c/__init__.py b/python/tritonserver/_c/__init__.py new file mode 100644 index 000000000..c8ec7c698 --- /dev/null +++ b/python/tritonserver/_c/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from .triton_bindings import * diff --git a/python/tritonserver/_c/tritonserver_pybind.cc b/python/tritonserver/_c/tritonserver_pybind.cc new file mode 100644 index 000000000..6e0f39842 --- /dev/null +++ b/python/tritonserver/_c/tritonserver_pybind.cc @@ -0,0 +1,2107 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include +#include +#include +#include + +#include + +// This binding is merely used to map Triton C API into Python equivalent, +// and therefore, the naming will be the same as the one used in corresponding +// sections. However, there are a few exceptions to better transit to Python: +// Structs: +// * Triton structs are encapsulated in a thin wrapper to isolate raw pointer +// operations which is not supported in pure Python. A thin 'PyWrapper' base +// class is defined with common utilities +// * Trivial getters and setters are grouped to be a Python class property. +// However, this creates asymmetry that some APIs are called like function +// while some like member variables. So I am open to expose getter / setter +// if it may be more intuitive. +// * The wrapper is only served as communication between Python and C, it will +// be unwrapped when control reaches C API and the C struct will be wrapped +// when control reaches Python side. Python binding user should respect the +// "ownership" and lifetime of the wrapper in the same way as described in +// the C API. Python binding user must not assume the same C struct will +// always be referred through the same wrapper object. +// Enums: +// * In C API, the enum values are prefixed by the enum name. The Python +// equivalent is an enum class and thus the prefix is removed to avoid +// duplication, i.e. Python user may specify a value by +// 'TRITONSERVER_ResponseCompleteFlag.FINAL'. +// Functions / Callbacks: +// * Output parameters are converted to return value. APIs that return an error +// will be thrown as an exception. The same applies to callbacks. +// ** Note that in the C API, the inference response may carry an error object +// that represent an inference failure. The equivalent Python API will raise +// the corresponding exception if the response contains error object. +// * The function parameters and return values are exposed in Python style, +// for example, object pointer becomes py::object, C array and length +// condenses into Python array. + +namespace py = pybind11; +namespace triton { namespace core { namespace python { + +// Macro used by PyWrapper +#define DISALLOW_COPY(TypeName) TypeName(const TypeName&) = delete; +#define DISALLOW_ASSIGN(TypeName) void operator=(const TypeName&) = delete; +#define DISALLOW_COPY_AND_ASSIGN(TypeName) \ + DISALLOW_COPY(TypeName) \ + DISALLOW_ASSIGN(TypeName) +#define DESTRUCTOR_WITH_LOG(TypeName, DeleteFunction) \ + ~TypeName() \ + { \ + if (owned_ && triton_object_) { \ + auto err__ = (DeleteFunction(triton_object_)); \ + if (err__) { \ + std::shared_ptr managed_err( \ + err__, TRITONSERVER_ErrorDelete); \ + py::print(TRITONSERVER_ErrorMessage(err__)); \ + } \ + }} +// base exception for all Triton error code +struct TritonError : public std::runtime_error { + explicit TritonError(const std::string& what) : std::runtime_error(what) {} +}; + +// triton::core::python exceptions map 1:1 to TRITONSERVER_Error_Code. +struct UnknownError : public TritonError { + explicit UnknownError(const std::string& what) : TritonError(what) {} +}; +struct InternalError : public TritonError { + explicit InternalError(const std::string& what) : TritonError(what) {} +}; +struct NotFoundError : public TritonError { + explicit NotFoundError(const std::string& what) : TritonError(what) {} +}; +struct InvalidArgumentError : public TritonError { + explicit InvalidArgumentError(const std::string& what) : TritonError(what) {} +}; +struct UnavailableError : public TritonError { + explicit UnavailableError(const std::string& what) : TritonError(what) {} +}; +struct UnsupportedError : public TritonError { + explicit UnsupportedError(const std::string& what) : TritonError(what) {} +}; +struct AlreadyExistsError : public TritonError { + explicit AlreadyExistsError(const std::string& what) : TritonError(what) {} +}; + +TRITONSERVER_Error* +CreateTRITONSERVER_ErrorFrom(const py::error_already_set& ex) +{ + // Reserved lookup to get Python type of the exceptions, + // 'TRITONSERVER_ERROR_UNKNOWN' is the fallback error code. + // static auto uk = + // py::module::import("triton_bindings").attr("UnknownError"); + static auto it = py::module::import("triton_bindings").attr("InternalError"); + static auto nf = py::module::import("triton_bindings").attr("NotFoundError"); + static auto ia = + py::module::import("triton_bindings").attr("InvalidArgumentError"); + static auto ua = + py::module::import("triton_bindings").attr("UnavailableError"); + static auto us = + py::module::import("triton_bindings").attr("UnsupportedError"); + static auto ae = + py::module::import("triton_bindings").attr("AlreadyExistsError"); + TRITONSERVER_Error_Code code = TRITONSERVER_ERROR_UNKNOWN; + if (ex.matches(it.ptr())) { + code = TRITONSERVER_ERROR_INTERNAL; + } else if (ex.matches(nf.ptr())) { + code = TRITONSERVER_ERROR_NOT_FOUND; + } else if (ex.matches(ia.ptr())) { + code = TRITONSERVER_ERROR_INVALID_ARG; + } else if (ex.matches(ua.ptr())) { + code = TRITONSERVER_ERROR_UNAVAILABLE; + } else if (ex.matches(us.ptr())) { + code = TRITONSERVER_ERROR_UNSUPPORTED; + } else if (ex.matches(ae.ptr())) { + code = TRITONSERVER_ERROR_ALREADY_EXISTS; + } + return TRITONSERVER_ErrorNew(code, ex.what()); +} + +void +ThrowIfError(TRITONSERVER_Error* err) +{ + if (err == nullptr) { + return; + } + std::shared_ptr managed_err( + err, TRITONSERVER_ErrorDelete); + std::string msg = TRITONSERVER_ErrorMessage(err); + switch (TRITONSERVER_ErrorCode(err)) { + case TRITONSERVER_ERROR_INTERNAL: + throw InternalError(std::move(msg)); + case TRITONSERVER_ERROR_NOT_FOUND: + throw NotFoundError(std::move(msg)); + case TRITONSERVER_ERROR_INVALID_ARG: + throw InvalidArgumentError(std::move(msg)); + case TRITONSERVER_ERROR_UNAVAILABLE: + throw UnavailableError(std::move(msg)); + case TRITONSERVER_ERROR_UNSUPPORTED: + throw UnsupportedError(std::move(msg)); + case TRITONSERVER_ERROR_ALREADY_EXISTS: + throw AlreadyExistsError(std::move(msg)); + default: + throw UnknownError(std::move(msg)); + } +} + +template +class PyWrapper { + public: + explicit PyWrapper(TritonStruct* triton_object, bool owned) + : triton_object_(triton_object), owned_(owned) + { + } + PyWrapper() = default; + // Destructor will be defined per specialization for now as a few + // Triton object delete functions have different signatures, which + // requires a function wrapper to generalize the destructor. + + // Use internally to get the pointer of the underlying Triton object + TritonStruct* Ptr() { return triton_object_; } + + DISALLOW_COPY_AND_ASSIGN(PyWrapper); + + protected: + TritonStruct* triton_object_{nullptr}; + bool owned_{false}; +}; + +class PyParameter : public PyWrapper { + public: + explicit PyParameter(struct TRITONSERVER_Parameter* p, const bool owned) + : PyWrapper(p, owned) + { + } + + PyParameter(const char* name, const std::string& val) + : PyWrapper( + TRITONSERVER_ParameterNew( + name, TRITONSERVER_PARAMETER_STRING, val.c_str()), + true) + { + } + + PyParameter(const char* name, int64_t val) + : PyWrapper( + TRITONSERVER_ParameterNew(name, TRITONSERVER_PARAMETER_INT, &val), + true) + { + } + + PyParameter(const char* name, bool val) + : PyWrapper( + TRITONSERVER_ParameterNew(name, TRITONSERVER_PARAMETER_BOOL, &val), + true) + { + } + + PyParameter(const char* name, const void* byte_ptr, uint64_t size) + : PyWrapper(TRITONSERVER_ParameterBytesNew(name, byte_ptr, size), true) + { + } + + ~PyParameter() + { + if (owned_ && triton_object_) { + TRITONSERVER_ParameterDelete(triton_object_); + } + } +}; + +class PyBufferAttributes + : public PyWrapper { + public: + DESTRUCTOR_WITH_LOG(PyBufferAttributes, TRITONSERVER_BufferAttributesDelete); + + PyBufferAttributes() + { + ThrowIfError(TRITONSERVER_BufferAttributesNew(&triton_object_)); + owned_ = true; + } + + explicit PyBufferAttributes( + struct TRITONSERVER_BufferAttributes* ba, const bool owned) + : PyWrapper(ba, owned) + { + } + + void SetMemoryTypeId(int64_t memory_type_id) + { + ThrowIfError(TRITONSERVER_BufferAttributesSetMemoryTypeId( + triton_object_, memory_type_id)); + } + + void SetMemoryType(TRITONSERVER_MemoryType memory_type) + { + ThrowIfError(TRITONSERVER_BufferAttributesSetMemoryType( + triton_object_, memory_type)); + } + + void SetCudaIpcHandle(uintptr_t cuda_ipc_handle) + { + ThrowIfError(TRITONSERVER_BufferAttributesSetCudaIpcHandle( + triton_object_, reinterpret_cast(cuda_ipc_handle))); + } + + void SetByteSize(size_t byte_size) + { + ThrowIfError( + TRITONSERVER_BufferAttributesSetByteSize(triton_object_, byte_size)); + } + + // Define methods to get buffer attribute fields + int64_t MemoryTypeId() + { + int64_t memory_type_id = 0; + ThrowIfError(TRITONSERVER_BufferAttributesMemoryTypeId( + triton_object_, &memory_type_id)); + return memory_type_id; + } + + TRITONSERVER_MemoryType MemoryType() + { + TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; + ThrowIfError( + TRITONSERVER_BufferAttributesMemoryType(triton_object_, &memory_type)); + return memory_type; + } + + uintptr_t CudaIpcHandle() + { + void* cuda_ipc_handle = nullptr; + ThrowIfError(TRITONSERVER_BufferAttributesCudaIpcHandle( + triton_object_, &cuda_ipc_handle)); + return reinterpret_cast(cuda_ipc_handle); + } + + size_t ByteSize() + { + size_t byte_size; + ThrowIfError( + TRITONSERVER_BufferAttributesByteSize(triton_object_, &byte_size)); + return byte_size; + } +}; + +class PyResponseAllocator + : public PyWrapper { + public: + DESTRUCTOR_WITH_LOG( + PyResponseAllocator, TRITONSERVER_ResponseAllocatorDelete); + + // Callback resource that holds Python user provided buffer and + // Triton C callback wrappers. This struct will be used for both + // 'allocator_userp' and 'buffer_userp' + struct CallbackResource { + CallbackResource(const py::object& a, const py::object& uo) + : allocator(a), user_object(uo) + { + } + // Storing the py::object of PyResponseAllocator to have convenient access + // to callbacks. + py::object allocator; + py::object user_object; + }; + using AllocFn = std::function< + std::tuple( + py::object, std::string, size_t, TRITONSERVER_MemoryType, int64_t, + py::object)>; + using ReleaseFn = std::function; + using StartFn = std::function; + + // size as input, optional? + using QueryFn = std::function( + py::object, py::object, std::string, std::optional, + TRITONSERVER_MemoryType, int64_t)>; + using BufferAttributesFn = std::function; + + PyResponseAllocator(AllocFn alloc, ReleaseFn release) + : alloc_fn_(alloc), release_fn_(release) + { + ThrowIfError(TRITONSERVER_ResponseAllocatorNew( + &triton_object_, PyTritonAllocFn, PyTritonReleaseFn, nullptr)); + owned_ = true; + } + + PyResponseAllocator(AllocFn alloc, ReleaseFn release, StartFn start) + : alloc_fn_(alloc), release_fn_(release), start_fn_(start) + { + ThrowIfError(TRITONSERVER_ResponseAllocatorNew( + &triton_object_, PyTritonAllocFn, PyTritonReleaseFn, PyTritonStartFn)); + owned_ = true; + } + + // Below implements the Triton callbacks, note that when registering the + // callbacks in Triton, an wrapped 'CallbackResource' must be used to bridge + // the gap between the Python API and C API. + static TRITONSERVER_Error* PyTritonAllocFn( + struct TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name, + size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id, void* userp, void** buffer, void** buffer_userp, + TRITONSERVER_MemoryType* actual_memory_type, + int64_t* actual_memory_type_id) + { + py::gil_scoped_acquire gil; + struct TRITONSERVER_Error* err = nullptr; + auto cr = reinterpret_cast(userp); + try { + auto res = cr->allocator.cast()->alloc_fn_( + cr->allocator, tensor_name, byte_size, memory_type, memory_type_id, + cr->user_object); + *buffer = reinterpret_cast(std::get<0>(res)); + { + // In C API usage, its typical to allocate user object within the + // callback and place the release logic in release callback. The same + // logic can't trivially ported to Python as user object is scoped, + // therefore the binding needs to wrap the object to ensure the user + // object will not be garbage collected until after release callback. + *buffer_userp = new CallbackResource(cr->allocator, std::get<1>(res)); + } + *actual_memory_type = std::get<2>(res); + *actual_memory_type_id = std::get<3>(res); + } + catch (py::error_already_set& ex) { + err = CreateTRITONSERVER_ErrorFrom(ex); + } + return err; + } + + static TRITONSERVER_Error* PyTritonReleaseFn( + struct TRITONSERVER_ResponseAllocator* allocator, void* buffer, + void* buffer_userp, size_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id) + { + py::gil_scoped_acquire gil; + struct TRITONSERVER_Error* err = nullptr; + auto cr = reinterpret_cast(buffer_userp); + try { + cr->allocator.cast()->release_fn_( + cr->allocator, reinterpret_cast(buffer), cr->user_object, + byte_size, memory_type, memory_type_id); + } + catch (py::error_already_set& ex) { + err = CreateTRITONSERVER_ErrorFrom(ex); + } + // Done with CallbackResource associated with this buffer + delete cr; + return err; + } + + static TRITONSERVER_Error* PyTritonStartFn( + struct TRITONSERVER_ResponseAllocator* allocator, void* userp) + { + py::gil_scoped_acquire gil; + struct TRITONSERVER_Error* err = nullptr; + auto cr = reinterpret_cast(userp); + try { + cr->allocator.cast()->start_fn_( + cr->allocator, cr->user_object); + } + catch (py::error_already_set& ex) { + err = CreateTRITONSERVER_ErrorFrom(ex); + } + return err; + } + + static TRITONSERVER_Error* PyTritonQueryFn( + struct TRITONSERVER_ResponseAllocator* allocator, void* userp, + const char* tensor_name, size_t* byte_size, + TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) + { + py::gil_scoped_acquire gil; + struct TRITONSERVER_Error* err = nullptr; + auto cr = reinterpret_cast(userp); + try { + std::optional bs; + if (byte_size) { + bs = *byte_size; + } + auto res = cr->allocator.cast()->query_fn_( + cr->allocator, cr->user_object, tensor_name, bs, *memory_type, + *memory_type_id); + *memory_type = std::get<0>(res); + *memory_type_id = std::get<1>(res); + } + catch (py::error_already_set& ex) { + err = CreateTRITONSERVER_ErrorFrom(ex); + } + return err; + } + + static TRITONSERVER_Error* PyTritonBufferAttributesFn( + struct TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name, + struct TRITONSERVER_BufferAttributes* buffer_attributes, void* userp, + void* buffer_userp) + { + py::gil_scoped_acquire gil; + struct TRITONSERVER_Error* err = nullptr; + auto cr = reinterpret_cast(userp); + auto bcr = reinterpret_cast(buffer_userp); + PyBufferAttributes pba{buffer_attributes, false /* owned_ */}; + try { + // Python version of BufferAttributes callback has return value + // to be the filled buffer attributes. The callback implementation + // should modify the passed PyBufferAttributes object and return it. + // However, the implementation may construct new PyBufferAttributes + // which requires additional checking to properly return the attributes + // through C API. + auto res = + cr->allocator.cast()->buffer_attributes_fn_( + cr->allocator, tensor_name, + py::cast(pba, py::return_value_policy::reference), + cr->user_object, bcr->user_object); + // Copy if 'res' is new object, otherwise the attributes have been set. + auto res_pba = res.cast(); + if (res_pba->Ptr() != buffer_attributes) { + pba.SetMemoryTypeId(res_pba->MemoryTypeId()); + pba.SetMemoryType(res_pba->MemoryType()); + pba.SetCudaIpcHandle(res_pba->CudaIpcHandle()); + pba.SetByteSize(res_pba->ByteSize()); + } + } + catch (py::error_already_set& ex) { + err = CreateTRITONSERVER_ErrorFrom(ex); + } + return err; + } + + void SetBufferAttributesFunction(BufferAttributesFn baf) + { + buffer_attributes_fn_ = baf; + ThrowIfError(TRITONSERVER_ResponseAllocatorSetBufferAttributesFunction( + triton_object_, PyTritonBufferAttributesFn)); + } + + void SetQueryFunction(QueryFn qf) + { + query_fn_ = qf; + ThrowIfError(TRITONSERVER_ResponseAllocatorSetQueryFunction( + triton_object_, PyTritonQueryFn)); + } + + private: + AllocFn alloc_fn_{nullptr}; + ReleaseFn release_fn_{nullptr}; + StartFn start_fn_{nullptr}; + QueryFn query_fn_{nullptr}; + BufferAttributesFn buffer_attributes_fn_{nullptr}; +}; + +class PyMessage : public PyWrapper { + public: + DESTRUCTOR_WITH_LOG(PyMessage, TRITONSERVER_MessageDelete); + + PyMessage(const std::string& serialized_json) + { + ThrowIfError(TRITONSERVER_MessageNewFromSerializedJson( + &triton_object_, serialized_json.c_str(), serialized_json.size())); + owned_ = true; + } + + explicit PyMessage(struct TRITONSERVER_Message* m, const bool owned) + : PyWrapper(m, owned) + { + } + + std::string SerializeToJson() + { + const char* base = nullptr; + size_t byte_size = 0; + ThrowIfError( + TRITONSERVER_MessageSerializeToJson(triton_object_, &base, &byte_size)); + return std::string(base, byte_size); + } +}; + +class PyMetrics : public PyWrapper { + public: + DESTRUCTOR_WITH_LOG(PyMetrics, TRITONSERVER_MetricsDelete); + + explicit PyMetrics(struct TRITONSERVER_Metrics* metrics, bool owned) + : PyWrapper(metrics, owned) + { + } + + std::string Formatted(TRITONSERVER_MetricFormat format) + { + const char* base = nullptr; + size_t byte_size = 0; + ThrowIfError(TRITONSERVER_MetricsFormatted( + triton_object_, format, &base, &byte_size)); + return std::string(base, byte_size); + } +}; + +class PyTrace : public PyWrapper { + public: + DESTRUCTOR_WITH_LOG(PyTrace, TRITONSERVER_InferenceTraceDelete); + + using TimestampActivityFn = std::function; + using TensorActivityFn = std::function, + TRITONSERVER_MemoryType, int64_t, py::object)>; + using ReleaseFn = std::function, py::object)>; + + struct CallbackResource { + CallbackResource( + TimestampActivityFn ts, TensorActivityFn t, ReleaseFn r, + const py::object& uo) + : timestamp_fn(ts), tensor_fn(t), release_fn(r), user_object(uo) + { + } + TimestampActivityFn timestamp_fn{nullptr}; + TensorActivityFn tensor_fn{nullptr}; + ReleaseFn release_fn{nullptr}; + py::object user_object; + // The trace API will use the same 'trace_userp' for all traces associated + // with the request, and because there is no guarantee that the root trace + // must be released last, need to track all trace seen / released to + // determine whether this CallbackResource may be released. + std::set seen_traces; + }; + + // Use internally when interacting with C APIs that takes ownership, + // this function will also release the ownership of the callback resource + // because once the ownership is transferred, the callback resource + // will be accessed in the callback pipeline and should not be tied to the + // PyWrapper's lifecycle. The callback resource will be released in the + // Triton C callback wrapper. + struct TRITONSERVER_InferenceTrace* Release() + { + owned_ = false; + callback_resource_.release(); + return triton_object_; + } + + PyTrace( + int level, uint64_t parent_id, TimestampActivityFn timestamp, + ReleaseFn release, const py::object& user_object) + : callback_resource_( + new CallbackResource(timestamp, nullptr, release, user_object)) + { + ThrowIfError(TRITONSERVER_InferenceTraceNew( + &triton_object_, static_cast(level), + parent_id, PyTritonTraceTimestampActivityFn, PyTritonTraceRelease, + callback_resource_.get())); + owned_ = true; + } + + PyTrace( + int level, uint64_t parent_id, TimestampActivityFn timestamp, + TensorActivityFn tensor, ReleaseFn release, const py::object& user_object) + : callback_resource_( + new CallbackResource(timestamp, tensor, release, user_object)) + { + ThrowIfError(TRITONSERVER_InferenceTraceTensorNew( + &triton_object_, static_cast(level), + parent_id, PyTritonTraceTimestampActivityFn, + PyTritonTraceTensorActivityFn, PyTritonTraceRelease, + callback_resource_.get())); + owned_ = true; + } + + explicit PyTrace(struct TRITONSERVER_InferenceTrace* t, const bool owned) + : PyWrapper(t, owned) + { + } + + CallbackResource* ReleaseCallbackResource() + { + return callback_resource_.release(); + } + + uint64_t Id() + { + uint64_t val = 0; + ThrowIfError(TRITONSERVER_InferenceTraceId(triton_object_, &val)); + return val; + } + + uint64_t ParentId() + { + uint64_t val = 0; + ThrowIfError(TRITONSERVER_InferenceTraceParentId(triton_object_, &val)); + return val; + } + + std::string ModelName() + { + const char* val = nullptr; + ThrowIfError(TRITONSERVER_InferenceTraceModelName(triton_object_, &val)); + return val; + } + + int64_t ModelVersion() + { + int64_t val = 0; + ThrowIfError(TRITONSERVER_InferenceTraceModelVersion(triton_object_, &val)); + return val; + } + + std::string RequestId() + { + const char* val = nullptr; + ThrowIfError(TRITONSERVER_InferenceTraceRequestId(triton_object_, &val)); + return val; + } + + // Below implements the Triton callbacks, note that when registering the + // callbacks in Triton, an wrapped 'CallbackResource' must be used to bridge + // the gap between the Python API and C API. + static void PyTritonTraceTimestampActivityFn( + struct TRITONSERVER_InferenceTrace* trace, + TRITONSERVER_InferenceTraceActivity activity, uint64_t timestamp_ns, + void* userp) + { + py::gil_scoped_acquire gil; + // Note that 'trace' associated with the activity is not necessary the + // root trace captured in Callback Resource, so need to always wrap 'trace' + // in PyTrace for the Python callback to interact with the correct trace. + PyTrace pt(trace, false /* owned */); + auto cr = reinterpret_cast(userp); + cr->seen_traces.insert(reinterpret_cast(trace)); + cr->timestamp_fn( + py::cast(pt, py::return_value_policy::reference), activity, + timestamp_ns, cr->user_object); + } + + static void PyTritonTraceTensorActivityFn( + struct TRITONSERVER_InferenceTrace* trace, + TRITONSERVER_InferenceTraceActivity activity, const char* name, + TRITONSERVER_DataType datatype, const void* base, size_t byte_size, + const int64_t* shape, uint64_t dim_count, + TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, void* userp) + { + py::gil_scoped_acquire gil; + // See 'PyTritonTraceTimestampActivityFn' for 'pt' explanation. + PyTrace pt(trace, false /* owned */); + auto cr = reinterpret_cast(userp); + cr->seen_traces.insert(reinterpret_cast(trace)); + cr->tensor_fn( + py::cast(pt, py::return_value_policy::reference), activity, name, + datatype, reinterpret_cast(base), byte_size, + py::array_t(dim_count, shape), memory_type, memory_type_id, + cr->user_object); + } + + static void PyTritonTraceRelease( + struct TRITONSERVER_InferenceTrace* trace, void* userp) + { + py::gil_scoped_acquire gil; + // See 'PyTritonTraceTimestampActivityFn' for 'pt' explanation. + // wrap in shared_ptr to transfer ownership to Python + auto managed_pt = std::make_shared(trace, true /* owned */); + auto cr = reinterpret_cast(userp); + cr->release_fn(managed_pt, cr->user_object); + cr->seen_traces.erase(reinterpret_cast(trace)); + if (cr->seen_traces.empty()) { + delete cr; + } + } + + private: + std::unique_ptr callback_resource_{nullptr}; +}; + +class PyInferenceResponse + : public PyWrapper { + public: + DESTRUCTOR_WITH_LOG( + PyInferenceResponse, TRITONSERVER_InferenceResponseDelete); + + using CompleteFn = std::function; + struct CallbackResource { + CallbackResource( + CompleteFn c, PyResponseAllocator::CallbackResource* a, + const py::object& u) + : complete_fn(c), allocator_resource(a), user_object(u) + { + } + CompleteFn complete_fn; + // During 'TRITONSERVER_InferenceRequestSetResponseCallback', a + // PyResponseAllocator::CallbackResource is allocated and passed as + // 'response_allocator_userp', which is used during any output buffer + // allocation of the requests. However, unlike other 'userp', there is no + // dedicated release callback to signal that the allocator resource may be + // released. So we deduce the point of time is deduced based on the + // following: 'TRITONSERVER_InferenceResponseCompleteFn_t' invoked with + // 'TRITONSERVER_RESPONSE_COMPLETE_FINAL' flag indicates there is no more + // responses to be generated and so does output allocation, therefore + // 'allocator_resource' may be released as part of releasing + // 'PyInferenceResponse::CallbackResource' + PyResponseAllocator::CallbackResource* allocator_resource; + py::object user_object; + }; + + explicit PyInferenceResponse( + struct TRITONSERVER_InferenceResponse* response, bool owned) + : PyWrapper(response, owned) + { + } + + + void ThrowIfResponseError() + { + ThrowIfError(TRITONSERVER_InferenceResponseError(triton_object_)); + } + + std::tuple Model() + { + const char* model_name = nullptr; + int64_t model_version = 0; + ThrowIfError(TRITONSERVER_InferenceResponseModel( + triton_object_, &model_name, &model_version)); + return {model_name, model_version}; + } + + std::string Id() + { + const char* val = nullptr; + ThrowIfError(TRITONSERVER_InferenceResponseId(triton_object_, &val)); + return val; + } + + uint32_t ParameterCount() + { + uint32_t val = 0; + ThrowIfError( + TRITONSERVER_InferenceResponseParameterCount(triton_object_, &val)); + return val; + } + + std::tuple Parameter( + uint32_t index) + { + const char* name = nullptr; + TRITONSERVER_ParameterType type = TRITONSERVER_PARAMETER_STRING; + const void* value = nullptr; + ThrowIfError(TRITONSERVER_InferenceResponseParameter( + triton_object_, index, &name, &type, &value)); + py::object py_value; + switch (type) { + case TRITONSERVER_PARAMETER_STRING: + py_value = py::str(reinterpret_cast(value)); + break; + case TRITONSERVER_PARAMETER_INT: + py_value = py::int_(*reinterpret_cast(value)); + break; + case TRITONSERVER_PARAMETER_BOOL: + py_value = py::bool_(*reinterpret_cast(value)); + break; + default: + throw UnsupportedError( + std::string("Unexpected type '") + + TRITONSERVER_ParameterTypeString(type) + + "' received as response parameter"); + break; + } + return {name, type, py_value}; + } + + uint32_t OutputCount() + { + uint32_t val = 0; + ThrowIfError( + TRITONSERVER_InferenceResponseOutputCount(triton_object_, &val)); + return val; + } + + std::tuple< + std::string, TRITONSERVER_DataType, py::array_t, uintptr_t, + size_t, TRITONSERVER_MemoryType, int64_t, py::object> + Output(uint32_t index) + { + const char* name = nullptr; + TRITONSERVER_DataType datatype = TRITONSERVER_TYPE_INVALID; + const int64_t* shape = nullptr; + uint64_t dim_count = 0; + const void* base = nullptr; + size_t byte_size = 0; + TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU; + int64_t memory_type_id = 0; + void* userp = nullptr; + ThrowIfError(TRITONSERVER_InferenceResponseOutput( + triton_object_, index, &name, &datatype, &shape, &dim_count, &base, + &byte_size, &memory_type, &memory_type_id, &userp)); + return { + name, + datatype, + py::array_t(dim_count, shape), + reinterpret_cast(base), + byte_size, + memory_type, + memory_type_id, + reinterpret_cast(userp) + ->user_object}; + } + + std::string OutputClassificationLabel(uint32_t index, size_t class_index) + { + const char* val = nullptr; + ThrowIfError(TRITONSERVER_InferenceResponseOutputClassificationLabel( + triton_object_, index, class_index, &val)); + return (val == nullptr) ? "" : val; + } +}; + +// forward declaration +class PyServer; + +class PyInferenceRequest + : public PyWrapper { + public: + DESTRUCTOR_WITH_LOG(PyInferenceRequest, TRITONSERVER_InferenceRequestDelete); + + using ReleaseFn = std::function, uint32_t, py::object)>; + + // Defer definition until PyServer is defined + PyInferenceRequest( + PyServer& server, const std::string& model_name, + const int64_t model_version); + + explicit PyInferenceRequest( + struct TRITONSERVER_InferenceRequest* r, const bool owned) + : PyWrapper(r, owned) + { + } + + + // Use internally when interacting with C APIs that takes ownership, + // this function will also release the ownership of the callback resource + // because once the ownership is transferred, the callback resource + // will be accessed in the callback pipeline and should not be tied to the + // PyWrapper's lifecycle. The callback resource will be released in the + // Triton C callback wrapper. + struct TRITONSERVER_InferenceRequest* Release() + { + // Note that Release() doesn't change ownership as the + // same PyInferenceRequest will be passed along the life cycle. + allocator_callback_resource_.release(); + response_callback_resource_.release(); + return triton_object_; + } + + struct CallbackResource { + CallbackResource(ReleaseFn r, const py::object& uo) + : release_fn(r), user_object(uo) + { + } + ReleaseFn release_fn; + py::object user_object; + // Unsafe handling to ensure the same PyInferenceRequest object + // goes through the request release cycle. This is due to + // a 'keep_alive' relationship is built between 'PyInferenceRequest' + // and 'PyServer': a request is associated with a server and the server + // should be kept alive until all associated requests is properly released. + // And here we exploit the 'keep_alive' utility in PyBind to guarantee so. + // See PyServer::InferAsync on how this field is set to avoid potential + // circular inclusion. + std::shared_ptr request; + }; + + + void SetReleaseCallback(ReleaseFn release, const py::object& user_object) + { + request_callback_resource_.reset( + new CallbackResource(release, user_object)); + ThrowIfError(TRITONSERVER_InferenceRequestSetReleaseCallback( + triton_object_, PyTritonRequestReleaseCallback, + request_callback_resource_.get())); + } + + static void PyTritonRequestReleaseCallback( + struct TRITONSERVER_InferenceRequest* request, const uint32_t flags, + void* userp) + { + py::gil_scoped_acquire gil; + auto cr = reinterpret_cast(userp); + cr->release_fn(cr->request, flags, cr->user_object); + delete cr; + } + + void SetResponseCallback( + const py::object& allocator, const py::object& allocater_user_object, + PyInferenceResponse::CompleteFn response, + const py::object& response_user_object) + { + allocator_callback_resource_.reset( + new PyResponseAllocator::CallbackResource( + allocator, allocater_user_object)); + response_callback_resource_.reset(new PyInferenceResponse::CallbackResource( + response, allocator_callback_resource_.get(), response_user_object)); + ThrowIfError(TRITONSERVER_InferenceRequestSetResponseCallback( + triton_object_, allocator.cast()->Ptr(), + allocator_callback_resource_.get(), PyTritonResponseCompleteCallback, + response_callback_resource_.get())); + } + static void PyTritonResponseCompleteCallback( + struct TRITONSERVER_InferenceResponse* response, const uint32_t flags, + void* userp) + { + py::gil_scoped_acquire gil; + auto managed_pt = + std::make_shared(response, true /* owned */); + auto cr = reinterpret_cast(userp); + cr->complete_fn(py::cast(managed_pt), flags, cr->user_object); + if (flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) { + delete cr->allocator_resource; + delete cr; + } + } + + // Trivial setters / getters + void SetId(const std::string& id) + { + ThrowIfError( + TRITONSERVER_InferenceRequestSetId(triton_object_, id.c_str())); + } + std::string Id() + { + const char* val = nullptr; + ThrowIfError(TRITONSERVER_InferenceRequestId(triton_object_, &val)); + return val; + } + + void SetFlags(uint32_t flags) + { + ThrowIfError(TRITONSERVER_InferenceRequestSetFlags(triton_object_, flags)); + } + + uint32_t Flags() + { + uint32_t val = 0; + ThrowIfError(TRITONSERVER_InferenceRequestFlags(triton_object_, &val)); + return val; + } + + void SetCorrelationId(uint64_t correlation_id) + { + ThrowIfError(TRITONSERVER_InferenceRequestSetCorrelationId( + triton_object_, correlation_id)); + } + uint64_t CorrelationId() + { + uint64_t val = 0; + ThrowIfError( + TRITONSERVER_InferenceRequestCorrelationId(triton_object_, &val)); + return val; + } + void SetCorrelationIdString(const std::string& correlation_id) + { + ThrowIfError(TRITONSERVER_InferenceRequestSetCorrelationIdString( + triton_object_, correlation_id.c_str())); + } + std::string CorrelationIdString() + { + const char* val = nullptr; + ThrowIfError( + TRITONSERVER_InferenceRequestCorrelationIdString(triton_object_, &val)); + return val; + } + + void SetPriority(uint32_t priority) + { + ThrowIfError( + TRITONSERVER_InferenceRequestSetPriority(triton_object_, priority)); + } + void SetPriorityUint64(uint64_t priority) + { + ThrowIfError(TRITONSERVER_InferenceRequestSetPriorityUInt64( + triton_object_, priority)); + } + uint32_t Priority() + { + uint32_t val = 0; + ThrowIfError(TRITONSERVER_InferenceRequestPriority(triton_object_, &val)); + return val; + } + uint64_t PriorityUint64() + { + uint64_t val = 0; + ThrowIfError( + TRITONSERVER_InferenceRequestPriorityUInt64(triton_object_, &val)); + return val; + } + + void SetTimeoutMicroseconds(uint64_t timeout_us) + { + ThrowIfError(TRITONSERVER_InferenceRequestSetTimeoutMicroseconds( + triton_object_, timeout_us)); + } + uint64_t TimeoutMicroseconds() + { + uint64_t val = 0; + ThrowIfError( + TRITONSERVER_InferenceRequestTimeoutMicroseconds(triton_object_, &val)); + return val; + } + + void AddInput( + const std::string& name, TRITONSERVER_DataType data_type, + std::vector shape) + { + ThrowIfError(TRITONSERVER_InferenceRequestAddInput( + triton_object_, name.c_str(), data_type, shape.data(), shape.size())); + } + void AddRawInput(const std::string& name) + { + ThrowIfError( + TRITONSERVER_InferenceRequestAddRawInput(triton_object_, name.c_str())); + } + void RemoveInput(const std::string& name) + { + ThrowIfError( + TRITONSERVER_InferenceRequestRemoveInput(triton_object_, name.c_str())); + } + void RemoveAllInputs() + { + ThrowIfError(TRITONSERVER_InferenceRequestRemoveAllInputs(triton_object_)); + } + void AppendInputData( + const std::string& name, uintptr_t base, size_t byte_size, + TRITONSERVER_MemoryType memory_type, int64_t memory_type_id) + { + ThrowIfError(TRITONSERVER_InferenceRequestAppendInputData( + triton_object_, name.c_str(), reinterpret_cast(base), + byte_size, memory_type, memory_type_id)); + } + void AppendInputDataWithHostPolicy( + const std::string name, uintptr_t base, size_t byte_size, + TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, + const std::string& host_policy_name) + { + ThrowIfError(TRITONSERVER_InferenceRequestAppendInputDataWithHostPolicy( + triton_object_, name.c_str(), reinterpret_cast(base), + byte_size, memory_type, memory_type_id, host_policy_name.c_str())); + } + void AppendInputDataWithBufferAttributes( + const std::string& name, uintptr_t base, + PyBufferAttributes* buffer_attributes) + { + ThrowIfError( + TRITONSERVER_InferenceRequestAppendInputDataWithBufferAttributes( + triton_object_, name.c_str(), reinterpret_cast(base), + buffer_attributes->Ptr())); + } + void RemoveAllInputData(const std::string& name) + { + ThrowIfError(TRITONSERVER_InferenceRequestRemoveAllInputData( + triton_object_, name.c_str())); + } + + void AddRequestedOutput(const std::string& name) + { + ThrowIfError(TRITONSERVER_InferenceRequestAddRequestedOutput( + triton_object_, name.c_str())); + } + void RemoveRequestedOutput(const std::string& name) + { + ThrowIfError(TRITONSERVER_InferenceRequestRemoveRequestedOutput( + triton_object_, name.c_str())); + } + void RemoveAllRequestedOutputs() + { + ThrowIfError( + TRITONSERVER_InferenceRequestRemoveAllRequestedOutputs(triton_object_)); + } + + void SetStringParameter(const std::string& key, const std::string& value) + { + ThrowIfError(TRITONSERVER_InferenceRequestSetStringParameter( + triton_object_, key.c_str(), value.c_str())); + } + void SetIntParameter(const std::string& key, int64_t value) + { + ThrowIfError(TRITONSERVER_InferenceRequestSetIntParameter( + triton_object_, key.c_str(), value)); + } + void SetBoolParameter(const std::string& key, bool value) + { + ThrowIfError(TRITONSERVER_InferenceRequestSetBoolParameter( + triton_object_, key.c_str(), value)); + } + + public: + std::unique_ptr request_callback_resource_{nullptr}; + + private: + std::unique_ptr + allocator_callback_resource_{nullptr}; + std::unique_ptr + response_callback_resource_{nullptr}; +}; + +class PyServerOptions : public PyWrapper { + public: + DESTRUCTOR_WITH_LOG(PyServerOptions, TRITONSERVER_ServerOptionsDelete); + PyServerOptions() + { + ThrowIfError(TRITONSERVER_ServerOptionsNew(&triton_object_)); + owned_ = true; + } + + void SetServerId(const std::string& server_id) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetServerId( + triton_object_, server_id.c_str())); + } + + void SetModelRepositoryPath(const std::string& model_repository_path) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetModelRepositoryPath( + triton_object_, model_repository_path.c_str())); + } + + void SetModelControlMode(TRITONSERVER_ModelControlMode mode) + { + ThrowIfError( + TRITONSERVER_ServerOptionsSetModelControlMode(triton_object_, mode)); + } + + void SetStartupModel(const std::string& model_name) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetStartupModel( + triton_object_, model_name.c_str())); + } + + void SetStrictModelConfig(bool strict) + { + ThrowIfError( + TRITONSERVER_ServerOptionsSetStrictModelConfig(triton_object_, strict)); + } + void SetRateLimiterMode(TRITONSERVER_RateLimitMode mode) + { + ThrowIfError( + TRITONSERVER_ServerOptionsSetRateLimiterMode(triton_object_, mode)); + } + + void AddRateLimiterResource( + const std::string& resource_name, size_t resource_count, int device) + { + ThrowIfError(TRITONSERVER_ServerOptionsAddRateLimiterResource( + triton_object_, resource_name.c_str(), resource_count, device)); + } + + void SetPinnedMemoryPoolByteSize(uint64_t size) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetPinnedMemoryPoolByteSize( + triton_object_, size)); + } + + void SetCudaMemoryPoolByteSize(int gpu_device, uint64_t size) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetCudaMemoryPoolByteSize( + triton_object_, gpu_device, size)); + } + void SetResponseCacheByteSize(uint64_t size) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetResponseCacheByteSize( + triton_object_, size)); + } + + void SetCacheConfig( + const std::string& cache_name, const std::string& config_json) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetCacheConfig( + triton_object_, cache_name.c_str(), config_json.c_str())); + } + + void SetCacheDirectory(const std::string& cache_dir) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetCacheDirectory( + triton_object_, cache_dir.c_str())); + } + + void SetMinSupportedComputeCapability(double cc) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetMinSupportedComputeCapability( + triton_object_, cc)); + } + + void SetExitOnError(bool exit) + { + ThrowIfError( + TRITONSERVER_ServerOptionsSetExitOnError(triton_object_, exit)); + } + + void SetStrictReadiness(bool strict) + { + ThrowIfError( + TRITONSERVER_ServerOptionsSetStrictReadiness(triton_object_, strict)); + } + + void SetExitTimeout(unsigned int timeout) + { + ThrowIfError( + TRITONSERVER_ServerOptionsSetExitTimeout(triton_object_, timeout)); + } + void SetBufferManagerThreadCount(unsigned int thread_count) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetBufferManagerThreadCount( + triton_object_, thread_count)); + } + + void SetModelLoadThreadCount(unsigned int thread_count) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetModelLoadThreadCount( + triton_object_, thread_count)); + } + + void SetModelNamespacing(bool enable_namespace) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetModelNamespacing( + triton_object_, enable_namespace)); + } + + void SetLogFile(const std::string& file) + { + ThrowIfError( + TRITONSERVER_ServerOptionsSetLogFile(triton_object_, file.c_str())); + } + + void SetLogInfo(bool log) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetLogInfo(triton_object_, log)); + } + + void SetLogWarn(bool log) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetLogWarn(triton_object_, log)); + } + + void SetLogError(bool log) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetLogError(triton_object_, log)); + } + + void SetLogFormat(TRITONSERVER_LogFormat format) + { + ThrowIfError( + TRITONSERVER_ServerOptionsSetLogFormat(triton_object_, format)); + } + + void SetLogVerbose(int level) + { + ThrowIfError( + TRITONSERVER_ServerOptionsSetLogVerbose(triton_object_, level)); + } + void SetMetrics(bool metrics) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetMetrics(triton_object_, metrics)); + } + + void SetGpuMetrics(bool gpu_metrics) + { + ThrowIfError( + TRITONSERVER_ServerOptionsSetGpuMetrics(triton_object_, gpu_metrics)); + } + + void SetCpuMetrics(bool cpu_metrics) + { + ThrowIfError( + TRITONSERVER_ServerOptionsSetCpuMetrics(triton_object_, cpu_metrics)); + } + + void SetMetricsInterval(uint64_t metrics_interval_ms) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetMetricsInterval( + triton_object_, metrics_interval_ms)); + } + + void SetBackendDirectory(const std::string& backend_dir) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetBackendDirectory( + triton_object_, backend_dir.c_str())); + } + + void SetRepoAgentDirectory(const std::string& repoagent_dir) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetRepoAgentDirectory( + triton_object_, repoagent_dir.c_str())); + } + + void SetModelLoadDeviceLimit( + TRITONSERVER_InstanceGroupKind kind, int device_id, double fraction) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetModelLoadDeviceLimit( + triton_object_, kind, device_id, fraction)); + } + + void SetBackendConfig( + const std::string& backend_name, const std::string& setting, + const std::string& value) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetBackendConfig( + triton_object_, backend_name.c_str(), setting.c_str(), value.c_str())); + } + + void SetHostPolicy( + const std::string& policy_name, const std::string& setting, + const std::string& value) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetHostPolicy( + triton_object_, policy_name.c_str(), setting.c_str(), value.c_str())); + } + + void SetMetricsConfig( + const std::string& name, const std::string& setting, + const std::string& value) + { + ThrowIfError(TRITONSERVER_ServerOptionsSetMetricsConfig( + triton_object_, name.c_str(), setting.c_str(), value.c_str())); + } +}; + +class PyServer : public PyWrapper { + public: + DESTRUCTOR_WITH_LOG(PyServer, TRITONSERVER_ServerDelete); + + PyServer(PyServerOptions& options) + { + ThrowIfError(TRITONSERVER_ServerNew(&triton_object_, options.Ptr())); + owned_ = true; + } + + void Stop() const { ThrowIfError(TRITONSERVER_ServerStop(triton_object_)); } + + void RegisterModelRepository( + const std::string& repository_path, + const std::vector>& name_mapping) const + { + std::vector params; + for (const auto& nm : name_mapping) { + params.emplace_back(nm->Ptr()); + } + ThrowIfError(TRITONSERVER_ServerRegisterModelRepository( + triton_object_, repository_path.c_str(), params.data(), params.size())); + } + + void UnregisterModelRepository(const std::string& repository_path) const + { + ThrowIfError(TRITONSERVER_ServerUnregisterModelRepository( + triton_object_, repository_path.c_str())); + } + + void PollModelRepository() const + { + ThrowIfError(TRITONSERVER_ServerPollModelRepository(triton_object_)); + } + + bool IsLive() const + { + bool live; + ThrowIfError(TRITONSERVER_ServerIsLive(triton_object_, &live)); + return live; + } + + bool IsReady() const + { + bool ready; + ThrowIfError(TRITONSERVER_ServerIsReady(triton_object_, &ready)); + return ready; + } + + bool ModelIsReady(const std::string& model_name, int64_t model_version) const + { + bool ready; + ThrowIfError(TRITONSERVER_ServerModelIsReady( + triton_object_, model_name.c_str(), model_version, &ready)); + return ready; + } + + std::tuple ModelBatchProperties( + const std::string& model_name, int64_t model_version) const + { + uint32_t flags; + void* voidp; + ThrowIfError(TRITONSERVER_ServerModelBatchProperties( + triton_object_, model_name.c_str(), model_version, &flags, &voidp)); + return {flags, reinterpret_cast(voidp)}; + } + + std::tuple ModelTransactionProperties( + const std::string& model_name, int64_t model_version) const + { + uint32_t txn_flags; + void* voidp; + ThrowIfError(TRITONSERVER_ServerModelTransactionProperties( + triton_object_, model_name.c_str(), model_version, &txn_flags, &voidp)); + return {txn_flags, reinterpret_cast(voidp)}; + } + + std::shared_ptr Metadata() const + { + struct TRITONSERVER_Message* server_metadata; + ThrowIfError(TRITONSERVER_ServerMetadata(triton_object_, &server_metadata)); + return std::make_shared(server_metadata, true /* owned */); + } + + std::shared_ptr ModelMetadata( + const std::string& model_name, int64_t model_version) const + { + struct TRITONSERVER_Message* model_metadata; + ThrowIfError(TRITONSERVER_ServerModelMetadata( + triton_object_, model_name.c_str(), model_version, &model_metadata)); + return std::make_shared(model_metadata, true /* owned */); + } + + std::shared_ptr ModelStatistics( + const std::string& model_name, int64_t model_version) const + { + struct TRITONSERVER_Message* model_stats; + ThrowIfError(TRITONSERVER_ServerModelStatistics( + triton_object_, model_name.c_str(), model_version, &model_stats)); + return std::make_shared(model_stats, true /* owned */); + } + + std::shared_ptr ModelConfig( + const std::string& model_name, int64_t model_version, + uint32_t config_version = 1) const + { + struct TRITONSERVER_Message* model_config; + ThrowIfError(TRITONSERVER_ServerModelConfig( + triton_object_, model_name.c_str(), model_version, config_version, + &model_config)); + return std::make_shared(model_config, true /* owned */); + } + + std::shared_ptr ModelIndex(uint32_t flags) const + { + struct TRITONSERVER_Message* model_index; + ThrowIfError( + TRITONSERVER_ServerModelIndex(triton_object_, flags, &model_index)); + return std::make_shared(model_index, true /* owned */); + } + + void LoadModel(const std::string& model_name) + { + // load model is blocking, ensure to release GIL + py::gil_scoped_release release; + ThrowIfError( + TRITONSERVER_ServerLoadModel(triton_object_, model_name.c_str())); + } + + void LoadModelWithParameters( + const std::string& model_name, + const std::vector>& parameters) const + { + std::vector params; + for (const auto& p : parameters) { + params.emplace_back(p->Ptr()); + } + // load model is blocking, ensure to release GIL + py::gil_scoped_release release; + ThrowIfError(TRITONSERVER_ServerLoadModelWithParameters( + triton_object_, model_name.c_str(), params.data(), params.size())); + } + + void UnloadModel(const std::string& model_name) + { + ThrowIfError( + TRITONSERVER_ServerUnloadModel(triton_object_, model_name.c_str())); + } + + void UnloadModelAndDependents(const std::string& model_name) + { + ThrowIfError(TRITONSERVER_ServerUnloadModelAndDependents( + triton_object_, model_name.c_str())); + } + + std::shared_ptr Metrics() const + { + struct TRITONSERVER_Metrics* metrics; + ThrowIfError(TRITONSERVER_ServerMetrics(triton_object_, &metrics)); + return std::make_shared(metrics, true /* owned */); + } + + void InferAsync( + const std::shared_ptr& request, PyTrace& trace) + { + // Extra handling to avoid circular inclusion: + // request -> request_callback_resource_ -> request + // 1. extract 'request_callback_resource_' out and provide + // scoped handler to place resource back to request if not released, + // TRITONSERVER_ServerInferAsync failed in other words. + // 2. add 'request' into resource so request release callback can access it. + // 3. call TRITONSERVER_ServerInferAsync. + // 4. release the extracted resource if TRITONSERVER_ServerInferAsync + // returns. + static auto resource_handler = + [](PyInferenceRequest::CallbackResource* cr) { + if (cr != nullptr) { + cr->request->request_callback_resource_.reset(cr); + cr->request.reset(); + } + }; + std::unique_ptr< + PyInferenceRequest::CallbackResource, decltype(resource_handler)> + scoped_rh( + request->request_callback_resource_.release(), resource_handler); + scoped_rh->request = request; + + ThrowIfError(TRITONSERVER_ServerInferAsync( + triton_object_, request->Ptr(), trace.Ptr())); + // Ownership of the internal C object is transferred. + scoped_rh.release(); + request->Release(); + trace.Release(); + } + + void InferAsync(const std::shared_ptr& request) + { + static auto resource_handler = + [](PyInferenceRequest::CallbackResource* cr) { + if (cr != nullptr) { + cr->request->request_callback_resource_.reset(cr); + cr->request.reset(); + } + }; + std::unique_ptr< + PyInferenceRequest::CallbackResource, decltype(resource_handler)> + scoped_rh( + request->request_callback_resource_.release(), resource_handler); + scoped_rh->request = request; + + ThrowIfError( + TRITONSERVER_ServerInferAsync(triton_object_, request->Ptr(), nullptr)); + // Ownership of the internal C object is transferred. + scoped_rh.release(); + request->Release(); + } +}; + +class PyMetricFamily : public PyWrapper { + public: + DESTRUCTOR_WITH_LOG(PyMetricFamily, TRITONSERVER_MetricFamilyDelete); + + PyMetricFamily( + TRITONSERVER_MetricKind kind, const std::string& name, + const std::string& description) + { + TRITONSERVER_MetricFamilyNew( + &triton_object_, kind, name.c_str(), description.c_str()); + owned_ = true; + } +}; + +class PyMetric : public PyWrapper { + public: + DESTRUCTOR_WITH_LOG(PyMetric, TRITONSERVER_MetricDelete); + PyMetric( + PyMetricFamily& family, + const std::vector>& labels) + { + std::vector params; + for (const auto& label : labels) { + params.emplace_back(label->Ptr()); + } + ThrowIfError(TRITONSERVER_MetricNew( + &triton_object_, family.Ptr(), params.data(), params.size())); + owned_ = true; + } + + double Value() const + { + double val = 0; + ThrowIfError(TRITONSERVER_MetricValue(triton_object_, &val)); + return val; + } + + void Increment(double val) const + { + ThrowIfError(TRITONSERVER_MetricIncrement(triton_object_, val)); + } + + void SetValue(double val) const + { + ThrowIfError(TRITONSERVER_MetricSet(triton_object_, val)); + } + + TRITONSERVER_MetricKind Kind() const + { + TRITONSERVER_MetricKind val = TRITONSERVER_METRIC_KIND_COUNTER; + ThrowIfError(TRITONSERVER_GetMetricKind(triton_object_, &val)); + return val; + } +}; + +// Deferred definitions.. +PyInferenceRequest::PyInferenceRequest( + PyServer& server, const std::string& model_name, + const int64_t model_version) +{ + ThrowIfError(TRITONSERVER_InferenceRequestNew( + &triton_object_, server.Ptr(), model_name.c_str(), model_version)); + owned_ = true; +} + +// [FIXME] module name? +PYBIND11_MODULE(triton_bindings, m) +{ + m.doc() = "Python bindings for Triton Inference Server"; + + // [FIXME] if dynamic linking, should have version check here as well to + // make sure the binding is compatible with the Triton library loaded + m.def("api_version", []() { + uint32_t major = 0, minor = 0; + ThrowIfError(TRITONSERVER_ApiVersion(&major, &minor)); + return py::make_tuple(major, minor); + }); + + // TRITONSERVER_Error... converted to 'TritonError' exception + // Implement exception inheritance in PyBind: + // https://github.com/jagerman/pybind11/blob/master/tests/test_exceptions.cpp#L149-L152 + auto te = pybind11::register_exception(m, "TritonError"); + pybind11::register_exception(m, "UnknownError", te.ptr()); + pybind11::register_exception(m, "InternalError", te.ptr()); + pybind11::register_exception(m, "NotFoundError", te.ptr()); + pybind11::register_exception( + m, "InvalidArgumentError", te.ptr()); + pybind11::register_exception( + m, "UnavailableError", te.ptr()); + pybind11::register_exception( + m, "UnsupportedError", te.ptr()); + pybind11::register_exception( + m, "AlreadyExistsError", te.ptr()); + + // TRITONSERVER_DataType + py::enum_(m, "TRITONSERVER_DataType") + .value("INVALID", TRITONSERVER_TYPE_INVALID) + .value("BOOL", TRITONSERVER_TYPE_BOOL) + .value("UINT8", TRITONSERVER_TYPE_UINT8) + .value("UINT16", TRITONSERVER_TYPE_UINT16) + .value("UINT32", TRITONSERVER_TYPE_UINT32) + .value("UINT64", TRITONSERVER_TYPE_UINT64) + .value("INT8", TRITONSERVER_TYPE_INT8) + .value("INT16", TRITONSERVER_TYPE_INT16) + .value("INT32", TRITONSERVER_TYPE_INT32) + .value("INT64", TRITONSERVER_TYPE_INT64) + .value("FP16", TRITONSERVER_TYPE_FP16) + .value("FP32", TRITONSERVER_TYPE_FP32) + .value("FP64", TRITONSERVER_TYPE_FP64) + .value("BYTES", TRITONSERVER_TYPE_BYTES) + .value("BF16", TRITONSERVER_TYPE_BF16); + // helper functions + m.def("TRITONSERVER_DataTypeString", [](TRITONSERVER_DataType datatype) { + return TRITONSERVER_DataTypeString(datatype); + }); + m.def("TRITONSERVER_StringToDataType", [](const char* dtype) { + return TRITONSERVER_StringToDataType(dtype); + }); + m.def("TRITONSERVER_DataTypeByteSize", [](TRITONSERVER_DataType datatype) { + return TRITONSERVER_DataTypeByteSize(datatype); + }); + + // TRITONSERVER_MemoryType + py::enum_(m, "TRITONSERVER_MemoryType") + .value("CPU", TRITONSERVER_MEMORY_CPU) + .value("CPU_PINNED", TRITONSERVER_MEMORY_CPU_PINNED) + .value("GPU", TRITONSERVER_MEMORY_GPU); + // helper functions + m.def("TRITONSERVER_MemoryTypeString", [](TRITONSERVER_MemoryType memtype) { + return TRITONSERVER_MemoryTypeString(memtype); + }); + + // TRITONSERVER_ParameterType + py::enum_(m, "TRITONSERVER_ParameterType") + .value("STRING", TRITONSERVER_PARAMETER_STRING) + .value("INT", TRITONSERVER_PARAMETER_INT) + .value("BOOL", TRITONSERVER_PARAMETER_BOOL) + .value("BYTES", TRITONSERVER_PARAMETER_BYTES); + // helper functions + m.def( + "TRITONSERVER_ParameterTypeString", + [](TRITONSERVER_ParameterType paramtype) { + return TRITONSERVER_ParameterTypeString(paramtype); + }); + // TRITONSERVER_Parameter + py::class_>( + m, "TRITONSERVER_Parameter") + // Python bytes can be consumed by function accepting string, so order + // the py::bytes constructor before string to ensure correct overload + // constructor is used + .def(py::init([](const char* name, py::bytes bytes) { + // [FIXME] does not own 'bytes' in the same way as C API, but can also + // hold 'bytes' to make sure it will not be invalidated while in use. + // i.e. safe to perform + // a = triton_bindings.TRITONSERVER_Parameter("abc", b'abc') + // # 'a' still points to valid buffer at this line. + // Note that even holding 'bytes', it is the user's responsibility not + // to modify 'bytes' while the parameter is in use. + py::buffer_info info(py::buffer(bytes).request()); + return std::make_unique(name, info.ptr, info.size); + })) + .def(py::init()) + .def(py::init()) + .def(py::init()); + + // TRITONSERVER_InstanceGroupKind + py::enum_(m, "TRITONSERVER_InstanceGroupKind") + .value("AUTO", TRITONSERVER_INSTANCEGROUPKIND_AUTO) + .value("CPU", TRITONSERVER_INSTANCEGROUPKIND_CPU) + .value("GPU", TRITONSERVER_INSTANCEGROUPKIND_GPU) + .value("MODEL", TRITONSERVER_INSTANCEGROUPKIND_MODEL); + m.def( + "TRITONSERVER_InstanceGroupKindString", + [](TRITONSERVER_InstanceGroupKind kind) { + return TRITONSERVER_InstanceGroupKindString(kind); + }); + + // TRITONSERVER_Log + py::enum_(m, "TRITONSERVER_LogLevel") + .value("INFO", TRITONSERVER_LOG_INFO) + .value("WARN", TRITONSERVER_LOG_WARN) + .value("ERROR", TRITONSERVER_LOG_ERROR) + .value("VERBOSE", TRITONSERVER_LOG_VERBOSE); + + py::enum_(m, "TRITONSERVER_LogFormat") + .value("DEFAULT", TRITONSERVER_LOG_DEFAULT) + .value("ISO8601", TRITONSERVER_LOG_ISO8601); + + m.def("TRITONSERVER_LogIsEnabled", [](TRITONSERVER_LogLevel level) { + return TRITONSERVER_LogIsEnabled(level); + }); + m.def( + "TRITONSERVER_LogMessage", + [](TRITONSERVER_LogLevel level, const char* filename, const int line, + const char* msg) { + ThrowIfError(TRITONSERVER_LogMessage(level, filename, line, msg)); + }); + + py::class_(m, "TRITONSERVER_BufferAttributes") + .def(py::init<>()) + .def_property( + "memory_type_id", &PyBufferAttributes::MemoryTypeId, + &PyBufferAttributes::SetMemoryTypeId) + .def_property( + "memory_type", &PyBufferAttributes::MemoryType, + &PyBufferAttributes::SetMemoryType) + .def_property( + "cuda_ipc_handle", &PyBufferAttributes::CudaIpcHandle, + &PyBufferAttributes::SetCudaIpcHandle) + .def_property( + "byte_size", &PyBufferAttributes::ByteSize, + &PyBufferAttributes::SetByteSize); + + py::class_(m, "TRITONSERVER_ResponseAllocator") + .def( + py::init< + PyResponseAllocator::AllocFn, PyResponseAllocator::ReleaseFn, + PyResponseAllocator::StartFn>(), + py::arg("alloc_function"), py::arg("release_function"), + py::arg("start_function")) + .def( + py::init< + PyResponseAllocator::AllocFn, PyResponseAllocator::ReleaseFn>(), + py::arg("alloc_function"), py::arg("release_function")) + .def( + "set_buffer_attributes_function", + &PyResponseAllocator::SetBufferAttributesFunction, + py::arg("buffer_attributes_function")) + .def( + "set_query_function", &PyResponseAllocator::SetQueryFunction, + py::arg("query_function")); + + // TRITONSERVER_Message + py::class_>(m, "TRITONSERVER_Message") + .def(py::init()) + .def("serialize_to_json", &PyMessage::SerializeToJson); + + // TRITONSERVER_Metrics + py::enum_(m, "TRITONSERVER_MetricFormat") + .value("PROMETHEUS", TRITONSERVER_METRIC_PROMETHEUS); + py::class_>(m, "TRITONSERVER_Metrics") + .def("formatted", &PyMetrics::Formatted); + + // TRITONSERVER_InferenceTrace + py::enum_( + m, "TRITONSERVER_InferenceTraceLevel") + .value("DISABLED", TRITONSERVER_TRACE_LEVEL_DISABLED) + .value("MIN", TRITONSERVER_TRACE_LEVEL_MIN) + .value("MAX", TRITONSERVER_TRACE_LEVEL_MAX) + .value("TIMESTAMPS", TRITONSERVER_TRACE_LEVEL_TIMESTAMPS) + .value("TENSORS", TRITONSERVER_TRACE_LEVEL_TENSORS) + .export_values(); + m.def( + "TRITONSERVER_InferenceTraceLevelString", + &TRITONSERVER_InferenceTraceLevelString); + py::enum_( + m, "TRITONSERVER_InferenceTraceActivity") + .value("REQUEST_START", TRITONSERVER_TRACE_REQUEST_START) + .value("QUEUE_START", TRITONSERVER_TRACE_QUEUE_START) + .value("COMPUTE_START", TRITONSERVER_TRACE_COMPUTE_START) + .value("COMPUTE_INPUT_END", TRITONSERVER_TRACE_COMPUTE_INPUT_END) + .value("COMPUTE_OUTPUT_START", TRITONSERVER_TRACE_COMPUTE_OUTPUT_START) + .value("COMPUTE_END", TRITONSERVER_TRACE_COMPUTE_END) + .value("REQUEST_END", TRITONSERVER_TRACE_REQUEST_END) + .value("TENSOR_QUEUE_INPUT", TRITONSERVER_TRACE_TENSOR_QUEUE_INPUT) + .value("TENSOR_BACKEND_INPUT", TRITONSERVER_TRACE_TENSOR_BACKEND_INPUT) + .value("TENSOR_BACKEND_OUTPUT", TRITONSERVER_TRACE_TENSOR_BACKEND_OUTPUT) + .export_values(); + m.def( + "TRITONSERVER_InferenceTraceActivityString", + &TRITONSERVER_InferenceTraceActivityString); + py::class_>( + m, "TRITONSERVER_InferenceTrace") + .def( + py::init< + int, uint64_t, PyTrace::TimestampActivityFn, + PyTrace::TensorActivityFn, PyTrace::ReleaseFn, + const py::object&>(), + py::arg("level"), py::arg("parent_id"), py::arg("activity_function"), + py::arg("tensor_activity_function"), py::arg("release_function"), + py::arg("trace_userp")) + .def( + py::init< + int, uint64_t, PyTrace::TimestampActivityFn, PyTrace::ReleaseFn, + const py::object&>(), + py::arg("level"), py::arg("parent_id"), py::arg("activity_function"), + py::arg("release_function"), py::arg("trace_userp")) + .def_property_readonly("id", &PyTrace::Id) + .def_property_readonly("parent_id", &PyTrace::ParentId) + .def_property_readonly("model_name", &PyTrace::ModelName) + .def_property_readonly("model_version", &PyTrace::ModelVersion) + .def_property_readonly("request_id", &PyTrace::RequestId); + + // TRITONSERVER_InferenceRequest + py::enum_(m, "TRITONSERVER_RequestFlag") + .value("SEQUENCE_START", TRITONSERVER_REQUEST_FLAG_SEQUENCE_START) + .value("SEQUENCE_END", TRITONSERVER_REQUEST_FLAG_SEQUENCE_END) + .export_values(); + py::enum_( + m, "TRITONSERVER_RequestReleaseFlag") + .value("ALL", TRITONSERVER_REQUEST_RELEASE_ALL) + .export_values(); + + py::class_>( + m, "TRITONSERVER_InferenceRequest") + .def( + py::init(), + py::keep_alive<1, 2>()) + .def("set_release_callback", &PyInferenceRequest::SetReleaseCallback) + .def("set_response_callback", &PyInferenceRequest::SetResponseCallback) + .def_property("id", &PyInferenceRequest::Id, &PyInferenceRequest::SetId) + .def_property( + "flags", &PyInferenceRequest::Flags, &PyInferenceRequest::SetFlags) + .def_property( + "correlation_id", &PyInferenceRequest::CorrelationId, + &PyInferenceRequest::SetCorrelationId) + .def_property( + "correlation_id_string", &PyInferenceRequest::CorrelationIdString, + &PyInferenceRequest::SetCorrelationIdString) + .def_property( + "priority", &PyInferenceRequest::Priority, + &PyInferenceRequest::SetPriority) + .def_property( + "priority_uint64", &PyInferenceRequest::PriorityUint64, + &PyInferenceRequest::SetPriorityUint64) + .def_property( + "timeout_microseconds", &PyInferenceRequest::TimeoutMicroseconds, + &PyInferenceRequest::SetTimeoutMicroseconds) + .def("add_input", &PyInferenceRequest::AddInput) + .def("add_raw_input", &PyInferenceRequest::AddRawInput) + .def("remove_input", &PyInferenceRequest::RemoveInput) + .def("remove_all_inputs", &PyInferenceRequest::RemoveAllInputs) + .def("append_input_data", &PyInferenceRequest::AppendInputData) + .def( + "append_input_data_with_host_policy", + &PyInferenceRequest::AppendInputDataWithHostPolicy) + .def( + "append_input_data_with_buffer_attributes", + &PyInferenceRequest::AppendInputDataWithBufferAttributes) + .def("remove_all_input_data", &PyInferenceRequest::RemoveAllInputData) + .def("add_requested_output", &PyInferenceRequest::AddRequestedOutput) + .def( + "remove_requested_output", &PyInferenceRequest::RemoveRequestedOutput) + .def( + "remove_all_requested_outputs", + &PyInferenceRequest::RemoveAllRequestedOutputs) + .def("set_string_parameter", &PyInferenceRequest::SetStringParameter) + .def("set_int_parameter", &PyInferenceRequest::SetIntParameter) + .def("set_bool_parameter", &PyInferenceRequest::SetBoolParameter); + + // TRITONSERVER_InferenceResponse + py::enum_( + m, "TRITONSERVER_ResponseCompleteFlag") + .value("FINAL", TRITONSERVER_RESPONSE_COMPLETE_FINAL) + .export_values(); + py::class_>( + m, "TRITONSERVER_InferenceResponse") + .def( + "throw_if_response_error", &PyInferenceResponse::ThrowIfResponseError) + .def_property_readonly("model", &PyInferenceResponse::Model) + .def_property_readonly("id", &PyInferenceResponse::Id) + .def_property_readonly( + "parameter_count", &PyInferenceResponse::ParameterCount) + .def("parameter", &PyInferenceResponse::Parameter) + .def_property_readonly("output_count", &PyInferenceResponse::OutputCount) + .def("output", &PyInferenceResponse::Output) + .def( + "output_classification_label", + &PyInferenceResponse::OutputClassificationLabel); + + // TRITONSERVER_ServerOptions + py::enum_(m, "TRITONSERVER_ModelControlMode") + .value("NONE", TRITONSERVER_MODEL_CONTROL_NONE) + .value("POLL", TRITONSERVER_MODEL_CONTROL_POLL) + .value("EXPLICIT", TRITONSERVER_MODEL_CONTROL_EXPLICIT); + py::enum_(m, "TRITONSERVER_RateLimitMode") + .value("OFF", TRITONSERVER_RATE_LIMIT_OFF) + .value("EXEC_COUNT", TRITONSERVER_RATE_LIMIT_EXEC_COUNT); + py::class_(m, "TRITONSERVER_ServerOptions") + .def(py::init<>()) + .def("set_server_id", &PyServerOptions::SetServerId) + .def( + "set_model_repository_path", &PyServerOptions::SetModelRepositoryPath) + .def("set_model_control_mode", &PyServerOptions::SetModelControlMode) + .def("set_startup_model", &PyServerOptions::SetStartupModel) + .def("set_strict_model_config", &PyServerOptions::SetStrictModelConfig) + .def("set_rate_limiter_mode", &PyServerOptions::SetRateLimiterMode) + .def( + "add_rate_limiter_resource", &PyServerOptions::AddRateLimiterResource) + .def( + "set_pinned_memory_pool_byte_size", + &PyServerOptions::SetPinnedMemoryPoolByteSize) + .def( + "set_cuda_memory_pool_byte_size", + &PyServerOptions::SetCudaMemoryPoolByteSize) + .def( + "set_response_cache_byte_size", + &PyServerOptions::SetResponseCacheByteSize) + .def("set_cache_config", &PyServerOptions::SetCacheConfig) + .def("set_cache_directory", &PyServerOptions::SetCacheDirectory) + .def( + "set_min_supported_compute_capability", + &PyServerOptions::SetMinSupportedComputeCapability) + .def("set_exit_on_error", &PyServerOptions::SetExitOnError) + .def("set_strict_readiness", &PyServerOptions::SetStrictReadiness) + .def("set_exit_timeout", &PyServerOptions::SetExitTimeout) + .def( + "set_buffer_manager_thread_count", + &PyServerOptions::SetBufferManagerThreadCount) + .def( + "set_model_load_thread_count", + &PyServerOptions::SetModelLoadThreadCount) + .def("set_model_namespacing", &PyServerOptions::SetModelNamespacing) + .def("set_log_file", &PyServerOptions::SetLogFile) + .def("set_log_info", &PyServerOptions::SetLogInfo) + .def("set_log_warn", &PyServerOptions::SetLogWarn) + .def("set_log_error", &PyServerOptions::SetLogError) + .def("set_log_format", &PyServerOptions::SetLogFormat) + .def("set_log_verbose", &PyServerOptions::SetLogVerbose) + .def("set_metrics", &PyServerOptions::SetMetrics) + .def("set_gpu_metrics", &PyServerOptions::SetGpuMetrics) + .def("set_cpu_metrics", &PyServerOptions::SetCpuMetrics) + .def("set_metrics_interval", &PyServerOptions::SetMetricsInterval) + .def("set_backend_directory", &PyServerOptions::SetBackendDirectory) + .def("set_repo_agent_directory", &PyServerOptions::SetRepoAgentDirectory) + .def( + "set_model_load_device_limit", + &PyServerOptions::SetModelLoadDeviceLimit) + .def("set_backend_config", &PyServerOptions::SetBackendConfig) + .def("set_host_policy", &PyServerOptions::SetHostPolicy) + .def("set_metrics_config", &PyServerOptions::SetMetricsConfig); + + // TRITONSERVER_Server + py::enum_(m, "TRITONSERVER_ModelBatchFlag") + .value("UNKNOWN", TRITONSERVER_BATCH_UNKNOWN) + .value("FIRST_DIM", TRITONSERVER_BATCH_FIRST_DIM) + .export_values(); + py::enum_(m, "TRITONSERVER_ModelIndexFlag") + .value("READY", TRITONSERVER_INDEX_FLAG_READY) + .export_values(); + py::enum_( + m, "TRITONSERVER_ModelTxnPropertyFlag") + .value("ONE_TO_ONE", TRITONSERVER_TXN_ONE_TO_ONE) + .value("DECOUPLED", TRITONSERVER_TXN_DECOUPLED) + .export_values(); + py::class_(m, "TRITONSERVER_Server") + .def(py::init()) + .def("stop", &PyServer::Stop) + .def("register_model_repository", &PyServer::RegisterModelRepository) + .def("unregister_model_repository", &PyServer::UnregisterModelRepository) + .def("poll_model_repository", &PyServer::PollModelRepository) + .def("poll_model_repository", &PyServer::PollModelRepository) + .def("is_live", &PyServer::IsLive) + .def("is_ready", &PyServer::IsReady) + .def("model_is_ready", &PyServer::ModelIsReady) + .def("model_batch_properties", &PyServer::ModelBatchProperties) + .def( + "model_transaction_properties", &PyServer::ModelTransactionProperties) + .def("metadata", &PyServer::Metadata) + .def("model_metadata", &PyServer::ModelMetadata) + .def("model_statistics", &PyServer::ModelStatistics) + .def("model_config", &PyServer::ModelConfig) + .def("model_index", &PyServer::ModelIndex) + .def("load_model", &PyServer::LoadModel) + .def("load_model_with_parameters", &PyServer::LoadModelWithParameters) + .def("unload_model", &PyServer::UnloadModel) + .def("unload_model_and_dependents", &PyServer::UnloadModelAndDependents) + .def("metrics", &PyServer::Metrics) + .def( + "infer_async", + py::overload_cast< + const std::shared_ptr&, PyTrace&>( + &PyServer::InferAsync)) + .def( + "infer_async", + py::overload_cast&>( + &PyServer::InferAsync)); + + // TRITONSERVER_MetricKind + py::enum_(m, "TRITONSERVER_MetricKind") + .value("COUNTER", TRITONSERVER_METRIC_KIND_COUNTER) + .value("GAUGE", TRITONSERVER_METRIC_KIND_GAUGE); + // TRITONSERVER_MetricFamily + py::class_(m, "TRITONSERVER_MetricFamily") + .def(py::init< + TRITONSERVER_MetricKind, const std::string&, const std::string&>()); + // TRITONSERVER_Metric + py::class_(m, "TRITONSERVER_Metric") + .def( + py::init< + PyMetricFamily&, + const std::vector>&>(), + py::keep_alive<1, 2>()) + .def_property_readonly("value", &PyMetric::Value) + .def("increment", &PyMetric::Increment) + .def("set_value", &PyMetric::SetValue) + .def_property_readonly("kind", &PyMetric::Kind); +} + +}}} // namespace triton::core::python diff --git a/src/tritonserver.cc b/src/tritonserver.cc index c9fc49fc4..a4fbf4e52 100644 --- a/src/tritonserver.cc +++ b/src/tritonserver.cc @@ -606,6 +606,8 @@ TRITONSERVER_ParameterTypeString(TRITONSERVER_ParameterType paramtype) return "INT"; case TRITONSERVER_PARAMETER_BOOL: return "BOOL"; + case TRITONSERVER_PARAMETER_BYTES: + return "BYTES"; default: break; }