Skip to content

Commit

Permalink
Clean up experimental launcher
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Dec 1, 2023
1 parent 04c18bf commit b7db7f7
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 310 deletions.
1 change: 0 additions & 1 deletion numba_dpex/core/parfors/parfor_lowerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def _build_kernel_arglist(self, kernel_fn, lowerer, kernel_builder):
callargs_ptrs=callargs_ptrs,
args_list=args_list,
args_ty_list=args_ty_list,
datamodel_mgr=dpex_dmm,
)

return _KernelArgs(
Expand Down
77 changes: 73 additions & 4 deletions numba_dpex/core/utils/kernel_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llvmlite import ir as llvmir
from numba.core import cgutils, types

from numba_dpex import utils
from numba_dpex import config, utils
from numba_dpex.core.runtime.context import DpexRTContext
from numba_dpex.core.types import DpnpNdArray
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
Expand Down Expand Up @@ -361,6 +361,60 @@ def _create_sycl_range(self, idx_range):

return self.builder.bitcast(range_list, intp_ptr_t)

def submit_kernel(
self,
kernel_ref: llvmir.CallInstr,
queue_ref: llvmir.PointerType,
kernel_args: list,
ty_kernel_args: list,
global_range_extents: list,
local_range_extents: list,
):
if config.DEBUG_KERNEL_LAUNCHER:
cgutils.printf(
self.builder,
"DPEX-DEBUG: Populating kernel args and arg type arrays.\n",
)

num_flattened_kernel_args = self.get_num_flattened_kernel_args(
kernel_argtys=ty_kernel_args,
)

# Create LLVM values for the kernel args list and kernel arg types list
args_list = self.allocate_kernel_arg_array(num_flattened_kernel_args)

args_ty_list = self.allocate_kernel_arg_ty_array(
num_flattened_kernel_args
)

kernel_args_ptrs = []
for arg in kernel_args:
ptr = self.builder.alloca(arg.type)
self.builder.store(arg, ptr)
kernel_args_ptrs.append(ptr)

# Populate the args_list and the args_ty_list LLVM arrays
self.populate_kernel_args_and_args_ty_arrays(
callargs_ptrs=kernel_args_ptrs,
kernel_argtys=ty_kernel_args,
args_list=args_list,
args_ty_list=args_ty_list,
)

if config.DEBUG_KERNEL_LAUNCHER:
cgutils.printf(self._builder, "DPEX-DEBUG: Submit kernel.\n")

return self.submit_sycl_kernel(
sycl_kernel_ref=kernel_ref,
sycl_queue_ref=queue_ref,
total_kernel_args=num_flattened_kernel_args,
arg_list=args_list,
arg_ty_list=args_ty_list,
global_range=global_range_extents,
local_range=local_range_extents,
wait_before_return=False,
)

def submit_sycl_kernel(
self,
sycl_kernel_ref,
Expand All @@ -373,7 +427,7 @@ def submit_sycl_kernel(
wait_before_return=True,
) -> llvmir.PointerType(llvmir.IntType(8)):
"""
Submits the kernel to the specified queue, waits.
Submits the kernel to the specified queue, waits by default.
"""
eref = None
gr = self._create_sycl_range(global_range)
Expand Down Expand Up @@ -411,19 +465,34 @@ def submit_sycl_kernel(
else:
return eref

def get_num_flattened_kernel_args(
self,
kernel_argtys: tuple[types.Type, ...],
):
num_flattened_kernel_args = 0
for arg_type in kernel_argtys:
if isinstance(arg_type, DpnpNdArray):
datamodel = self.context.data_model_manager.lookup(arg_type)
num_flattened_kernel_args += datamodel.flattened_field_count
elif arg_type in [types.complex64, types.complex128]:
num_flattened_kernel_args += 2
else:
num_flattened_kernel_args += 1

return num_flattened_kernel_args

def populate_kernel_args_and_args_ty_arrays(
self,
kernel_argtys,
callargs_ptrs,
args_list,
args_ty_list,
datamodel_mgr,
):
kernel_arg_num = 0
for arg_num, argtype in enumerate(kernel_argtys):
llvm_val = callargs_ptrs[arg_num]
if isinstance(argtype, DpnpNdArray):
datamodel = datamodel_mgr.lookup(argtype)
datamodel = self.context.data_model_manager.lookup(argtype)
self.build_array_arg(
array_val=llvm_val,
array_data_model=datamodel,
Expand Down
Loading

0 comments on commit b7db7f7

Please sign in to comment.