Skip to content
This repository has been archived by the owner on Jun 15, 2023. It is now read-only.

[WIP] Generic jitclass #3

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions numba_extras/jitclass/boxing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import typing_extensions as typing
from typing import Type

from numba import types

from numba.core.extending import (
box,
unbox,
NativeValue,
)

from numba.experimental.structref import _Utils

# copied from numba.experimental.structref.
# Replaced c.pyapi.call_function_objargs with c.pyapi.call in box_struct_ref
def define_boxing(struct_type: Type, obj_class: Type):
"""Define the boxing & unboxing logic for `struct_type` to `obj_class`.

Defines both boxing and unboxing.

- boxing turns an instance of `struct_type` into a PyObject of `obj_class`
- unboxing turns an instance of `obj_class` into an instance of
`struct_type` in jit-code.


Use this directly instead of `define_proxy()` when the user does not
want any constructor to be defined.
"""
if struct_type is types.StructRef:
raise ValueError(f"cannot register {types.StructRef}")

obj_ctor = obj_class._numba_box_

@box(struct_type)
def box_struct_ref(typ, val, c):
"""
Convert a raw pointer to a Python int.
"""
# import pdb; pdb.set_trace()
utils = _Utils(c.context, c.builder, typ)
struct_ref = utils.get_struct_ref(val)
meminfo = struct_ref.meminfo

mip_type = types.MemInfoPointer(types.voidptr)
boxed_meminfo = c.box(mip_type, meminfo)

ctor_pyfunc = c.pyapi.unserialize(c.pyapi.serialize_object(obj_ctor))
ty_pyobj = c.pyapi.unserialize(c.pyapi.serialize_object(typ))

args = c.pyapi.tuple_pack([ty_pyobj, boxed_meminfo])
res = c.pyapi.call(ctor_pyfunc, args)

c.pyapi.decref(args)
c.pyapi.decref(ctor_pyfunc)
c.pyapi.decref(ty_pyobj)
c.pyapi.decref(boxed_meminfo)
return res

@unbox(struct_type)
def unbox_struct_ref(typ, obj, c):
# import pdb; pdb.set_trace()
mi_obj = c.pyapi.object_getattr_string(obj, "_meminfo")

mip_type = types.MemInfoPointer(types.voidptr)

mi = c.unbox(mip_type, mi_obj).value

utils = _Utils(c.context, c.builder, typ)
struct_ref = utils.new_struct_ref(mi)
out = struct_ref._getvalue()

c.pyapi.decref(mi_obj)
return NativeValue(out)
276 changes: 276 additions & 0 deletions numba_extras/jitclass/class_descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
from collections import namedtuple

import typing_extensions as typing
from typing import (
Dict,
Type,
Optional,
Any,
Callable,
ClassVar,
Generic,
List,
Mapping,
Tuple,
TypeVar,
)

from numba import types, njit, typed
from numba.experimental import structref
from numba.experimental.structref import _Utils

from numba.core.extending import box, overload
from numba.core.imputils import lower_constant

from numba_extras.jitclass.typing_utils import (
_GenericAlias,
get_annotated_members,
get_parameters,
resolve_members,
MappedParameters,
MembersDict,
MethodsDict,
ResolvedMembersList,
NType,
)
from numba_extras.jitclass.typemap import python_numba_type_map, _check_arguments
from numba_extras.jitclass.overload_utils import (
get_methods,
wrap_and_jit,
make_function,
overload_methods,
make_overload,
make_constructor,
)
from numba_extras.jitclass.boxing import define_boxing
from numba_extras.jitclass.common import _Params


class ClassDescriptor:
init: Callable
ref_type: Type
proxy_type: Type
parameters: Tuple[TypeVar, ...]
members: MembersDict
wrapped_methods: MethodsDict
original_methods: MethodsDict
specificized: Dict[Tuple[Type, ...], Callable]

def __init__(self, jitclass, cls: Type, params: _Params):
members = params.members

if not members:
members = get_annotated_members(cls)

methods = get_methods(cls)
if "__init__" not in methods:
methods["__init__"] = lambda self: None

init = methods["__init__"]

if "__new__" in methods:
new = methods["__new__"]
if new is Generic.__new__ or new is object.__new__:
del methods["__new__"]

if "__new__" in methods:
raise NotImplementedError("Custom __new__ is not supported")

wrapped_methods = wrap_and_jit(methods)
parameters = get_parameters(cls)
name = cls.__name__

# self is not properly initialized yet. Anything except capturing may result in wierd things
ref_meta, proxy_meta = jitclass.make_ref_and_proxy_metas(
name, cls, wrapped_methods, members, params, methods, self
)
define_boxing(ref_meta, proxy_meta)

from numba.core import cgutils

ref_meta_inst = ref_meta({})

@lower_constant(ref_meta_inst)
def lower(context, builder, ty, pyval):
import pdb

pdb.set_trace()
obj = cgutils.create_struct_proxy(typ)(context, builder)
return obj._getvalue()

