-
Notifications
You must be signed in to change notification settings - Fork 1
Torch Function Implementation
Based on NEP 18, this document summarises the implementation of torch_function protocol for PyTorch APIs.
- Add a dispatch decorator to Pytorch methods.
- The dispatcher then verifies function signature. It checks if the args supplied and args implemented are same, then the code is further executed.
- Next step is generating the source of code compiling it to Python and injecting it as Public API.
- The code generation checks for overloaded args. This means if the arg supplied has torch_function define and it handles the torch.operator API, it would use that implementation else it falls back to Torch implemetation. Note that the return type would depend on the type of first arg supplied. See NEP - 18 for more such details.
The dispatcher when used as a decorator around a Torch method treats it differently. It checks if the arguments provided by the dispatcher and those accepted by the Torch method match. Then it exposes it as a public API.
The implemetation is actually handled by implement_torch_function
. implement_torch_function
then checks if the arg has been overloaded with a custom function. If not overloaded use the torch API that is exposed as public else use the implementation provided by the overloaded arg.
def torch_function_dispatch(dispatcher, module=None, verify=True,
docs_from_dispatcher=False):
def decorator(implementation):
if verify:
verify_matching_signatures(implementation, dispatcher)
if docs_from_dispatcher:
add_docstring(implementation, dispatcher.__doc__)
# Equivalently, we could define this function directly instead of using
# exec. This version has the advantage of giving the helper function a
# more interpretable name. Otherwise, the original function does not
# show up at all in many cases, e.g., if it's written in C++ or if the
# dispatcher gets an invalid keyword argument.
source = _wrapped_func_source.format(name=implementation.__name__)
source_object = compile(
source, filename='<__torch_function__ internals>', mode='exec')
scope = {
'implementation': implementation,
'dispatcher': dispatcher,
'functools': functools,
'implement_torch_function': implement_torch_function,
}
exec(source_object, scope)
public_api = scope[implementation.__name__]
if module is not None:
public_api.__module__ = module
public_api._implementation = implementation
return public_api
return decorator
A dispatcher takes a list of args passed and returns them as the members of a tuple.
For example:
def gemm_dispatcher(input, mat2, out=None):
return (input, mat2, out)
verify_matching_signatures
check if the args passed by the dispatcher are the same as expected by the original implementation in methods provided by the Torch library.
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
def verify_matching_signatures(implementation, dispatcher):
"""Verify that a dispatcher function has the right signature."""
implementation_spec = ArgSpec(*getargspec(implementation))
dispatcher_spec = ArgSpec(*getargspec(dispatcher))
if (implementation_spec.args != dispatcher_spec.args or
implementation_spec.varargs != dispatcher_spec.varargs or
implementation_spec.keywords != dispatcher_spec.keywords or
(bool(implementation_spec.defaults) !=
bool(dispatcher_spec.defaults)) or
(implementation_spec.defaults is not None and
len(implementation_spec.defaults) !=
len(dispatcher_spec.defaults))):
raise RuntimeError('implementation and dispatcher for %s have '
'different function signatures' % implementation)
if implementation_spec.defaults is not None:
if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
raise RuntimeError('dispatcher functions can only use None for '
'default argument values')
If the signatures match wrapped_func_source generates Python code corresponding to the implementation, compiles it and exposes it as Public API.
_wrapped_func_source = textwrap.dedent("""
@functools.wraps(implementation)
def {name}(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
return implement_torch_function(
implementation, {name}, relevant_args, args, kwargs)
""")
The
def implement_torch_function(
implementation, public_api, relevant_args, args, kwargs):
# Check for __torch_function__ methods.
types, overloaded_args = get_overloaded_types_and_args(relevant_args)
# Short-cut for common cases: no overload or only Tensor overload
# (directly or with subclasses that do not override __torch_function__).
if (not overloaded_args or types == _TENSOR_ONLY or
all(type(arg).__torch_function__ is _TORCH_FUNCTION
for arg in overloaded_args)):
return implementation(*args, **kwargs)
# Call overrides
for overloaded_arg in overloaded_args:
# Use `public_api` instead of `implemenation` so __torch_function__
# implementations can do equality/identity comparisons.
result = overloaded_arg.__torch_function__(
public_api, types, args, kwargs)
if result is not NotImplemented:
return result
func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
raise TypeError("no implementation found for '{}' on types that implement "
'__torch_function__: {}'
.format(func_name, list(map(type, overloaded_args))))
Thecks if the args supplied and args implemented are same.
The test can be found here .
The benchmark code was added in this commit.
In torch/crsc/autograd/generated/python_torch_functions.cpp
we try to inject our code.
static PyObject * THPVariable_mean(PyObject* self_, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
std::cout << "hello world from mean!" << std::endl;
static PythonArgParser parser({
"mean(Tensor input, *, ScalarType? dtype=None)",
"mean(Tensor input, IntArrayRef[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor out=None)",
}, /*traceable=*/true);
ParsedArgs<5> parsed_args;
auto r = parser.parse2(args, kwargs, parsed_args);
# check if r.torch_function_dispatch == true and then look for r.tensor_like.HANDLED_FUNCTIONS[r.function_name]
# return call(r.tensor_like[r.function_name], args, kwargs);
if (r.idx == 0) {
return wrap(dispatch_mean(r.tensor(0), r.scalartypeOptional(1)));
} else if (r.idx == 1) {
if (r.isNone(4)) {
return wrap(dispatch_mean(r.tensor(0), r.intlist(1), r.toBool(2), r.scalartypeOptional(3)));
} else {
return wrap(dispatch_mean(r.tensor(0), r.intlist(1), r.toBool(2), r.scalartypeOptional(3), r.tensor(4)));
}
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
template<int N>
inline PythonArgs PythonArgParser::parse2(PyObject* args, PyObject* kwargs, ParsedArgs<N>& dst) {
if (N < max_args) {
throw ValueError("PythonArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)",
(int)max_args, N);
}
return raw_parse2(args, kwargs, dst.args);
}
PythonArgParser::PythonArgParser(std::vector<std::string> fmts, bool traceable)
: max_args(0)
, traceable(traceable)
{
std::cout << "fmts is => " << fmts << std::endl;
for (auto& fmt : fmts) {
std::cout << "fmt=> " << fmt << std::endl;
signatures_.emplace_back(fmt);
}
for (auto& signature : signatures_) {
if (signature.max_args > max_args) {
max_args = signature.max_args;
}
}
if (signatures_.size() > 0) {
function_name = signatures_[0].name;
std::cout << "function_name is => " << function_name << std::endl;
}
}
PythonArgs PythonArgParser::raw_parse2(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) {
std::cout << "In PythonArgParser::raw_parse" << std::endl;
if (signatures_.size() == 1) {
auto& signature = signatures_[0];
signature.parse2(args, kwargs, parsed_args, true);
auto x = PythonArgs(0, traceable, signature, parsed_args);
return x;
}
int i = 0;
for (auto& signature : signatures_) {
if (signature.parse2(args, kwargs, parsed_args, false)) {
auto x = PythonArgs(i, traceable, signature, parsed_args);
return x;
}
i++;
}
print_error(args, kwargs, parsed_args);
}
Here all the signatures are validated from args and kwargs. Add a check if we detect a tensor like PyObject with torch_function defined. Then collect all such arguments in an overloaded_args and overloaded_types list. Note that the inserting into overloaded_args list needs to check subclass.
bool FunctionSignature::parse2(PyObject* args, PyObject* kwargs, PyObject* dst[],
bool raise_exception) {
std::cout << "FunctionSignature::parse2, Trying to find out torch_Function" << std::endl;
auto nargs = PyTuple_GET_SIZE(args);
std::cout << "nargs ->" << nargs << std::endl;
ssize_t remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0;
ssize_t arg_pos = 0;
bool allow_varargs_intlist = false;
// if there is a single positional IntArrayRef argument, i.e. expand(..), view(...),
// allow a var-args style IntArrayRef, so expand(5,3) behaves as expand((5,3))
if (max_pos_args == 1 && params[0].type_ == ParameterType::INT_LIST) {
allow_varargs_intlist = true;
}
if (nargs > max_pos_args && !allow_varargs_intlist) {
if (raise_exception) {
// foo() takes takes 2 positional arguments but 3 were given
extra_args(*this, nargs);
}
return false;
}
int i = 0;
for (auto& param : params) {
PyObject* obj = nullptr;
bool is_kwd = false;
if (arg_pos < nargs) {
// extra positional args given after single positional IntArrayRef arg
if (param.keyword_only) {
if (raise_exception) {
extra_args(*this, nargs);
}
return false;
}
obj = PyTuple_GET_ITEM(args, arg_pos);
} else if (kwargs) {
obj = PyDict_GetItem(kwargs, param.python_name);
for (PyObject *numpy_name: param.numpy_python_names) {
if (obj) {
break;
}
obj = PyDict_GetItem(kwargs, numpy_name);
}
is_kwd = true;
}
if ((!obj && param.optional) || (obj == Py_None && param.allow_none)) {
dst[i++] = nullptr;
} else if (!obj) {
if (raise_exception) {
// foo() missing 1 required positional argument: "b"
missing_args(*this, i);
}
return false;
} else if (param.check2(obj, args, kwargs)) {
dst[i++] = obj;
// XXX: the Variable check is necessary because sizes become tensors when
// tracer is enabled. This behavior easily leads to ambiguities, and we
// should avoid having complex signatures that make use of it...
} else if (allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
THPUtils_checkIndex(obj)) {
// take all positional arguments as this parameter
// e.g. permute(1, 2, 3) -> permute((1, 2, 3))
dst[i++] = args;
arg_pos = nargs;
continue;
} else if (raise_exception) {
if (is_kwd) {
// foo(): argument 'other' must be str, not int
throw TypeError("%s(): argument '%s' must be %s, not %s",
name.c_str(), param.name.c_str(), param.type_name().c_str(),
Py_TYPE(obj)->tp_name);
} else {
// foo(): argument 'other' (position 2) must be str, not int
throw TypeError("%s(): argument '%s' (position %d) must be %s, not %s",
name.c_str(), param.name.c_str(), arg_pos + 1,
param.type_name().c_str(), Py_TYPE(obj)->tp_name);
}
} else {
return false;
}
if (!is_kwd) {
arg_pos++;
} else if (obj) {
remaining_kwargs--;
}
}
if (remaining_kwargs > 0) {
if (raise_exception) {
// foo() got an unexpected keyword argument "b"
extra_kwargs(*this, kwargs, nargs);
}
return false;
}
return true;
}
Once you have it figured out you need to look into gen_python_functions.py
PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
std::cout << "hello world!" << std::endl; // added this to check
static PythonArgParser parser({
${signatures}
}, /*traceable=*/${traceable});
${unpack_self}
ParsedArgs<${max_args}> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
${declare_namedtuple_return_types}
${dispatch} // modify dispatch to check for torch function and call torch
// function or use central dispatch machinery. See the example
// below
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
""")
PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self_, PyObject* args)
{
HANDLE_TH_ERRORS
${declare_namedtuple_return_types}
${unpack_self}
return wrap(${namedtuple_return_type}${dispatch_name}(${actuals}));
END_HANDLE_TH_ERRORS
}
""")
Now we would need to modify ${dispatch}
.
static PyObject * THPVariable_mean(PyObject* self_, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
std::cout << "hello world from mean!" << std::endl;
static PythonArgParser parser({
"mean(Tensor input, *, ScalarType? dtype=None)",
"mean(Tensor input, IntArrayRef[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor out=None)",
}, /*traceable=*/true);
ParsedArgs<5> parsed_args;
auto r = parser.parse2(args, kwargs, parsed_args);
std::cout << "parsed and got r" << std::endl;
if(r.has_torch_function()){
std::cout << "Found torch_function" << std::endl;
PyObject* handled_functions = maybe_get_attr(r.get_overloaded_arg(0), "__torch_function__");
// How to get handled_functions[get torch.mean]
return PyObject_CallFunctionObjArgs(handled_functions, PyUnicode_FromString(r.get_func_name().data()), args, kwargs, NULL);
}
else{
std::cout << "Not found torch_function" << std::endl;
if (r.idx == 0) {
return wrap(dispatch_mean(r.tensor(0), r.scalartypeOptional(1)));
} else if (r.idx == 1) {
if (r.isNone(4)) {
return wrap(dispatch_mean(r.tensor(0), r.intlist(1), r.toBool(2), r.scalartypeOptional(3)));
} else {
return wrap(dispatch_mean(r.tensor(0), r.intlist(1), r.toBool(2), r.scalartypeOptional(3), r.tensor(4)));
}
}
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
The goal here is to generate something like the code snippet provided above.