From 69b0998bb86c3c80750e6cd0861e338cb33ef7fb Mon Sep 17 00:00:00 2001 From: Yevhenii Havrylko Date: Thu, 30 Nov 2023 14:32:39 -0500 Subject: [PATCH] Clean up experimental launcher --- numba_dpex/core/utils/kernel_launcher.py | 77 +++- numba_dpex/experimental/launcher.py | 476 ++++++++--------------- 2 files changed, 243 insertions(+), 310 deletions(-) diff --git a/numba_dpex/core/utils/kernel_launcher.py b/numba_dpex/core/utils/kernel_launcher.py index 143ba1b3d5..416bed3f84 100644 --- a/numba_dpex/core/utils/kernel_launcher.py +++ b/numba_dpex/core/utils/kernel_launcher.py @@ -5,7 +5,7 @@ from llvmlite import ir as llvmir from numba.core import cgutils, types -from numba_dpex import utils +from numba_dpex import config, utils from numba_dpex.core.runtime.context import DpexRTContext from numba_dpex.core.types import DpnpNdArray from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl @@ -361,6 +361,60 @@ def _create_sycl_range(self, idx_range): return self.builder.bitcast(range_list, intp_ptr_t) + def submit_kernel( + self, + kernel_ref: llvmir.CallInstr, + queue_ref: llvmir.PointerType, + kernel_args: list, + ty_kernel_args: list, + global_range_extents: list, + local_range_extents: list, + ): + if config.DEBUG_KERNEL_LAUNCHER: + cgutils.printf( + self.builder, + "DPEX-DEBUG: Populating kernel args and arg type arrays.\n", + ) + + num_flattened_kernel_args = self.get_num_flattened_kernel_args( + kernel_argtys=ty_kernel_args, + ) + + # Create LLVM values for the kernel args list and kernel arg types list + args_list = self.allocate_kernel_arg_array(num_flattened_kernel_args) + + args_ty_list = self.allocate_kernel_arg_ty_array( + num_flattened_kernel_args + ) + + kernel_args_ptrs = [] + for arg in kernel_args: + ptr = self.builder.alloca(arg.type) + self.builder.store(arg, ptr) + kernel_args_ptrs.append(ptr) + + # Populate the args_list and the args_ty_list LLVM arrays + self.populate_kernel_args_and_args_ty_arrays( + callargs_ptrs=kernel_args_ptrs, + kernel_argtys=ty_kernel_args, + args_list=args_list, + args_ty_list=args_ty_list, + ) + + if config.DEBUG_KERNEL_LAUNCHER: + cgutils.printf(self._builder, "DPEX-DEBUG: Submit kernel.\n") + + return self.submit_sycl_kernel( + sycl_kernel_ref=kernel_ref, + sycl_queue_ref=queue_ref, + total_kernel_args=num_flattened_kernel_args, + arg_list=args_list, + arg_ty_list=args_ty_list, + global_range=global_range_extents, + local_range=local_range_extents, + wait_before_return=False, + ) + def submit_sycl_kernel( self, sycl_kernel_ref, @@ -373,7 +427,7 @@ def submit_sycl_kernel( wait_before_return=True, ) -> llvmir.PointerType(llvmir.IntType(8)): """ - Submits the kernel to the specified queue, waits. + Submits the kernel to the specified queue, waits by default. """ eref = None gr = self._create_sycl_range(global_range) @@ -411,19 +465,34 @@ def submit_sycl_kernel( else: return eref + def get_num_flattened_kernel_args( + self, + kernel_argtys: tuple[types.Type, ...], + ): + num_flattened_kernel_args = 0 + for arg_type in kernel_argtys: + if isinstance(arg_type, DpnpNdArray): + datamodel = self.context.data_model_manager.lookup(arg_type) + num_flattened_kernel_args += datamodel.flattened_field_count + elif arg_type in [types.complex64, types.complex128]: + num_flattened_kernel_args += 2 + else: + num_flattened_kernel_args += 1 + + return num_flattened_kernel_args + def populate_kernel_args_and_args_ty_arrays( self, kernel_argtys, callargs_ptrs, args_list, args_ty_list, - datamodel_mgr, ): kernel_arg_num = 0 for arg_num, argtype in enumerate(kernel_argtys): llvm_val = callargs_ptrs[arg_num] if isinstance(argtype, DpnpNdArray): - datamodel = datamodel_mgr.lookup(argtype) + datamodel = self.context.data_model_manager.lookup(argtype) self.build_array_arg( array_val=llvm_val, array_data_model=datamodel, diff --git a/numba_dpex/experimental/launcher.py b/numba_dpex/experimental/launcher.py index 7c201be4a6..fed4dc6271 100644 --- a/numba_dpex/experimental/launcher.py +++ b/numba_dpex/experimental/launcher.py @@ -6,19 +6,18 @@ from either CPython or a numba_dpex.dpjit decorated function. """ -from collections import namedtuple from typing import Union import dpctl from llvmlite import ir as llvmir from numba.core import cgutils, cpu, types from numba.core.datamodel import default_manager as numba_default_dmm +from numba.core.types.functions import Dispatcher from numba.extending import intrinsic from numba_dpex import config, dpjit, utils from numba_dpex.core.exceptions import UnreachableError from numba_dpex.core.runtime.context import DpexRTContext -from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext from numba_dpex.core.types import ( DpctlSyclEvent, DpnpNdArray, @@ -28,153 +27,49 @@ from numba_dpex.core.utils import kernel_launcher as kl from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl from numba_dpex.dpctl_iface.wrappers import wrap_event_reference -from numba_dpex.experimental.kernel_dispatcher import _KernelModule -from numba_dpex.utils import create_null_ptr - -_KernelArgs = namedtuple( - "_KernelArgs", - [ - "flattened_args_count", - "array_of_kernel_args", - "array_of_kernel_arg_types", - ], -) - -_KernelSubmissionArgs = namedtuple( - "_KernelSubmissionArgs", - [ - "kernel_ref", - "queue_ref", - "kernel_args", - "global_range_extents", - "local_range_extents", - ], -) - -_LLVMIRValuesForIndexSpace = namedtuple( - "_LLVMIRValuesForNdRange", ["global_range_extents", "local_range_extents"] +from numba_dpex.experimental.kernel_dispatcher import ( + KernelDispatcher, + _KernelModule, ) +from numba_dpex.utils import create_null_ptr -class _LaunchTrampolineFunctionBodyGenerator: +class _LaunchKernelIRGenerator: """ Helper class to generate the LLVM IR for the launch_trampoline intrinsic. """ - def _get_num_flattened_kernel_args( - self, - kernel_targetctx: DpexKernelTargetContext, - kernel_argtys: tuple[types.Type, ...], - ): - num_flattened_kernel_args = 0 - for arg_type in kernel_argtys: - if isinstance(arg_type, DpnpNdArray): - datamodel = kernel_targetctx.data_model_manager.lookup(arg_type) - num_flattened_kernel_args += datamodel.flattened_field_count - elif arg_type in [types.complex64, types.complex128]: - num_flattened_kernel_args += 2 - else: - num_flattened_kernel_args += 1 - - return num_flattened_kernel_args - def __init__( self, codegen_targetctx: cpu.CPUContext, - kernel_targetctx: DpexKernelTargetContext, builder: llvmir.IRBuilder, ): - self._cpu_codegen_targetctx = codegen_targetctx - self._kernel_targetctx = kernel_targetctx - self._builder = builder - if kernel_targetctx: - self._klbuilder = kl.KernelLaunchIRBuilder( - kernel_targetctx, builder - ) + # No computations here. ALl IR code that is generated by any builder + # must be passed as arguments, not attributes. + self.context = codegen_targetctx + self.builder = builder if config.DEBUG_KERNEL_LAUNCHER: cgutils.printf( - self._builder, + self.builder, "DPEX-DEBUG: Inside the kernel launcher function\n", ) - def insert_kernel_bitcode_as_byte_str( - self, kernel_module: _KernelModule - ) -> None: - """Inserts a global constant byte string in the current LLVM module to - store the passed in SPIR-V binary blob. - """ - return self._cpu_codegen_targetctx.insert_const_bytes( - self._builder.module, - bytes=kernel_module.kernel_bitcode, - ) - - def allocate_meminfos_array(self, num_meminfos): - """Allocates an array to store nrt memory infos. - - Args: - num_meminfos (int): The number of memory infos to allocate. - - Returns: An LLVM IR value pointing to an array to store the memory - infos. - """ - builder = self._builder - context = self._cpu_codegen_targetctx - - meminfo_list = cgutils.alloca_once( - builder, - utils.get_llvm_type(context=context, type=types.voidptr), - size=context.get_constant(types.uintp, num_meminfos), - ) - - return meminfo_list - - def populate_kernel_args_and_argsty_arrays( + def extract_arguments_from_tuple( self, - kernel_argtys: tuple[types.Type, ...], - kernel_args: [llvmir.Instruction, ...], - ) -> _KernelArgs: - """Allocates an LLVM array value to store each flattened kernel arg and - another LLVM array to store the typeid for each flattened kernel arg. - The arrays are the populated with the LLVM value for each arg. - """ - num_flattened_kernel_args = self._get_num_flattened_kernel_args( - kernel_targetctx=self._kernel_targetctx, kernel_argtys=kernel_argtys - ) - - # Create LLVM values for the kernel args list and kernel arg types list - args_list = self._klbuilder.allocate_kernel_arg_array( - num_flattened_kernel_args - ) - args_ty_list = self._klbuilder.allocate_kernel_arg_ty_array( - num_flattened_kernel_args - ) - kernel_args_ptrs = [] - for arg in kernel_args: - ptr = self._builder.alloca(arg.type) - self._builder.store(arg, ptr) - kernel_args_ptrs.append(ptr) - - # Populate the args_list and the args_ty_list LLVM arrays - self._klbuilder.populate_kernel_args_and_args_ty_arrays( - callargs_ptrs=kernel_args_ptrs, - kernel_argtys=kernel_argtys, - args_list=args_list, - args_ty_list=args_ty_list, - datamodel_mgr=self._kernel_targetctx.data_model_manager, - ) + ty_kernel_args_tuple, + ll_kernel_args_tuple, + ): + """Convert arguments to kernel arguments because kernel and dpjit use + different data models""" - if config.DEBUG_KERNEL_LAUNCHER: - cgutils.printf( - self._builder, - "DPEX-DEBUG: Populated kernel args and arg type arrays.\n", + kernel_args = [] + for pos in range(len(ty_kernel_args_tuple)): + kernel_args.append( + self.builder.extract_value(ll_kernel_args_tuple, pos) ) - return _KernelArgs( - flattened_args_count=num_flattened_kernel_args, - array_of_kernel_args=args_list, - array_of_kernel_arg_types=args_ty_list, - ) + return kernel_args def allocate_meminfo_array( self, @@ -185,32 +80,33 @@ def allocate_meminfo_array( kernel arguments. The array is the populated with the LLVM value for every meminfo of the kernel arguments. """ - builder = self._builder - context = self._cpu_codegen_targetctx - meminfos = [] for arg_num, argtype in enumerate(kernel_argtys): llvm_val = kernel_args[arg_num] meminfos += [ meminfo - for ty, meminfo in context.nrt.get_meminfos( - builder, argtype, llvm_val + for ty, meminfo in self.context.nrt.get_meminfos( + self.builder, argtype, llvm_val ) ] - meminfo_list = self.allocate_meminfos_array(len(meminfos)) + meminfo_list = cgutils.alloca_once( + self.builder, + utils.get_llvm_type(context=self.context, type=types.voidptr), + size=self.context.get_constant(types.uintp, len(meminfos)), + ) for meminfo_num, meminfo in enumerate(meminfos): - meminfo_arg_dst = builder.gep( + meminfo_arg_dst = self.builder.gep( meminfo_list, - [context.get_constant(types.int32, meminfo_num)], + [self.context.get_constant(types.int32, meminfo_num)], ) - meminfo_ptr = builder.bitcast( + meminfo_ptr = self.builder.bitcast( meminfo, - utils.get_llvm_type(context=context, type=types.voidptr), + utils.get_llvm_type(context=self.context, type=types.voidptr), ) - builder.store(meminfo_ptr, meminfo_arg_dst) + self.builder.store(meminfo_ptr, meminfo_arg_dst) return len(meminfos), meminfo_list @@ -229,70 +125,85 @@ def get_queue_ref_val( for arg_num, argty in enumerate(kernel_argtys): if isinstance(argty, DpnpNdArray): llvm_val = kernel_args[arg_num] - datamodel = ( - self._cpu_codegen_targetctx.data_model_manager.lookup(argty) - ) + datamodel = self.context.data_model_manager.lookup(argty) sycl_queue_attr_pos = datamodel.get_field_position("sycl_queue") - ptr_to_queue_ref = self._builder.extract_value( + ptr_to_queue_ref = self.builder.extract_value( llvm_val, sycl_queue_attr_pos ) break return ptr_to_queue_ref - def get_kernel(self, kernel_module, kbref): + def get_kernel(self, qref, kernel_module: _KernelModule): """Returns the pointer to the sycl::kernel object in a passed in sycl::kernel_bundle wrapper object. """ - kernel_name = self._cpu_codegen_targetctx.insert_const_string( - self._builder.module, kernel_module.kernel_name + # Inserts a global constant byte string in the current LLVM module to + # store the passed in SPIR-V binary blob. + kernel_bc_byte_str = self.context.insert_const_bytes( + self.builder.module, + bytes=kernel_module.kernel_bitcode, ) - return sycl.dpctl_kernel_bundle_get_kernel( - self._builder, kbref, kernel_name + + # Send it to dpctl, so it create this function on device and returns + # reference to it. + kbref = self.create_kernel_bundle_from_spirv( + queue_ref=qref, + kernel_bc=kernel_bc_byte_str, + kernel_bc_size_in_bytes=len(kernel_module.kernel_bitcode), ) + kernel_name = self.context.insert_const_string( + self.builder.module, kernel_module.kernel_name + ) + + kernel_ref = sycl.dpctl_kernel_bundle_get_kernel( + self.builder, kbref, kernel_name + ) + + sycl.dpctl_kernel_bundle_delete(self.builder, kbref) + + return kernel_ref + def create_llvm_values_for_index_space( self, - indexer_argty: Union[RangeType, NdRangeType], - index_space_arg: llvmir.BaseStructType, - ) -> _LLVMIRValuesForIndexSpace: - """Returns a list of LLVM IR Values that hold the unboxed extents of a - Python Range or NdRange object. + ty_indexer_arg: Union[RangeType, NdRangeType], + index_arg: llvmir.BaseStructType, + ) -> tuple[list, list]: + """Returns two lists of LLVM IR Values that hold the unboxed extents of + a Python Range or NdRange object. """ - ndim = indexer_argty.ndim - grange_extents = [] - lrange_extents = [] - indexer_datamodel = numba_default_dmm.lookup(indexer_argty) + ndim = ty_indexer_arg.ndim + global_range_extents = [] + local_range_extents = [] + indexer_datamodel = numba_default_dmm.lookup(ty_indexer_arg) - if isinstance(indexer_argty, RangeType): + if isinstance(ty_indexer_arg, RangeType): for dim_num in range(ndim): dim_pos = indexer_datamodel.get_field_position( "dim" + str(dim_num) ) - grange_extents.append( - self._builder.extract_value(index_space_arg, dim_pos) + global_range_extents.append( + self.builder.extract_value(index_arg, dim_pos) ) - elif isinstance(indexer_argty, NdRangeType): + elif isinstance(ty_indexer_arg, NdRangeType): for dim_num in range(ndim): gdim_pos = indexer_datamodel.get_field_position( "gdim" + str(dim_num) ) - grange_extents.append( - self._builder.extract_value(index_space_arg, gdim_pos) + global_range_extents.append( + self.builder.extract_value(index_arg, gdim_pos) ) ldim_pos = indexer_datamodel.get_field_position( "ldim" + str(dim_num) ) - lrange_extents.append( - self._builder.extract_value(index_space_arg, ldim_pos) + local_range_extents.append( + self.builder.extract_value(index_arg, ldim_pos) ) else: raise UnreachableError - return _LLVMIRValuesForIndexSpace( - global_range_extents=grange_extents, - local_range_extents=lrange_extents, - ) + return global_range_extents, local_range_extents def create_kernel_bundle_from_spirv( self, @@ -303,110 +214,72 @@ def create_kernel_bundle_from_spirv( """Calls DPCTLKernelBundle_CreateFromSpirv to create an opaque pointer to a sycl::kernel_bundle from the SPIR-V generated for a kernel. """ - dref = sycl.dpctl_queue_get_device(self._builder, queue_ref) - cref = sycl.dpctl_queue_get_context(self._builder, queue_ref) + device_ref = sycl.dpctl_queue_get_device(self.builder, queue_ref) + context_ref = sycl.dpctl_queue_get_context(self.builder, queue_ref) args = [ - cref, - dref, + context_ref, + device_ref, kernel_bc, llvmir.Constant(llvmir.IntType(64), kernel_bc_size_in_bytes), - self._builder.load( - create_null_ptr(self._builder, self._cpu_codegen_targetctx) - ), + self.builder.load(create_null_ptr(self.builder, self.context)), ] - kbref = sycl.dpctl_kernel_bundle_create_from_spirv(self._builder, *args) - sycl.dpctl_context_delete(self._builder, cref) - sycl.dpctl_device_delete(self._builder, dref) + kb_ref = sycl.dpctl_kernel_bundle_create_from_spirv(self.builder, *args) + sycl.dpctl_context_delete(self.builder, context_ref) + sycl.dpctl_device_delete(self.builder, device_ref) if config.DEBUG_KERNEL_LAUNCHER: cgutils.printf( - self._builder, + self.builder, "DPEX-DEBUG: Generated kernel_bundle from SPIR-V.\n", ) - return kbref - - def submit( - self, submit_call_args: _KernelSubmissionArgs - ) -> llvmir.PointerType(llvmir.IntType(8)): - """Generates LLVM IR CallInst to submit a kernel to specified SYCL - queue. - """ - if config.DEBUG_KERNEL_LAUNCHER: - cgutils.printf( - self._builder, "DPEX-DEBUG: Submit sync range kernel.\n" - ) - - eref = self._klbuilder.submit_sycl_kernel( - sycl_kernel_ref=submit_call_args.kernel_ref, - sycl_queue_ref=submit_call_args.queue_ref, - total_kernel_args=submit_call_args.kernel_args.flattened_args_count, - arg_list=submit_call_args.kernel_args.array_of_kernel_args, - arg_ty_list=submit_call_args.kernel_args.array_of_kernel_arg_types, - global_range=submit_call_args.global_range_extents, - local_range=submit_call_args.local_range_extents, - wait_before_return=False, - ) - if config.DEBUG_KERNEL_LAUNCHER: - cgutils.printf(self._builder, "DPEX-DEBUG: Wait on event.\n") - - return eref + return kb_ref def acquire_meminfo_and_schedule_release( self, - qref, - eref, - total_meminfos, - meminfo_list, + queue_ref, + event_ref, + ty_kernel_args, + kernel_args, ): """Schedule sycl host task to release nrt meminfo of the arguments used to run job. Use it to keep arguments alive during kernel execution.""" - ctx = self._cpu_codegen_targetctx - builder = self._builder + total_meminfos, meminfo_list = self.allocate_meminfo_array( + ty_kernel_args, kernel_args + ) - eref_ptr = builder.alloca(eref.type) - builder.store(eref, eref_ptr) + event_ref_ptr = self.builder.alloca(event_ref.type) + self.builder.store(event_ref, event_ref_ptr) status_ptr = cgutils.alloca_once( - builder, ctx.get_value_type(types.uint64) + self.builder, self.context.get_value_type(types.uint64) ) # TODO: get dpex RT from cached property once the PR is merged # https://github.com/IntelPython/numba-dpex/pull/1027 # host_eref = ctx.dpexrt.acquire_meminfo_and_schedule_release( # noqa: W0621 - host_eref = DpexRTContext(ctx).acquire_meminfo_and_schedule_release( - builder, + host_eref = DpexRTContext( + self.context + ).acquire_meminfo_and_schedule_release( + self.builder, [ - ctx.nrt.get_nrt_api(builder), - qref, + self.context.nrt.get_nrt_api(self.builder), + queue_ref, meminfo_list, - ctx.get_constant(types.uintp, total_meminfos), - eref_ptr, - ctx.get_constant(types.uintp, 1), + self.context.get_constant(types.uintp, total_meminfos), + event_ref_ptr, + self.context.get_constant(types.uintp, 1), status_ptr, ], ) return host_eref - def cleanup( - self, - kernel_ref: llvmir.Instruction, - kernel_bundle_ref: llvmir.Instruction, - ) -> None: - """Generates calls to free up temporary resources that were allocated in - the launch_trampoline body. - """ - # Delete the kernel ref - sycl.dpctl_kernel_delete(self._builder, kernel_ref) - # Delete the kernel bundle pointer - sycl.dpctl_kernel_bundle_delete(self._builder, kernel_bundle_ref) - @intrinsic(target="cpu") def _submit_kernel( typingctx, # pylint: disable=W0613 - kernel_fn, - index_space, - kernel_args, + ty_kernel_fn: Dispatcher, + ty_index_space, + ty_kernel_args_tuple, ): """Generates IR code for call_kernel dpjit function. @@ -416,74 +289,71 @@ def _submit_kernel( extracted from the args. Finally, the actual kernel is extracted from the kernel bundle and submitted to the sycl queue. """ - kernel_args_list = list(kernel_args) # signature of this intrinsic ty_event = DpctlSyclEvent() - sig = ty_event(kernel_fn, index_space, kernel_args) - # signature of the kernel_fn - kernel_sig = types.void(*kernel_args_list) - kernel_fn.dispatcher.compile(kernel_sig) - kernel_module: _KernelModule = kernel_fn.dispatcher.get_overload_device_ir( + sig = ty_event(ty_kernel_fn, ty_index_space, ty_kernel_args_tuple) + kernel_sig = types.void(*ty_kernel_args_tuple) + # ty_kernel_fn is type specific to exact function, so we can get function + # directly from type and compile it. Thats why we don't need to get it in + # codegen + kernel_dispatcher: KernelDispatcher = ty_kernel_fn.dispatcher + kernel_dispatcher.compile(kernel_sig) + kernel_module: _KernelModule = kernel_dispatcher.get_overload_device_ir( kernel_sig ) - kernel_targetctx = kernel_fn.dispatcher.targetctx + kernel_targetctx = kernel_dispatcher.targetctx - # TODO: refactor so there are no too many locals - def codegen(cgctx, builder, sig, llargs): # pylint: disable=R0914 - kernel_argtys = kernel_sig.args - kernel_args_unpacked = [] - for pos in range(len(kernel_args)): - kernel_args_unpacked.append(builder.extract_value(llargs[2], pos)) + def codegen(cgctx, builder, sig, llargs): + # llargs[0] is kernel function that we don't need anymore (see above) + ty_index_space = sig.args[1] + ll_index_space = llargs[1] + ty_kernel_args_tuple = sig.args[2] + ll_kernel_args_tuple = llargs[2] - fn_body_gen = _LaunchTrampolineFunctionBodyGenerator( + generator = _LaunchKernelIRGenerator( codegen_targetctx=cgctx, - kernel_targetctx=kernel_targetctx, builder=builder, ) - kernel_bc_byte_str = fn_body_gen.insert_kernel_bitcode_as_byte_str( - kernel_module + kernel_args = generator.extract_arguments_from_tuple( + ty_kernel_args_tuple=ty_kernel_args_tuple, + ll_kernel_args_tuple=ll_kernel_args_tuple, ) - populated_kernel_args = ( - fn_body_gen.populate_kernel_args_and_argsty_arrays( - kernel_argtys, kernel_args_unpacked - ) + # queue_ref is just a pointer to the attribute, so we don't have to + # clean it up + queue_ref = generator.get_queue_ref_val( + kernel_argtys=ty_kernel_args_tuple, + kernel_args=kernel_args, ) - qref = fn_body_gen.get_queue_ref_val( - kernel_argtys=kernel_args_list, - kernel_args=kernel_args_unpacked, - ) + # creates new object, so we must clean it up + kernel_ref = generator.get_kernel(queue_ref, kernel_module) - kbref = fn_body_gen.create_kernel_bundle_from_spirv( - queue_ref=qref, - kernel_bc=kernel_bc_byte_str, - kernel_bc_size_in_bytes=len(kernel_module.kernel_bitcode), + # What is it for? + ( + global_range_extents, + local_range_extents, + ) = generator.create_llvm_values_for_index_space( + ty_indexer_arg=ty_index_space, + index_arg=ll_index_space, ) - kref = fn_body_gen.get_kernel(kernel_module, kbref) - - index_space_values = fn_body_gen.create_llvm_values_for_index_space( - indexer_argty=sig.args[1], - index_space_arg=llargs[1], + device_event_ref = kl.KernelLaunchIRBuilder( + kernel_targetctx, builder + ).submit_kernel( + kernel_ref=kernel_ref, + queue_ref=queue_ref, + kernel_args=kernel_args, + ty_kernel_args=ty_kernel_args_tuple, + global_range_extents=global_range_extents, + local_range_extents=local_range_extents, ) - submit_call_args = _KernelSubmissionArgs( - kernel_ref=kref, - queue_ref=qref, - kernel_args=populated_kernel_args, - global_range_extents=index_space_values.global_range_extents, - local_range_extents=index_space_values.local_range_extents, - ) + # Clean up + sycl.dpctl_kernel_delete(builder, kernel_ref) - eref = fn_body_gen.submit(submit_call_args) - # We could've just wait and delete event here, but we want to reuse - # this function in async kernel submition and unfortunately numba does - # not support conditional returns: - # https://github.com/numba/numba/issues/9314 - device_event = wrap_event_reference(cgctx, builder, eref) - return device_event + return wrap_event_reference(cgctx, builder, device_event_ref) return sig, codegen @@ -508,38 +378,32 @@ def codegen(cgctx, builder, sig, llargs): cgctx, builder, value=llargs[0] ) - kernel_args_tuple = llargs[1] - ty_kernel_args = sig.args[1] - - kernel_args = [] - for pos in range(len(ty_kernel_args)): - kernel_args.append(builder.extract_value(kernel_args_tuple, pos)) + ll_kernel_args_tuple = llargs[1] + ty_kernel_args_tuple = sig.args[1] - fn_body_gen = _LaunchTrampolineFunctionBodyGenerator( + generator = _LaunchKernelIRGenerator( codegen_targetctx=cgctx, - kernel_targetctx=None, builder=builder, ) - total_meminfos, meminfo_list = fn_body_gen.allocate_meminfo_array( - ty_kernel_args, kernel_args + kernel_args = generator.extract_arguments_from_tuple( + ty_kernel_args_tuple=ty_kernel_args_tuple, + ll_kernel_args_tuple=ll_kernel_args_tuple, ) - qref = fn_body_gen.get_queue_ref_val( - kernel_argtys=ty_kernel_args, + qref = generator.get_queue_ref_val( + kernel_argtys=ty_kernel_args_tuple, kernel_args=kernel_args, ) - host_eref = fn_body_gen.acquire_meminfo_and_schedule_release( - qref=qref, - eref=device_event.event_ref, - total_meminfos=total_meminfos, - meminfo_list=meminfo_list, + host_eref = generator.acquire_meminfo_and_schedule_release( + queue_ref=qref, + event_ref=device_event.event_ref, + ty_kernel_args=ty_kernel_args_tuple, + kernel_args=kernel_args, ) - host_event = wrap_event_reference(cgctx, builder, host_eref) - - return host_event + return wrap_event_reference(cgctx, builder, host_eref) return sig, codegen