# init = self.init
self.init = None
meta_ctor = self.__make_constructor(ref_meta_inst)
self._meta_ctor = lambda cls, *args, **kwargs: meta_ctor()
# self.init = None

ref_cls, proxy_cls = jitclass.make_ref_and_proxy_types(
name, cls, wrapped_methods, members, params, methods, proxy_meta, self
)

define_boxing(ref_cls, proxy_cls)
# TODO more args to 'init'

self.init = init # type: ignore
self.ref_type = ref_cls
self.proxy_type = proxy_cls
self.parameters = parameters
self.members = members
self.wrapped_methods = wrapped_methods
self.original_methods = methods
self.specificized = {}

if len(parameters) > 0:
overload(proxy_cls)(
make_function(
"init", "self", "raise NotImplementedError('Not implemented')", {}
)
)
else:
ctor = self.specificize([])
typ = ctor.__numba_class_type__
from numba.core.imputils import lower_builtin

from numba.core.typing.templates import builtin_registry

def get(glbls, func):
for key, value in glbls:
if key == func:
return value

return None

ref_cls.__numba_call_impl__ = get(
builtin_registry.globals, methods["__call__"]
)

overload(proxy_cls, strict=False)(
make_function(
"init",
"*args, **kwargs",
"return ctor",
{"ctor": ctor.__original_func__},
)
)

overload_methods(methods, ref_cls)

if "__call__" in methods:
from numba.core.typing.templates import builtin_registry

def get(glbls, func):
for key, value in glbls:
if key == func:
return value

return None

ref_cls.__numba_call_impl__ = get(
builtin_registry.globals, methods["__call__"]
)
from numba.core import utils

def _get_signature(ovld, typingctx, fnty, args, kws):
sig = fnty.get_call_type(typingctx, args, kws)
sig = sig.replace(pysig=utils.pysignature(ovld))
return sig

@lower_builtin(typ, typ, types.VarArg(types.Any))
def method_impl(context, builder, sig, args):
typ = sig.args[0]
typing_context = context.typing_context
func = typ.__numba_call_impl__
fnty = typing_context.resolve_value_type(func)
sig = _get_signature(func, typing_context, fnty, sig.args, {})
call = context.get_function(fnty, sig)
# Link dependent library
context.add_linking_libs(getattr(call, "libs", ()))
return call(builder, args)

if "__len__" in methods:
ref_cls.__numba_len_impl__ = methods["__len__"]

self.__add_type_mapping(cls, proxy_cls)

def __add_type_mapping(self, cls, proxy_cls):
def construct(args):
_check_arguments(str(cls) + " constructor", len(self.parameters), args)
typ = self.__specificize_type(args)
return typ

python_numba_type_map.add(cls, construct)
python_numba_type_map.add(proxy_cls, construct)

if len(self.parameters) > 0:

@overload(cls)
def _ovl():
def impl():
# TODO raise an error
pass

return impl

else:
ctor = self.specificize([])
overload(cls, strict=False)(
make_function(
"init", "*args", "return ctor", {"ctor": ctor.__original_func__}
)
)

def __resolve_members(
self, mapped_parameters: MappedParameters
) -> ResolvedMembersList:
return resolve_members(self.members, mapped_parameters)

def __map_parameters(self, args: Tuple[Type, ...]) -> MappedParameters:
return {param: typ for param, typ in zip(self.parameters, args)}

def __name_to_type_mapping(self, mapped_parameters):
return {var.__name__: typ for var, typ in mapped_parameters.items()}

def __make_constructor(self, struct_type: types.StructRef) -> Callable:
ctor = make_constructor(self.init, struct_type)
ctor_impl = njit(ctor)
ctor_impl.__numba_class_type__ = struct_type
ctor_impl.__original_func__ = ctor

return ctor_impl

def __specificize_type(
self, args: Tuple[Type, ...]
) -> Tuple[NType, MappedParameters]:
mapped_parameters = self.__map_parameters(args)
members = self.__resolve_members(mapped_parameters)
struct_type = self.ref_type(members)
struct_type.mapped_parameters = self.__name_to_type_mapping(mapped_parameters)
self.proxy_type._type.instance_type = struct_type

return struct_type

def __specificize_ctor(self, args: Tuple[Type, ...]) -> Callable:
struct_type = self.__specificize_type(args)
ctor_impl = self.__make_constructor(struct_type)

return ctor_impl

def specificize(self, args: List[Type]) -> Callable:
if len(args) != len(self.parameters):
msg = f"Wrong number of args. "
msg += f"Expected {len(self.parameters)}({self.parameters})."
msg += f"Got {len(args)}({args})"

raise RuntimeError(msg)

_args = tuple(args)
specificized = self.specificized.get(_args)

if not specificized:
specificized = self.__specificize_ctor(_args)
self.specificized[_args] = specificized

return specificized

def _meta_constructor(self, cls, *args, **kwargs):
return self._meta_ctor(*args, **kwargs)
3 changes: 3 additions & 0 deletions numba_extras/jitclass/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from collections import namedtuple

_Params = namedtuple("_Params", ["compile_methods", "jit_options", "members"])
Loading