Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FFI] Propagate Python errors across FFI boundaries #15596

Merged
merged 20 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,13 @@ typedef void* TVMObjectHandle;
*/
TVM_DLL void TVMAPISetLastError(const char* msg);

/*!
* \brief Used for implementing C API function.
* Set last exception before return.
* \param py_object The python exception to be set
*/
TVM_DLL void TVMAPISetLastPythonError(void* py_object);

/*!
* \brief return str message of the last error
* all function in this file will return 0 when success
Expand Down
45 changes: 45 additions & 0 deletions include/tvm/runtime/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,51 @@ namespace runtime {
*/
TVM_DLL void EnvCheckSignals();

/*! \brief A class that wraps a Python object and preserves its ownership.

* This class is used to wrap a PyObject* from the Python API and preserve its ownership.
* Allows for the creation of strong references to Python objects, which prevent them from being
* garbage-collected as long as the wrapper object exists.
*/
class WrappedPythonObject {
public:
/*! \brief Construct a wrapper that doesn't own anything */
WrappedPythonObject() : python_obj_(nullptr) {}

/*! \brief Conversion constructor from nullptr */
explicit WrappedPythonObject(std::nullptr_t) : python_obj_(nullptr) {}

/*! \brief Take ownership of a python object
*
* A new strong reference is created for the underlying python
* object.
*
* \param python_obj A PyObject* from the Python.h API. A new
* strong reference is created using Py_IncRef.
*/
explicit WrappedPythonObject(void* python_obj);

/*! \brief Drop ownership of a python object
*
* Removes the strong reference held by the wrapper.
*/
~WrappedPythonObject();

WrappedPythonObject(WrappedPythonObject&&);
WrappedPythonObject& operator=(WrappedPythonObject&&);

WrappedPythonObject(const WrappedPythonObject&);
WrappedPythonObject& operator=(const WrappedPythonObject&);
WrappedPythonObject& operator=(std::nullptr_t);

operator bool() { return python_obj_; }

void* raw_pointer() { return python_obj_; }

private:
void* python_obj_ = nullptr;
};

