Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up experimental launcher #1231

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion numba_dpex/core/parfors/parfor_lowerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def _build_kernel_arglist(self, kernel_fn, lowerer, kernel_builder):
callargs_ptrs=callargs_ptrs,
args_list=args_list,
args_ty_list=args_ty_list,
datamodel_mgr=dpex_dmm,
)

return _KernelArgs(
Expand Down
77 changes: 73 additions & 4 deletions numba_dpex/core/utils/kernel_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llvmlite import ir as llvmir
from numba.core import cgutils, types

from numba_dpex import utils
from numba_dpex import config, utils
from numba_dpex.core.runtime.context import DpexRTContext
from numba_dpex.core.types import DpnpNdArray
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
Expand Down Expand Up @@ -361,6 +361,60 @@ def _create_sycl_range(self, idx_range):

return self.builder.bitcast(range_list, intp_ptr_t)

def submit_kernel(
self,
kernel_ref: llvmir.CallInstr,
queue_ref: llvmir.PointerType,
kernel_args: list,
ty_kernel_args: list,
global_range_extents: list,
local_range_extents: list,
):
if config.DEBUG_KERNEL_LAUNCHER:
cgutils.printf(
self.builder,
"DPEX-DEBUG: Populating kernel args and arg type arrays.\n",
)

num_flattened_kernel_args = self.get_num_flattened_kernel_args(
kernel_argtys=ty_kernel_args,
)

# Create LLVM values for the kernel args list and kernel arg types list
args_list = self.allocate_kernel_arg_array(num_flattened_kernel_args)

args_ty_list = self.allocate_kernel_arg_ty_array(
num_flattened_kernel_args
)

kernel_args_ptrs = []
for arg in kernel_args:
ptr = self.builder.alloca(arg.type)
self.builder.store(arg, ptr)
kernel_args_ptrs.append(ptr)

# Populate the args_list and the args_ty_list LLVM arrays
self.populate_kernel_args_and_args_ty_arrays(
callargs_ptrs=kernel_args_ptrs,
kernel_argtys=ty_kernel_args,
args_list=args_list,
args_ty_list=args_ty_list,
)

if config.DEBUG_KERNEL_LAUNCHER:
cgutils.printf(self._builder, "DPEX-DEBUG: Submit kernel.\n")

return self.submit_sycl_kernel(
sycl_kernel_ref=kernel_ref,
sycl_queue_ref=queue_ref,
total_kernel_args=num_flattened_kernel_args,
arg_list=args_list,
arg_ty_list=args_ty_list,
global_range=global_range_extents,
local_range=local_range_extents,
wait_before_return=False,
)

def submit_sycl_kernel(
self,
sycl_kernel_ref,
Expand All @@ -373,7 +427,7 @@ def submit_sycl_kernel(
wait_before_return=True,
) -> llvmir.PointerType(llvmir.IntType(8)):
"""
Submits the kernel to the specified queue, waits.
Submits the kernel to the specified queue, waits by default.
"""
eref = None
gr = self._create_sycl_range(global_range)
Expand Down Expand Up @@ -411,19 +465,34 @@ def submit_sycl_kernel(
else:
return eref

def get_num_flattened_kernel_args(
self,
kernel_argtys: tuple[types.Type, ...],
):
num_flattened_kernel_args = 0
for arg_type in kernel_argtys:
if isinstance(arg_type, DpnpNdArray):
datamodel = self.context.data_model_manager.lookup(arg_type)
num_flattened_kernel_args += datamodel.flattened_field_count
elif arg_type in [types.complex64, types.complex128]:
num_flattened_kernel_args += 2
else:
num_flattened_kernel_args += 1

return num_flattened_kernel_args

def populate_kernel_args_and_args_ty_arrays(
self,
kernel_argtys,
callargs_ptrs,
args_list,
args_ty_list,
datamodel_mgr,
):
kernel_arg_num = 0
for arg_num, argtype in enumerate(kernel_argtys):
llvm_val = callargs_ptrs[arg_num]
if isinstance(argtype, DpnpNdArray):
datamodel = datamodel_mgr.lookup(argtype)
datamodel = self.context.data_model_manager.lookup(argtype)
self.build_array_arg(
array_val=llvm_val,
array_data_model=datamodel,
Expand Down
Loading
Loading