Skip to content

Commit

Permalink
Remove kernel_interface.arrayobj
Browse files Browse the repository at this point in the history
    - Merges core.kernel_interface.arrayobj into
      kernel_api_impl.spirv.arrayobj
    - Moves arrayobj.populate_array into target.py
    - Updated populate_array to use USMNdArray type.
    - Remove core.kernel_api
  • Loading branch information
Diptorup Deb authored and ZzEeKkAa committed Apr 1, 2024
1 parent 9a34d0a commit bea7e3f
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 154 deletions.
6 changes: 0 additions & 6 deletions numba_dpex/core/kernel_interface/__init__.py

This file was deleted.

137 changes: 0 additions & 137 deletions numba_dpex/core/kernel_interface/arrayobj.py

This file was deleted.

4 changes: 2 additions & 2 deletions numba_dpex/dpnp_iface/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from numba.np.arrayobj import make_array
from numba.np.numpy_support import is_nonelike

from numba_dpex.core.kernel_interface.arrayobj import (
from numba_dpex.core.types import DpnpNdArray
from numba_dpex.kernel_api_impl.spirv.arrayobj import (
_getitem_array_generic as kernel_getitem_array_generic,
)
from numba_dpex.core.types import DpnpNdArray
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext

from ._intrinsic import (
Expand Down
62 changes: 61 additions & 1 deletion numba_dpex/kernel_api_impl/spirv/arrayobj.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: 2012 - 2024 Anaconda Inc.
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: BSD-2-Clause

"""Contains SPIR-V specific array functions."""

Expand All @@ -12,10 +14,68 @@
from llvmlite.ir.builder import IRBuilder
from numba.core import cgutils, errors, types
from numba.core.base import BaseContext
from numba.np import arrayobj as np_arrayobj

from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext
from numba_dpex.ocl.oclimpl import _get_target_data

from .target import SPIRVTargetContext


def make_view(
context, builder, ary, return_type, data, shapes, strides
): # pylint: disable=too-many-arguments
"""
Build a view over the given array with the given parameters.
This is analog of numpy.np.arrayobj.make_view without parent and
meminfo fields, because they don't make sense on device. This function
intended to be used only in kernel targets.
"""
retary = np_arrayobj.make_array(return_type)(context, builder)
context.populate_array(
retary, data=data, shape=shapes, strides=strides, itemsize=ary.itemsize
)
return retary


def _getitem_array_generic(
context, builder, return_type, aryty, ary, index_types, indices
): # pylint: disable=too-many-arguments
"""
Return the result of indexing *ary* with the given *indices*,
returning either a scalar or a view.
This is analog of numpy.np.arrayobj._getitem_array_generic without parent
and meminfo fields, because they don't make sense on device. This function
intended to be used only in kernel targets.
"""
dataptr, view_shapes, view_strides = np_arrayobj.basic_indexing(
context,
builder,
aryty,
ary,
index_types,
indices,
boundscheck=context.enable_boundscheck,
)

if isinstance(return_type, types.Buffer):
# Build array view
retary = make_view(
context,
builder,
ary,
return_type,
dataptr,
view_shapes,
view_strides,
)
return retary._getvalue() # pylint: disable=protected-access

# Load scalar from 0-d result
assert not view_shapes
return np_arrayobj.load_item(context, builder, aryty, dataptr)


def get_itemsize(context: SPIRVTargetContext, array_type: types.Array):
"""
Expand Down
80 changes: 72 additions & 8 deletions numba_dpex/kernel_api_impl/spirv/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from numba.core.typing import cmathdecl, enumdecl

from numba_dpex.core.datamodel.models import _init_kernel_data_model_manager
from numba_dpex.core.types import IntEnumLiteral
from numba_dpex.core.types import IntEnumLiteral, USMNdArray
from numba_dpex.core.typing import dpnpdecl
from numba_dpex.kernel_api.flag_enum import FlagEnum
from numba_dpex.ocl.mathimpl import lower_ocl_impl, sig_mapper
Expand All @@ -37,9 +37,7 @@

class CompilationMode(IntEnum):
"""Flags used to determine how a function should be compiled by the
numba_dpex.experimental.dispatcher.KernelDispatcher. Note the functionality
will be merged into numba_dpex.core.kernel_interface.dispatcher in the
future.
numba_dpex.kernel_api_impl_spirv.dispatcher.KernelDispatcher.
KERNEL : Indicates that the function will be compiled into an
LLVM function that has ``spir_kernel`` calling
Expand Down Expand Up @@ -220,6 +218,75 @@ def _generate_spir_kernel_wrapper(self, func, argtypes):
module.get_function(func.name).linkage = "internal"
return wrapper

def _populate_array(
self, arraystruct, data, shape, strides, itemsize
): # pylint: disable=too-many-arguments,too-many-locals
"""
Helper function for populating array structures.
The function is copied from upstream Numba and modified to support the
USMNdArray data type that uses a different data model on SYCL devices
than the upstream types.Array data type. USMNdArray data model does not
have the ``parent`` and ``meminfo`` fields. This function intended to be
used only in the SPIRVKernelTarget.
*shape* and *strides* can be Python tuples or LLVM arrays.
"""
context = arraystruct._context # pylint: disable=protected-access
builder = arraystruct._builder # pylint: disable=protected-access
datamodel = arraystruct._datamodel # pylint: disable=protected-access
# doesn't matter what this array type instance is, it's just to get the
# fields for the data model of the standard array type in this context
standard_array = USMNdArray(ndim=1, layout="C", dtype=nb_types.float64)
standard_array_type_datamodel = context.data_model_manager[
standard_array
]
required_fields = set(standard_array_type_datamodel._fields)
datamodel_fields = set(datamodel._fields)
# Make sure that the presented array object has a data model that is
# close enough to an array for this function to proceed.
if (required_fields & datamodel_fields) != required_fields:
missing = required_fields - datamodel_fields
msg = (
f"The datamodel for type {arraystruct} is missing "
f"field{'s' if len(missing) > 1 else ''} {missing}."
)
raise ValueError(msg)

intp_t = context.get_value_type(nb_types.intp)
if isinstance(shape, (tuple, list)):
shape = cgutils.pack_array(builder, shape, intp_t)
if isinstance(strides, (tuple, list)):
strides = cgutils.pack_array(builder, strides, intp_t)
if isinstance(itemsize, int):
itemsize = intp_t(itemsize)

attrs = {
"shape": shape,
"strides": strides,
"data": data,
"itemsize": itemsize,
}

# Calc num of items from shape
nitems = context.get_constant(nb_types.intp, 1)
unpacked_shape = cgutils.unpack_tuple(builder, shape, shape.type.count)
# (note empty shape => 0d array therefore nitems = 1)
for axlen in unpacked_shape:
nitems = builder.mul(nitems, axlen, flags=["nsw"])
attrs["nitems"] = nitems

# Make sure that we have all the fields
got_fields = set(attrs.keys())
if got_fields != required_fields:
raise ValueError(f"missing {required_fields - got_fields}")

# Set field value
for k, v in attrs.items():
setattr(arraystruct, k, v)

return arraystruct

def get_getattr(self, typ, attr):
"""
Overrides the get_getattr function to provide an implementation for
Expand Down Expand Up @@ -419,10 +486,7 @@ def populate_array(self, arr, **kwargs):
"""
Populate array structure.
"""
# pylint: disable=import-outside-toplevel
from numba_dpex.core.kernel_interface import arrayobj

return arrayobj.populate_array(arr, **kwargs)
return self._populate_array(arr, **kwargs)

def get_executable(self, func, fndesc, env):
"""Not implemented for SPIRVTargetContext"""
Expand Down

0 comments on commit bea7e3f

Please sign in to comment.