Skip to content

Commit

Permalink
test: Make cuda tests pass (#2570)
Browse files Browse the repository at this point in the history
* Refactor cuda test to use importorskip

* Use array.size instead of len() for scalar

* Add import to generated tests

* Replace argtypes with ispointerlist

* Avoid indexes being passed to array_equal

* wipwip

* refactor: use `type_` instead of `t`

* refactor: don't call `asarray` in `broadcast_arrays`

* chore: ensure we import backends directly

* fix: revert kernel name change

* fix: don't wipe exception context on exit

instead, use weakref

* fix: don't use update_wrapper - it maintains a ref

* chore: delete wip kernel

* fix: only weakref-wrap methods

* test: add xfails and split tests

* Adjust for changes in Numba.

---------

Co-authored-by: Angus Hollands <[email protected]>
Co-authored-by: Jim Pivarski <[email protected]>
  • Loading branch information
3 people authored Nov 14, 2023
1 parent 2c75c0f commit e3e4874
Show file tree
Hide file tree
Showing 13 changed files with 208 additions and 75 deletions.
10 changes: 9 additions & 1 deletion dev/generate-kernel-signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def by_signature(cuda_kernel_templates):
special = [repr(spec["name"])]
[type_to_pytype(x["type"], special) for x in childfunc["args"]]
dirlist = [repr(x["dir"]) for x in childfunc["args"]]
ispointerlist = [repr("List" in x["type"]) for x in childfunc["args"]]
if spec["name"] in cuda_kernels_impl:
with open(
os.path.join(
Expand All @@ -404,10 +405,12 @@ def by_signature(cuda_kernel_templates):
def f(grid, block, args):
cuda_kernel_templates.get_function(fetch_specialization([{}]))(grid, block, args)
f.dir = [{}]
f.is_ptr = [{}]
out[{}] = f
""".format(
", ".join(special),
", ".join(dirlist),
", ".join(ispointerlist),
", ".join(special),
)
)
Expand All @@ -428,8 +431,13 @@ def f(grid, block, args):
file.write(python_code)
file.write(
""" f.dir = [{}]
f.is_ptr = [{}]
out[{}] = f
""".format(", ".join(dirlist), ", ".join(special))
""".format(
", ".join(dirlist),
", ".join(ispointerlist),
", ".join(special),
)
)
else:
file.write(
Expand Down
7 changes: 6 additions & 1 deletion dev/generate-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,12 @@ def gencudakerneltests(specdict):
)

f.write(
"import cupy\nimport pytest\n\nimport awkward as ak\nimport awkward._connect.cuda as ak_cu\n\ncupy_backend = ak._backends.CupyBackend.instance()\n\n"
"import cupy\n"
"import pytest\n\n"
"import awkward as ak\n"
"import awkward._connect.cuda as ak_cu\n"
"from awkward._backends.cupy import CupyBackend\n\n"
"cupy_backend = CupyBackend.instance()\n\n"
)
num = 1
if spec.tests == []:
Expand Down
3 changes: 2 additions & 1 deletion src/awkward/_connect/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ def synchronize_cuda(stream=None):
[],
)
raise invoked_kernel.error_context.decorate_exception(
ValueError,
ValueError(
f"{kernel_errors[invoked_kernel.name][int(invocation_index % math.pow(2, ERROR_BITS))]} in compiled CUDA code ({invoked_kernel.name})"
)
),
)
36 changes: 24 additions & 12 deletions src/awkward/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings
from collections.abc import Callable, Collection, Iterable, Mapping
from functools import wraps
from weakref import ref as weak_ref

import numpy

Expand All @@ -18,6 +19,22 @@


E = TypeVar("E", bound=Exception)
T = TypeVar("T")
S = TypeVar("S")
P = ParamSpec("P")


class WeakMethodProxy:
"""A proxy for a method of a weakly referenced object"""

def __init__(self, method):
self._this = weak_ref(method.__self__)
self._impl = method.__func__

def __call__(self, *args, **kwargs):
this = self._this()
method = self._impl.__get__(this, type(this))
return method(*args, **kwargs)