/*! \brief Registry for global function */
class Registry {
public:
Expand Down
26 changes: 18 additions & 8 deletions python/tvm/_ffi/_ctypes/packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import traceback
from numbers import Number, Integral

from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
from ..base import _LIB, get_last_ffi_error, py2cerror, check_call, raise_last_ffi_error
from ..base import c_str, string_types
from ..runtime_ctypes import DataType, TVMByteArray, Device, ObjectRValueRef
from . import ndarray as _nd
Expand Down Expand Up @@ -80,10 +80,11 @@ def cfun(args, type_codes, num_args, ret, _):
# pylint: disable=broad-except
try:
rv = local_pyfunc(*pyargs)
except Exception:
except Exception as err:
msg = traceback.format_exc()
msg = py2cerror(msg)
_LIB.TVMAPISetLastError(c_str(msg))
_LIB.TVMAPISetLastPythonError(ctypes.py_object(err))

return -1

if rv is not None:
Expand All @@ -94,7 +95,7 @@ def cfun(args, type_codes, num_args, ret, _):
if not isinstance(ret, TVMRetValueHandle):
ret = TVMRetValueHandle(ret)
if _LIB.TVMCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1)) != 0:
raise get_last_ffi_error()
raise_last_ffi_error()
_ = temp_args
_ = rv
return 0
Expand All @@ -106,7 +107,7 @@ def cfun(args, type_codes, num_args, ret, _):
pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj)
if _LIB.TVMFuncCreateFromCFunc(f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0:
raise get_last_ffi_error()
raise_last_ffi_error()
return _make_packed_func(handle, False)


Expand Down Expand Up @@ -212,7 +213,7 @@ def __init__(self, handle, is_global):
def __del__(self):
if not self.is_global and _LIB is not None:
if _LIB.TVMFuncFree(self.handle) != 0:
raise get_last_ffi_error()
raise_last_ffi_error()

def __call__(self, *args):
"""Call the function with positional arguments
Expand All @@ -235,7 +236,7 @@ def __call__(self, *args):
)
!= 0
):
raise get_last_ffi_error()
raise_last_ffi_error()
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
Expand All @@ -258,7 +259,7 @@ def __init_handle_by_constructor__(fconstructor, args):
)
!= 0
):
raise get_last_ffi_error()
raise_last_ffi_error()
_ = temp_args
_ = args
assert ret_tcode.value == ArgTypeCode.OBJECT_HANDLE
Expand Down Expand Up @@ -333,3 +334,12 @@ def _set_class_object_generic(object_generic_class, func_convert_to_object):
global _FUNC_CONVERT_TO_OBJECT
_CLASS_OBJECT_GENERIC = object_generic_class
_FUNC_CONVERT_TO_OBJECT = func_convert_to_object


def _init_pythonapi_inc_def_ref():
register_func = _LIB.TVMBackendRegisterEnvCAPI
register_func(c_str("Py_IncRef"), ctypes.pythonapi.Py_IncRef)
register_func(c_str("Py_DecRef"), ctypes.pythonapi.Py_DecRef)


_init_pythonapi_inc_def_ref()
5 changes: 3 additions & 2 deletions python/tvm/_ffi/_cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from ..base import get_last_ffi_error
from ..base import raise_last_ffi_error
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule
Expand Down Expand Up @@ -113,6 +113,7 @@ ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle)
# We mark the possibly long running function as nogil below.
cdef extern from "tvm/runtime/c_runtime_api.h":
void TVMAPISetLastError(const char* msg)
void TVMAPISetLastPythonError(void* py_object) except +
const char *TVMGetLastError()
int TVMFuncGetGlobal(const char* name,
TVMPackedFuncHandle* out)
Expand Down Expand Up @@ -178,7 +179,7 @@ cdef inline int CHECK_CALL(int ret) except -2:
if ret == -2:
return -2
if ret != 0:
raise get_last_ffi_error()
raise_last_ffi_error()
return 0


Expand Down
19 changes: 17 additions & 2 deletions python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ cdef int tvm_callback(TVMValue* args,
pyargs.append(c_make_array(value.v_handle, True, False))
try:
rv = local_pyfunc(*pyargs)
except Exception:
except Exception as err:
msg = traceback.format_exc()
msg = py2cerror(msg)
TVMAPISetLastError(c_str(msg))
TVMAPISetLastPythonError(<void*>err)

return -1
if rv is not None:
if isinstance(rv, tuple):
Expand Down Expand Up @@ -368,3 +369,17 @@ def _set_class_object_generic(object_generic_class, func_convert_to_object):
global _FUNC_CONVERT_TO_OBJECT
_CLASS_OBJECT_GENERIC = object_generic_class
_FUNC_CONVERT_TO_OBJECT = func_convert_to_object

# Py_INCREF and Py_DECREF are C macros, not function objects.
# Therefore, providing a wrapper function.
cdef void _py_incref_wrapper(void* py_object):
Py_INCREF(<object>py_object)
cdef void _py_decref_wrapper(void* py_object):
Py_DECREF(<object>py_object)

def _init_pythonapi_inc_def_ref():
register_func = TVMBackendRegisterEnvCAPI
register_func(c_str("Py_IncRef"), <void*>_py_incref_wrapper)
register_func(c_str("Py_DecRef"), <void*>_py_decref_wrapper)

_init_pythonapi_inc_def_ref()
149 changes: 146 additions & 3 deletions python/tvm/_ffi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@
# coding: utf-8
# pylint: disable=invalid-name, import-outside-toplevel
"""Base library for TVM FFI."""
import sys
import os
import ctypes
import functools
import os
import re
import sys
import types

from typing import Callable, Sequence

import numpy as np

from . import libinfo

# ----------------------------
Expand Down Expand Up @@ -333,6 +340,142 @@ def get_last_ffi_error():
return ERROR_TYPE.get(err_type, TVMError)(py_err_msg)


def _append_traceback_frame(tb, func_name, filepath, lineno):
"""Append a dummy frame to appear in the Python traceback"""

# Compile a dummy function to Python bytecode, so that with the
# filepath that we want to appear in the traceback. Any external
# debugger (e.g. pdb) that catches the exception will use the
# filepath to show code snippets from that FFI file.
code = compile(
"{}def dummy_func(): raise NotImplementedError()".format("\n" * (lineno - 1)),
filepath,
"exec",
)

# Replacing the name by updating the bytecode allows the function
# name to be values that would normally be forbidden by python
# syntax. For example, "operator()".
code = code.replace(co_consts=(code.co_consts[0].replace(co_name=func_name), func_name, None))
namespace = {}
exec(code, namespace) # pylint: disable=exec-used
dummy_func = namespace["dummy_func"]

# Execute the dummy function in order to generate a stack frame.
dummy_tb = None
try:
dummy_func()
except NotImplementedError as err:
dummy_tb = err.__traceback__

# Insert the dummy function into the stack trace.
new_frame = dummy_tb.tb_next
return types.TracebackType(tb, new_frame.tb_frame, new_frame.tb_lasti, new_frame.tb_lineno)


def _filter_traceback_frames(tb, filter_funcs: Sequence[Callable[[types.CodeType], bool]]):
orig = tb
filtered_at_least_one = False
temp_all_frames = []
filtered_frames = []

while tb is not None:
frame_code = tb.tb_frame.f_code
should_remove = any(filter_func(frame_code) for filter_func in filter_funcs)
if not should_remove:
filtered_at_least_one = True
filtered_frames.append(tb)
temp_all_frames.append(tb)
tb = tb.tb_next

if not filtered_at_least_one:
return orig

def _append_frame(tb, next_tb_frame):
return types.TracebackType(
tb, next_tb_frame.tb_frame, next_tb_frame.tb_lasti, next_tb_frame.tb_lineno
)

new_tb = functools.reduce(_append_frame, reversed(filtered_frames))

return new_tb


def raise_last_ffi_error():
"""Raise the previous error from FFI

This should be used instead of `raise get_last_ffi_error()`, as it
handle propagation of errors across an FFI boundary. For example,
if Python passes a callback to a C++ function, and the callback
raises an exception, the re-thrown exception should contain the
full stack trace, not just the stack frames that are above the
outermost FFI call.
"""

_LIB.TVMGetLastPythonError.restype = ctypes.c_void_p
_LIB.TVMGetLastBacktrace.restype = ctypes.c_char_p
py_err = _LIB.TVMGetLastPythonError()
if py_err is None:
c_err_msg = py_str(_LIB.TVMGetLastError())
py_err_msg, err_type = c2pyerror(c_err_msg)
if err_type is not None and err_type.startswith("tvm.error."):
err_type = err_type[10:]
py_err = ERROR_TYPE.get(err_type, TVMError)(py_err_msg)

else:
# TVMGetLastPythonError returns a PyObject*, with NULL when
# there is no such value. If we annotated the restype as
# ctypes.py_object, we would need to return Py_None from the
# C++ implementation. This would require introducing a
# dependency on libpython that we want to avoid when not in a
# Python environment. Therefore, casting the resulting void*
# pointer to PyObject* using ctypes.
py_err = ctypes.cast(ctypes.c_void_p(py_err), ctypes.py_object).value

tb = py_err.__traceback__

# The py_err.__traceback__ only goes from the location thrown
# up to the next FFI handoff. To have the stacktrace also
# include the C++ side, we need to adjust the __traceback__
# before re-throwing.
backtrace = _LIB.TVMGetLastBacktrace()
if backtrace:
frames = re.split(r"\n\W+\d+:\W+", py_str(backtrace))
frames = frames[1:] # Skip "Stack trace: "

for frame in frames:
if " at " in frame:
func_name, frame = frame.split(" at ", 1)
filename, lineno = frame.rsplit(":", 1)
func_name = func_name.strip()
filename = filename.strip()
lineno = int(lineno.strip())

tb = _append_traceback_frame(tb, func_name, filename, lineno)

# Remove stack frames that provide little benefit to
# debugging. These are only removed from the stack frames
# contained within the exception we are re-raising, and not to
# the stack frames that it will continue to collect.
# Therefore, there may still be a single instance of these
# frames in the outermost Python-to-FFI call.
filter_funcs = [
lambda code: "tvm/_ffi/_ctypes/packed_func.py" in code.co_filename,
lambda code: "tvm/_ffi/base.py" in code.co_filename,
]
tb = _filter_traceback_frames(tb, filter_funcs)

py_err = py_err.with_traceback(tb)

# The exception PyObject may contain a large amount of state,
# including all stack frames that may be inspected in a later
# PDB post-mortem. Therefore, we must make sure to remove the
# underlying PyObject* from the C++ side after we retrieve it.
_LIB.TVMDropLastPythonError()

raise py_err


def check_call(ret):
"""Check the return value of C API call

Expand All @@ -345,4 +488,4 @@ def check_call(ret):
return value from API calls
"""
if ret != 0:
raise get_last_ffi_error()
raise_last_ffi_error()
Loading