From 3a4167f486b5058a34a19947fb0b17091d6e7fa9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 25 Aug 2023 13:49:24 -0500 Subject: [PATCH] Propagate python errors through cython FFI handling --- include/tvm/runtime/c_runtime_api.h | 7 +++++++ python/tvm/_ffi/_cython/base.pxi | 5 +++-- python/tvm/_ffi/_cython/packed_func.pxi | 18 +++++++++++++++++- 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 36ae5c6b158e7..43cf499481080 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -244,6 +244,13 @@ typedef void* TVMObjectHandle; */ TVM_DLL void TVMAPISetLastError(const char* msg); +/*! + * \brief Used for implementing C API function. + * Set last exception before return. + * \param py_object The python exception to be set + */ +TVM_DLL void TVMAPISetLastPythonError(void* py_object); + /*! * \brief return str message of the last error * all function in this file will return 0 when success diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index c2c06674978dd..69e1355f7d130 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from ..base import get_last_ffi_error +from ..base import raise_last_ffi_error from libcpp.vector cimport vector from cpython.version cimport PY_MAJOR_VERSION from cpython cimport pycapsule @@ -113,6 +113,7 @@ ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle) # We mark the possibly long running function as nogil below. cdef extern from "tvm/runtime/c_runtime_api.h": void TVMAPISetLastError(const char* msg) + void TVMAPISetLastPythonError(void* py_object) except + const char *TVMGetLastError() int TVMFuncGetGlobal(const char* name, TVMPackedFuncHandle* out) @@ -178,7 +179,7 @@ cdef inline int CHECK_CALL(int ret) except -2: if ret == -2: return -2 if ret != 0: - raise get_last_ffi_error() + raise_last_ffi_error() return 0 diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 7c9ef51bd6f89..24e4d877674d1 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -54,10 +54,12 @@ cdef int tvm_callback(TVMValue* args, pyargs.append(c_make_array(value.v_handle, True, False)) try: rv = local_pyfunc(*pyargs) - except Exception: + except Exception as err: msg = traceback.format_exc() msg = py2cerror(msg) TVMAPISetLastError(c_str(msg)) + TVMAPISetLastPythonError(err) + return -1 if rv is not None: if isinstance(rv, tuple): @@ -368,3 +370,17 @@ def _set_class_object_generic(object_generic_class, func_convert_to_object): global _FUNC_CONVERT_TO_OBJECT _CLASS_OBJECT_GENERIC = object_generic_class _FUNC_CONVERT_TO_OBJECT = func_convert_to_object + +# Py_INCREF and Py_DECREF are C macros, not function objects. +# Therefore, providing a wrapper function. +cdef void _py_incref_wrapper(void* py_object): + Py_INCREF(py_object) +cdef void _py_decref_wrapper(void* py_object): + Py_DECREF(py_object) + +def _init_pythonapi_inc_def_ref(): + register_func = TVMBackendRegisterEnvCAPI + register_func(c_str("Py_IncRef"), _py_incref_wrapper) + register_func(c_str("Py_DecRef"), _py_decref_wrapper) + +_init_pythonapi_inc_def_ref()