class PartialFunction:
Expand Down Expand Up @@ -67,11 +84,6 @@ def __exit__(self, exception_type, exception_value, traceback):
):
self.handle_exception(exception_type, exception_value)
finally:
# `_kwargs` may hold cyclic references, that we really want to avoid
# as this can lead to large buffers remaining in memory for longer than absolutely necessary
# Let's just clear this, now.
self._kwargs.clear()

# Step out of the way so that another ErrorContext can become primary.
if self.primary() is self:
self._slate.__dict__.clear()
Expand Down Expand Up @@ -226,8 +238,10 @@ def __init__(self, name, args: Iterable[Any], kwargs: Mapping[str, Any]):
# if primary is not None: we won't be setting an ErrorContext
# if all nplikes are eager: no accumulation of large arrays
# --> in either case, delay string generation
string_args = PartialFunction(self._format_args, args)
string_kwargs = PartialFunction(self._format_kwargs, kwargs)
string_args = PartialFunction(WeakMethodProxy(self._format_args), args)
string_kwargs = PartialFunction(
WeakMethodProxy(self._format_kwargs), kwargs
)

super().__init__(
name=name,
Expand Down Expand Up @@ -307,7 +321,9 @@ def __init__(self, array, where):
# if primary is not None: we won't be setting an ErrorContext
# if all nplikes are eager: no accumulation of large arrays
# --> in either case, delay string generation
formatted_array = PartialFunction(self.format_argument, self._width, array)
formatted_array = PartialFunction(
WeakMethodProxy(self.format_argument), self._width, array
)
formatted_slice = PartialFunction(self.format_slice, where)
else:
formatted_array = self.format_argument(self._width, array)
Expand Down Expand Up @@ -423,10 +439,6 @@ def deprecate(
warnings.warn(warning, category, stacklevel=stacklevel + 1)


T = TypeVar("T")
P = ParamSpec("P")


def with_operation_context(func: Callable[P, T]) -> Callable[P, T]:
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
Expand Down
25 changes: 13 additions & 12 deletions src/awkward/_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,23 @@ def max_length(self, args):
# TODO should kernels strip nplike wrapper? Probably
for array in args:
if self._cupy.is_own_array(array):
max_length = max(max_length, len(array))
max_length = max(max_length, array.size)
return max_length

def calc_grid(self, length):
if length > 1024:
return -(length // -1024), 1, 1
return 1, 1, 1
# CUDA blocks are limited to 1024 threads per block, so to
# have more than one block, we have at least `length // 1024` blocks
# of size 1024.
return (length // 1024) + 1, 1, 1

def calc_blocks(self, length):
if length > 1024:
return 1024, 1, 1
return length, 1, 1
# CUDA blocks are limited to 1024 threads per block
# Number of threads are given by `length`
return min(length, 1024), 1, 1

def _cast(self, x, t):
if issubclass(t, ctypes._Pointer):
# Do we have a NumPy-owned array?
def _cast(self, x, type_):
if type_:
# Do we have a CuPy-owned array?
if self._cupy.is_own_array(x):
assert self._cupy.is_c_contiguous(x)
return x
Expand All @@ -149,9 +150,9 @@ def __call__(self, *args) -> None:
cupy.array(ak_cuda.NO_ERROR),
[],
)
assert len(args) == len(self._impl.argtypes)
assert len(args) == len(self._impl.is_ptr)

args = [self._cast(x, t) for x, t in zip(args, self._impl.argtypes)]
args = [self._cast(x, t) for x, t in zip(args, self._impl.is_ptr)]

# The first arg is the invocation index which raises itself by 8 in the kernel if there was no error before.
# The second arg is the error_code.
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/contents/listoffsetarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def _broadcast_tooffsets64(self, offsets: Index) -> ListOffsetArray:
next_content = self._content[this_start:]

if index_nplike.known_data and not index_nplike.array_equal(
this_zero_offsets, offsets
this_zero_offsets, offsets.data
):
raise ValueError("cannot broadcast nested list")

Expand Down
5 changes: 5 additions & 0 deletions tests-cuda/test_1276_cuda_num.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

import awkward as ak

try:
ak.numba.register_and_check()
except ImportError:
pytest.skip(reason="too old Numba version", allow_module_level=True)


@pytest.mark.xfail(reason="unimplemented CUDA Kernels (awkward_ByteMaskedArray_numnull")
def test_num_1():
Expand Down
Loading

0 comments on commit e3e4874

Please sign in to comment.