Skip to content

Commit

Permalink
Update atomic example to be more stable
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Jul 31, 2023
1 parent 32b0da3 commit f323f73
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 62 deletions.
6 changes: 4 additions & 2 deletions numba_dpex/examples/kernel/atomic_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
@ndpx.kernel
def atomic_reduction(a):
idx = ndpx.get_global_id(0)
ndpx.atomic.add(a, 0, a[idx])
ndpx.atomic.add(a, 0, a[idx + 1])


def main():
N = 10
a = np.arange(N)

# We are storing sum to the first element
a = np.arange(0, N + 1)

print("Using device ...")
print(a.device)
Expand Down
2 changes: 0 additions & 2 deletions numba_dpex/ocl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

from .atomics import atomic_support_present
29 changes: 0 additions & 29 deletions numba_dpex/ocl/atomics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,3 @@
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0

import os
import os.path


def atomic_support_present():
if os.path.isfile(
os.path.join(os.path.dirname(__file__), "atomic_ops.spir")
):
return True
else:
return False


def get_atomic_spirv_path():
if atomic_support_present():
return os.path.join(os.path.dirname(__file__), "atomic_ops.spir")
else:
return None


def read_atomic_spirv_file():
path = get_atomic_spirv_path()
if path:
with open(path, "rb") as fin:
spirv = fin.read()
return spirv
else:
return None
10 changes: 4 additions & 6 deletions numba_dpex/ocl/oclimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,14 @@ def native_atomic_add(context, builder, sig, args):
def support_atomic(dtype: types.Type) -> bool:
# This check should be the same as described in sycl documentation:
# https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#sec:atomic-references
# If atomic is not supported, it will be emulated by the sycl compiler.
return (
dtype == types.int32
or dtype == types.uint32
or dtype == types.float32
or (
dtype == types.int64
or dtype == types.uint64
or dtype == types.float64
)
and dpctl.get_current_queue().sycl_device.has_aspect_atomic64
or dtype == types.int64
or dtype == types.uint64
or dtype == types.float64
)


Expand Down
25 changes: 2 additions & 23 deletions numba_dpex/tests/kernel_tests/test_atomic_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0

import dpctl
import dpnp as np
import pytest

import numba_dpex as dpex
from numba_dpex import config
from numba_dpex.core.descriptor import dpex_kernel_target
from numba_dpex.tests._helper import override_config

Expand Down Expand Up @@ -66,13 +64,6 @@ def f(a):
return dpex.kernel(f), request.param[1]


skip_no_atomic_support = pytest.mark.skipif(
not dpex.ocl.atomic_support_present(),
reason="No atomic support",
)


@skip_no_atomic_support
def test_kernel_atomic_simple(input_arrays, kernel_result_pair):
a, dtype = input_arrays()
kernel, expected = kernel_result_pair
Expand Down Expand Up @@ -111,7 +102,6 @@ def f(a):
return f


@skip_no_atomic_support
def test_kernel_atomic_local(input_arrays, return_list_of_op):
a, dtype = input_arrays()
op_type, expected = return_list_of_op
Expand Down Expand Up @@ -148,7 +138,6 @@ def f(a):
return dpex.kernel(f)


@skip_no_atomic_support
def test_kernel_atomic_multi_dim(
return_list_of_op, return_list_of_dim, return_dtype
):
Expand All @@ -160,13 +149,6 @@ def test_kernel_atomic_multi_dim(
assert a[0] == expected


@skip_no_atomic_support
@pytest.mark.parametrize(
"expected_native_atomic_for_device",
[
lambda device: True,
],
)
@pytest.mark.parametrize(
"function_generator", [get_func_global, get_func_local]
)
Expand All @@ -179,7 +161,6 @@ def test_kernel_atomic_multi_dim(
)
@pytest.mark.parametrize("dtype", list_of_f_dtypes)
def test_atomic_fp_native(
expected_native_atomic_for_device,
function_generator,
operator_name,
expected_spirv_function,
Expand All @@ -203,7 +184,5 @@ def test_atomic_fp_native(
typing_ctx=dpex_kernel_target.typing_context,
)

is_native_atomic = expected_spirv_function in kernel._llvm_module
assert is_native_atomic == expected_native_atomic_for_device(
dpctl.select_default_device().filter_string
)
# TODO: this may fail if code is generated for platform that emulates atomic support?
assert expected_spirv_function in kernel._llvm_module

0 comments on commit f323f73

Please sign in to comment.