From d0bedd2496016293e10f6ef0d43a181d5f031f0f Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sat, 11 Nov 2023 14:07:06 -0600 Subject: [PATCH 1/7] Adds a new literal type to store IntEnum as Literal types. - Adds a new IntEnumLiteral type with corresponding data model into the DpexExpKernelTargetContext. The type is used to pass in or define an IntEnum flag as an Integer literal inside a kernel function. --- numba_dpex/core/exceptions.py | 12 ++++ numba_dpex/experimental/__init__.py | 9 ++- numba_dpex/experimental/flag_enum.py | 25 +++++++++ .../experimental/literal_intenum_type.py | 55 +++++++++++++++++++ numba_dpex/experimental/models.py | 16 +++++- numba_dpex/experimental/target.py | 53 ++++++++++++++++++ 6 files changed, 168 insertions(+), 2 deletions(-) create mode 100644 numba_dpex/experimental/flag_enum.py create mode 100644 numba_dpex/experimental/literal_intenum_type.py diff --git a/numba_dpex/core/exceptions.py b/numba_dpex/core/exceptions.py index aa755c0f77..7c47c138b7 100644 --- a/numba_dpex/core/exceptions.py +++ b/numba_dpex/core/exceptions.py @@ -386,3 +386,15 @@ def __init__(self, extra_msg=None) -> None: if extra_msg: self.message += " due to " + extra_msg super().__init__(self.message) + + +class IllegalIntEnumLiteralValueError(Exception): + """Exception raised when an IntEnumLiteral is attempted to be created from + a non FlagEnum attribute. + """ + + def __init__(self) -> None: + self.message = ( + "An IntEnumLiteral can only be initialized from a FlagEnum member" + ) + super().__init__(self.message) diff --git a/numba_dpex/experimental/__init__.py b/numba_dpex/experimental/__init__.py index 7782fb0e76..97aad8a9fb 100644 --- a/numba_dpex/experimental/__init__.py +++ b/numba_dpex/experimental/__init__.py @@ -11,6 +11,7 @@ from .decorators import kernel from .kernel_dispatcher import KernelDispatcher from .launcher import call_kernel, call_kernel_async +from .literal_intenum_type import IntEnumLiteral from .models import * from .types import KernelDispatcherType @@ -26,4 +27,10 @@ def dpex_dispatcher_const(context): return context.get_dummy_value() -__all__ = ["kernel", "KernelDispatcher", "call_kernel", "call_kernel_async"] +__all__ = [ + "kernel", + "call_kernel", + "call_kernel_async", + "IntEnumLiteral", + "KernelDispatcher", +] diff --git a/numba_dpex/experimental/flag_enum.py b/numba_dpex/experimental/flag_enum.py new file mode 100644 index 0000000000..9d77fdd17b --- /dev/null +++ b/numba_dpex/experimental/flag_enum.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +"""Provides a FlagEnum class to help distinguish IntEnum types that numba-dpex +intends to use as Integer literal types inside the compiler type inferring +infrastructure. +""" +from enum import IntEnum + + +class FlagEnum(IntEnum): + """Helper class to distinguish IntEnum types that numba-dpex should consider + as Numba Literal types. + """ + + @classmethod + def basetype(cls) -> int: + """Returns an dummy int object that helps numba-dpex infer the type of + an instance of a FlagEnum class. + + Returns: + int: Dummy int value + """ + return int(0) diff --git a/numba_dpex/experimental/literal_intenum_type.py b/numba_dpex/experimental/literal_intenum_type.py new file mode 100644 index 0000000000..7442294380 --- /dev/null +++ b/numba_dpex/experimental/literal_intenum_type.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +"""Definition of a new Literal type in numba-dpex that allows treating IntEnum +members as integer literals inside a JIT compiled function. +""" +from enum import IntEnum + +from numba.core.pythonapi import box +from numba.core.typeconv import Conversion +from numba.core.types import Integer, Literal +from numba.core.typing.typeof import typeof + +from numba_dpex.core.exceptions import IllegalIntEnumLiteralValueError +from numba_dpex.experimental.flag_enum import FlagEnum + + +class IntEnumLiteral(Literal, Integer): + """A Literal type for IntEnum objects. The type contains the original Python + value of the IntEnum class in it. + """ + + # pylint: disable=W0231 + def __init__(self, value): + self._literal_init(value) + self.name = f"Literal[IntEnum]({value})" + if issubclass(value, FlagEnum): + basetype = typeof(value.basetype()) + Integer.__init__( + self, + name=self.name, + bitwidth=basetype.bitwidth, + signed=basetype.signed, + ) + else: + raise IllegalIntEnumLiteralValueError + + def can_convert_to(self, typingctx, other) -> bool: + conv = typingctx.can_convert(self.literal_type, other) + if conv is not None: + return max(conv, Conversion.promote) + return False + + +Literal.ctor_map[IntEnum] = IntEnumLiteral + + +@box(IntEnumLiteral) +def box_literal_integer(typ, val, c): + """Defines how a Numba representation for an IntEnumLiteral object should + be converted to a PyObject* object and returned back to Python. + """ + val = c.context.cast(c.builder, val, typ, typ.literal_type) + return c.box(typ.literal_type, val) diff --git a/numba_dpex/experimental/models.py b/numba_dpex/experimental/models.py index 3fb8a879fb..837ccde519 100644 --- a/numba_dpex/experimental/models.py +++ b/numba_dpex/experimental/models.py @@ -6,14 +6,28 @@ numba_dpex.experimental module. """ +from llvmlite import ir as llvmir from numba.core.datamodel import DataModelManager, models +from numba.core.datamodel.models import PrimitiveModel from numba.core.extending import register_model import numba_dpex.core.datamodel.models as dpex_core_models +from .literal_intenum_type import IntEnumLiteral from .types import KernelDispatcherType +class LiteralIntEnumModel(PrimitiveModel): + """Representation of an object of LiteralIntEnum type using Numba's + PrimitiveModel that can be represented natively in the target in all + usage contexts. + """ + + def __init__(self, dmm, fe_type): + be_type = llvmir.IntType(fe_type.bitwidth) + super().__init__(dmm, fe_type, be_type) + + def _init_exp_data_model_manager() -> DataModelManager: """Initializes a DpexExpKernelTarget-specific data model manager. @@ -28,7 +42,7 @@ def _init_exp_data_model_manager() -> DataModelManager: dmm = dpex_core_models.dpex_data_model_manager.copy() # Register the types and data model in the DpexExpTargetContext - # Add here... + dmm.register(IntEnumLiteral, LiteralIntEnumModel) return dmm diff --git a/numba_dpex/experimental/target.py b/numba_dpex/experimental/target.py index f38e1d9aea..e901d6b0f7 100644 --- a/numba_dpex/experimental/target.py +++ b/numba_dpex/experimental/target.py @@ -8,8 +8,11 @@ from functools import cached_property +from llvmlite import ir as llvmir +from numba.core import types from numba.core.descriptors import TargetDescriptor from numba.core.target_extension import GPU, target_registry +from numba.core.types.scalars import IntEnumClass from numba_dpex.core.descriptor import DpexTargetOptions from numba_dpex.core.targets.kernel_target import ( @@ -18,6 +21,9 @@ ) from numba_dpex.experimental.models import exp_dmm +from .flag_enum import FlagEnum +from .literal_intenum_type import IntEnumLiteral + # pylint: disable=R0903 class SyclDeviceExp(GPU): @@ -39,6 +45,37 @@ class DpexExpKernelTypingContext(DpexKernelTypingContext): are stable enough to be migrated to DpexKernelTypingContext. """ + def resolve_value_type(self, val): + """ + Return the numba type of a Python value that is being used + as a runtime constant. + ValueError is raised for unsupported types. + """ + + ty = super().resolve_value_type(val) + + if isinstance(ty, IntEnumClass) and issubclass(val, FlagEnum): + ty = IntEnumLiteral(val) + + return ty + + def resolve_getattr(self, typ, attr): + """ + Resolve getting the attribute *attr* (a string) on the Numba type. + The attribute's type is returned, or None if resolution failed. + """ + ty = None + + if isinstance(typ, IntEnumLiteral): + try: + attrval = getattr(typ.literal_value, attr).value + ty = types.IntegerLiteral(attrval) + except ValueError: + pass + else: + ty = super().resolve_getattr(typ, attr) + return ty + # pylint: disable=W0223 # FIXME: Remove the pylint disablement once we add an override for @@ -56,6 +93,22 @@ def __init__(self, typingctx, target=DPEX_KERNEL_EXP_TARGET_NAME): super().__init__(typingctx, target) self.data_model_manager = exp_dmm + def get_getattr(self, typ, attr): + """ + Overrides the get_getattr function to provide an implementation for + getattr call on an IntegerEnumLiteral type. + """ + + if isinstance(typ, IntEnumLiteral): + # pylint: disable=W0613 + def enum_literal_getattr_imp(context, builder, typ, val, attr): + enum_attr_value = getattr(typ.literal_value, attr).value + return llvmir.Constant(llvmir.IntType(64), enum_attr_value) + + return enum_literal_getattr_imp + + return super().get_getattr(typ, attr) + class DpexExpKernelTarget(TargetDescriptor): """ From ec22cc45a6f091756f1142f6a4b66d45327ae2a8 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sat, 25 Nov 2023 11:26:17 -0600 Subject: [PATCH 2/7] Add utility function to return KernelCompileResult from KernelDispatcher --- numba_dpex/experimental/kernel_dispatcher.py | 6 +++--- numba_dpex/experimental/launcher.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/numba_dpex/experimental/kernel_dispatcher.py b/numba_dpex/experimental/kernel_dispatcher.py index 594f411c22..f5ef4fd360 100644 --- a/numba_dpex/experimental/kernel_dispatcher.py +++ b/numba_dpex/experimental/kernel_dispatcher.py @@ -262,12 +262,12 @@ def add_overload(self, cres): args = tuple(cres.signature.args) self.overloads[args] = cres - def get_overload_device_ir(self, sig): + def get_overload_kcres(self, sig) -> _KernelCompileResult: """ - Return the compiled device bitcode for the given signature. + Return the compiled function for the given signature. """ args, _ = sigutils.normalize_signature(sig) - return self.overloads[tuple(args)].kernel_device_ir_module + return self.overloads[tuple(args)] def compile(self, sig) -> any: disp = self._get_dispatcher_for_current_target() diff --git a/numba_dpex/experimental/launcher.py b/numba_dpex/experimental/launcher.py index f5fbb99649..420d7f5758 100644 --- a/numba_dpex/experimental/launcher.py +++ b/numba_dpex/experimental/launcher.py @@ -303,9 +303,9 @@ def _submit_kernel( # codegen kernel_dispatcher: KernelDispatcher = ty_kernel_fn.dispatcher kernel_dispatcher.compile(kernel_sig) - kernel_module: _KernelModule = kernel_dispatcher.get_overload_device_ir( + kernel_module: _KernelModule = kernel_dispatcher.get_overload_kcres( kernel_sig - ) + ).kernel_device_ir_module kernel_targetctx = kernel_dispatcher.targetctx def codegen(cgctx, builder, sig, llargs): From d01968175428cdcb738faad3b062b8b12a717548 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 21 Nov 2023 22:56:05 -0600 Subject: [PATCH 3/7] Unit tests --- .../experimental/IntEnumLiteral/__init__.py | 0 .../IntEnumLiteral/test_type_creation.py | 30 ++++++++++++++++ .../IntEnumLiteral/test_type_registration.py | 36 +++++++++++++++++++ 3 files changed, 66 insertions(+) create mode 100644 numba_dpex/tests/experimental/IntEnumLiteral/__init__.py create mode 100644 numba_dpex/tests/experimental/IntEnumLiteral/test_type_creation.py create mode 100644 numba_dpex/tests/experimental/IntEnumLiteral/test_type_registration.py diff --git a/numba_dpex/tests/experimental/IntEnumLiteral/__init__.py b/numba_dpex/tests/experimental/IntEnumLiteral/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/numba_dpex/tests/experimental/IntEnumLiteral/test_type_creation.py b/numba_dpex/tests/experimental/IntEnumLiteral/test_type_creation.py new file mode 100644 index 0000000000..bf19d6d816 --- /dev/null +++ b/numba_dpex/tests/experimental/IntEnumLiteral/test_type_creation.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +from enum import IntEnum + +import pytest + +from numba_dpex.core.exceptions import IllegalIntEnumLiteralValueError +from numba_dpex.experimental import IntEnumLiteral +from numba_dpex.experimental.flag_enum import FlagEnum + + +def test_intenumliteral_creation(): + """Tests the creation of an IntEnumLiteral type.""" + + class DummyFlags(FlagEnum): + DUMMY = 0 + + try: + IntEnumLiteral(DummyFlags) + except: + pytest.fail("Unexpected failure in IntEnumLiteral initialization") + + with pytest.raises(IllegalIntEnumLiteralValueError): + + class SomeKindOfUnknownEnum(IntEnum): + UNKNOWN_FLAG = 1 + + IntEnumLiteral(SomeKindOfUnknownEnum) diff --git a/numba_dpex/tests/experimental/IntEnumLiteral/test_type_registration.py b/numba_dpex/tests/experimental/IntEnumLiteral/test_type_registration.py new file mode 100644 index 0000000000..a0c9700649 --- /dev/null +++ b/numba_dpex/tests/experimental/IntEnumLiteral/test_type_registration.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from numba.core.datamodel import default_manager + +from numba_dpex.core.datamodel.models import dpex_data_model_manager +from numba_dpex.experimental import IntEnumLiteral +from numba_dpex.experimental.flag_enum import FlagEnum +from numba_dpex.experimental.models import exp_dmm + + +def test_data_model_registration(): + """Tests that the IntEnumLiteral type is only registered with the + DpexExpKernelTargetContext target. + """ + + class DummyFlags(FlagEnum): + DUMMY = 0 + + dummy = IntEnumLiteral(DummyFlags) + + with pytest.raises(KeyError): + default_manager.lookup(dummy) + + with pytest.raises(KeyError): + dpex_data_model_manager.lookup(dummy) + + try: + exp_dmm.lookup(dummy) + except: + pytest.fail( + "IntEnumLiteral type lookup failed in experimental " + "data model manager" + ) From 97707a5b1c394ddd8b15053bc9a3cf306b1acd78 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sat, 25 Nov 2023 11:27:55 -0600 Subject: [PATCH 4/7] Unit test to compile IntEnumLiteral type object --- .../IntEnumLiteral/test_compilation.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 numba_dpex/tests/experimental/IntEnumLiteral/test_compilation.py diff --git a/numba_dpex/tests/experimental/IntEnumLiteral/test_compilation.py b/numba_dpex/tests/experimental/IntEnumLiteral/test_compilation.py new file mode 100644 index 0000000000..c4719462e3 --- /dev/null +++ b/numba_dpex/tests/experimental/IntEnumLiteral/test_compilation.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import dpnp +from numba.core import types +from numba.extending import intrinsic, overload + +import numba_dpex.experimental as exp_dpex +from numba_dpex import Range, dpjit +from numba_dpex.experimental.flag_enum import FlagEnum + + +class MockFlags(FlagEnum): + FLAG1 = 100 + FLAG2 = 200 + + +@exp_dpex.kernel( + release_gil=False, + no_compile=True, + no_cpython_wrapper=True, + no_cfunc_wrapper=True, +) +def update_with_flag(a): + a[0] = MockFlags.FLAG1 + a[1] = MockFlags.FLAG2 + + +def test_compilation_of_flag_enum(): + """Tests if a FlagEnum subclass can be used inside a kernel function.""" + a = dpnp.ones(10, dtype=dpnp.int64) + exp_dpex.call_kernel(update_with_flag, Range(10), a) + + assert a[0] == MockFlags.FLAG1 + assert a[1] == MockFlags.FLAG2 + for idx in range(2, 9): + assert a[idx] == 1 From 1ade40d0e749c0b4e2ee329ea660e54dc9009f03 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 6 Dec 2023 22:42:34 -0600 Subject: [PATCH 5/7] Address review comments. --- numba_dpex/experimental/models.py | 4 ++-- .../tests/experimental/IntEnumLiteral/test_compilation.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/numba_dpex/experimental/models.py b/numba_dpex/experimental/models.py index 837ccde519..61f154a020 100644 --- a/numba_dpex/experimental/models.py +++ b/numba_dpex/experimental/models.py @@ -17,7 +17,7 @@ from .types import KernelDispatcherType -class LiteralIntEnumModel(PrimitiveModel): +class IntEnumLiteralModel(PrimitiveModel): """Representation of an object of LiteralIntEnum type using Numba's PrimitiveModel that can be represented natively in the target in all usage contexts. @@ -42,7 +42,7 @@ def _init_exp_data_model_manager() -> DataModelManager: dmm = dpex_core_models.dpex_data_model_manager.copy() # Register the types and data model in the DpexExpTargetContext - dmm.register(IntEnumLiteral, LiteralIntEnumModel) + dmm.register(IntEnumLiteral, IntEnumLiteralModel) return dmm diff --git a/numba_dpex/tests/experimental/IntEnumLiteral/test_compilation.py b/numba_dpex/tests/experimental/IntEnumLiteral/test_compilation.py index c4719462e3..37c190f458 100644 --- a/numba_dpex/tests/experimental/IntEnumLiteral/test_compilation.py +++ b/numba_dpex/tests/experimental/IntEnumLiteral/test_compilation.py @@ -3,11 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 import dpnp -from numba.core import types -from numba.extending import intrinsic, overload import numba_dpex.experimental as exp_dpex -from numba_dpex import Range, dpjit +from numba_dpex import Range from numba_dpex.experimental.flag_enum import FlagEnum From 9ce99270695b6659935f0a35cf99a26d668ba94f Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 8 Dec 2023 17:45:34 -0600 Subject: [PATCH 6/7] Expose device_func decorator at numba_dpex.experimental module. - To get the experimental device_func decorator to be usable standalone the decorator is now available via the numba_dpex.experimental module. - The DpexExpKernelTarget allows dynamic_globals to make it possible to call device_func from kernel. --- numba_dpex/experimental/__init__.py | 3 ++- numba_dpex/experimental/target.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/numba_dpex/experimental/__init__.py b/numba_dpex/experimental/__init__.py index 97aad8a9fb..d1fac49dc8 100644 --- a/numba_dpex/experimental/__init__.py +++ b/numba_dpex/experimental/__init__.py @@ -8,7 +8,7 @@ from numba.core.imputils import Registry -from .decorators import kernel +from .decorators import device_func, kernel from .kernel_dispatcher import KernelDispatcher from .launcher import call_kernel, call_kernel_async from .literal_intenum_type import IntEnumLiteral @@ -28,6 +28,7 @@ def dpex_dispatcher_const(context): __all__ = [ + "device_func", "kernel", "call_kernel", "call_kernel_async", diff --git a/numba_dpex/experimental/target.py b/numba_dpex/experimental/target.py index e901d6b0f7..8f74963476 100644 --- a/numba_dpex/experimental/target.py +++ b/numba_dpex/experimental/target.py @@ -89,6 +89,8 @@ class DpexExpKernelTargetContext(DpexKernelTargetContext): they are stable enough to be migrated to DpexKernelTargetContext. """ + allow_dynamic_globals = True + def __init__(self, typingctx, target=DPEX_KERNEL_EXP_TARGET_NAME): super().__init__(typingctx, target) self.data_model_manager = exp_dmm From b1ac8d641e1b1cbcf25307fdfb422fbe524c15c7 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 8 Dec 2023 17:51:52 -0600 Subject: [PATCH 7/7] Unit test checking if FlagEnum values are lowered as constants in LLVM IR. --- .../codegen/test_intenum_literal_codegen.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 numba_dpex/tests/experimental/codegen/test_intenum_literal_codegen.py diff --git a/numba_dpex/tests/experimental/codegen/test_intenum_literal_codegen.py b/numba_dpex/tests/experimental/codegen/test_intenum_literal_codegen.py new file mode 100644 index 0000000000..209cb6c787 --- /dev/null +++ b/numba_dpex/tests/experimental/codegen/test_intenum_literal_codegen.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +import re + +import dpctl +from numba.core import types + +import numba_dpex.experimental as exp_dpex +from numba_dpex import DpctlSyclQueue, DpnpNdArray, int64 +from numba_dpex.experimental.flag_enum import FlagEnum + + +def test_compilation_as_literal_constant(): + """Tests if FlagEnum objects are treaded as scalar constants inside + numba-dpex generated code. + + The test case compiles the kernel `pass_flags_to_func` that includes a + call to the device_func `bitwise_or_flags`. The `bitwise_or_flags` function + is passed two FlagEnum arguments. The test case evaluates the generated + LLVM IR for `pass_flags_to_func` to see if the call to `bitwise_or_flags` + has the scalar arguments `i64 1` and `i64 2`. + """ + + class PseudoFlags(FlagEnum): + FLAG1 = 1 + FLAG2 = 2 + + @exp_dpex.device_func + def bitwise_or_flags(flag1, flag2): + return flag1 | flag2 + + def pass_flags_to_func(a): + f1 = PseudoFlags.FLAG1 + f2 = PseudoFlags.FLAG2 + a[0] = bitwise_or_flags(f1, f2) + + queue_ty = DpctlSyclQueue(dpctl.SyclQueue()) + i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty) + kernel_sig = types.void(i64arr_ty) + + disp = exp_dpex.kernel(pass_flags_to_func) + disp.compile(kernel_sig) + kcres = disp.overloads[kernel_sig.args] + llvm_ir_mod = kcres.library._final_module.__str__() + + pattern = re.compile( + r"call spir_func i32 @\_Z.*bitwise\_or" + r"\_flags.*\(i64\* nonnull %.*, i64 1, i64 2\)" + ) + + assert re.search(pattern, llvm_ir_mod) is not None