diff --git a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py index 57b820e1d2..315d00c47e 100644 --- a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py +++ b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py @@ -69,6 +69,10 @@ def gen(context, builder, sig, args): "--spirv-ext=+SPV_EXT_shader_atomic_float_add" ] + context.extra_compile_options[LLVM_SPIRV_ARGS] = [ + "--spirv-ext=+SPV_EXT_shader_atomic_float_min_max" + ] + ptr_type = retty.as_pointer() ptr_type.addrspace = atomic_ref_ty.address_space @@ -118,8 +122,31 @@ def _intrinsic_fetch_add(ty_context, ty_atomic_ref, ty_val): return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_add") +def _atomic_sub_float_wrapper(gen_fn): + def gen(context, builder, sig, args): + # args is a tuple, which is immutable + # covert tuple to list obj first before replacing arg[1] + # with fneg and convert back to tuple again. + args_lst = list(args) + args_lst[1] = builder.fneg(args[1]) + args = tuple(args_lst) + + gen_fn(context, builder, sig, args) + + return gen + + @intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) def _intrinsic_fetch_sub(ty_context, ty_atomic_ref, ty_val): + if ty_atomic_ref.dtype in (types.float32, types.float64): + # dpcpp does not support ``__spirv_AtomicFSubEXT``. fetch_sub + # for floats is implemented by negating the value and calling fetch_add. + # For example, A.fetch_sub(A, val) is implemented as A.fetch_add(-val). + sig, gen = _intrinsic_helper( + ty_context, ty_atomic_ref, ty_val, "fetch_add" + ) + return sig, _atomic_sub_float_wrapper(gen) + return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_sub") diff --git a/numba_dpex/tests/experimental/kernel_iface/spv_overloads/test_atomic_fetch_phi.py b/numba_dpex/tests/experimental/kernel_iface/spv_overloads/test_atomic_fetch_phi.py index ae101ad684..c1b7c9f8a8 100644 --- a/numba_dpex/tests/experimental/kernel_iface/spv_overloads/test_atomic_fetch_phi.py +++ b/numba_dpex/tests/experimental/kernel_iface/spv_overloads/test_atomic_fetch_phi.py @@ -3,7 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 import dpnp +import numpy as np import pytest +from numba.core.errors import TypingError import numba_dpex as dpex import numba_dpex.experimental as dpex_exp @@ -14,9 +16,17 @@ no_bool=True, no_float16=True, no_none=True, no_complex=True ) +list_of_int_dtypes = get_all_dtypes( + no_bool=True, no_float=True, no_none=True, no_complex=True +) + +list_of_float_dtypes = get_all_dtypes( + no_bool=True, no_int=True, no_float16=True, no_none=True, no_complex=True +) + @pytest.fixture(params=list_of_supported_dtypes) -def input_arrays(request): +def input_arrays_for_add(request): # The size of input and out arrays to be used N = 10 a = dpnp.ones(N, dtype=request.param) @@ -24,22 +34,279 @@ def input_arrays(request): return a, b +@pytest.fixture(params=list_of_supported_dtypes) +def input_arrays_for_sub(request): + # The size of input and out arrays to be used + N = 10 + a = dpnp.ones(N, dtype=request.param) + b_np = np.ones(N) * N + 1 + b = dpnp.asarray(b_np, dtype=request.param) + return a, b + + +@pytest.fixture(params=list_of_supported_dtypes) +def input_arrays_for_min_max(request): + # The size of input and out arrays to be used + N = 10 + a = dpnp.arange(N, dtype=request.param) + b = dpnp.ones(N, dtype=request.param) + return a, b + + +@pytest.fixture(params=list_of_int_dtypes) +def input_arrays_for_and_or_xor(request): + # The size of input and out arrays to be used + N = 10 + a = dpnp.arange(N, dtype=request.param) + b = dpnp.ones(N, dtype=request.param) + return a, b + + @pytest.mark.parametrize("ref_index", [0, 5]) -def test_fetch_add(input_arrays, ref_index): +def test_fetch_add(input_arrays_for_add, ref_index): @dpex_exp.kernel - def atomic_ref_kernel(a, b, ref_index): + def atomic_add_kernel(a, b, ref_index): i = dpex.get_global_id(0) v = AtomicRef(b, index=ref_index) v.fetch_add(a[i]) - a, b = input_arrays + a, b = input_arrays_for_add - dpex_exp.call_kernel(atomic_ref_kernel, dpex.Range(10), a, b, ref_index) + dpex_exp.call_kernel(atomic_add_kernel, dpex.Range(10), a, b, ref_index) # Verify that `a` was accumulated at b[ref_index] assert b[ref_index] == 10 +def test_fetch_add_diff_types(): + """A negative test that verifies that a TypingError is raised if + AtomicRef type and value to be added are of different types. + """ + + @dpex_exp.kernel + def atomic_add_kernel(a, b): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=0) + v.fetch_add(a[i]) + + N = 10 + a = dpnp.ones(N, dtype=dpnp.float32) + b = dpnp.zeros(N, dtype=dpnp.int32) + + with pytest.raises(TypingError): + dpex_exp.call_kernel(atomic_add_kernel, dpex.Range(10), a, b) + + +@pytest.mark.parametrize("ref_index", [0, 5]) +def test_fetch_sub(input_arrays_for_sub, ref_index): + @dpex_exp.kernel + def atomic_sub_kernel(a, b, ref_index): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=ref_index) + v.fetch_sub(a[i]) + + a, b = input_arrays_for_sub + + dpex_exp.call_kernel(atomic_sub_kernel, dpex.Range(10), a, b, ref_index) + + # Verify that b[ref_index] -= a[0:N] + assert b[ref_index] == 1 + + +def test_fetch_sub_diff_types(): + """A negative test that verifies that a TypingError is raised if + AtomicRef type and value to be subtracted are of different types. + """ + + @dpex_exp.kernel + def atomic_sub_kernel(a, b): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=0) + v.fetch_sub(a[i]) + + N = 10 + a = dpnp.ones(N, dtype=dpnp.float32) + b = dpnp.zeros(N, dtype=dpnp.int32) + + with pytest.raises(TypingError): + dpex_exp.call_kernel(atomic_sub_kernel, dpex.Range(10), a, b) + + +@pytest.mark.parametrize("ref_index", [0, 5]) +def test_fetch_min(input_arrays_for_min_max, ref_index): + @dpex_exp.kernel + def atomic_min_kernel(a, b, ref_index): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=ref_index) + v.fetch_min(a[i]) + + a, b = input_arrays_for_min_max + + dpex_exp.call_kernel(atomic_min_kernel, dpex.Range(10), a, b, ref_index) + + # Verify that b[ref_index] is min(a[0:N]) + assert b[ref_index] == 0 + + +def test_fetch_min_diff_types(): + """A negative test that verifies that a TypingError is raised if + AtomicRef type and compared value to compute min are of different types. + """ + + @dpex_exp.kernel + def atomic_min_kernel(a, b): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=0) + v.fetch_min(a[i]) + + N = 10 + a = dpnp.ones(N, dtype=dpnp.float32) + b = dpnp.zeros(N, dtype=dpnp.int32) + + with pytest.raises(TypingError): + dpex_exp.call_kernel(atomic_min_kernel, dpex.Range(10), a, b) + + +@pytest.mark.parametrize("ref_index", [0, 5]) +def test_fetch_max(input_arrays_for_min_max, ref_index): + @dpex_exp.kernel + def atomic_max_kernel(a, b, ref_index): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=ref_index) + v.fetch_max(a[i]) + + a, b = input_arrays_for_min_max + + dpex_exp.call_kernel(atomic_max_kernel, dpex.Range(10), a, b, ref_index) + + # Verify that b[ref_index] is max(a[0:N]) + assert b[ref_index] == 9 + + +def test_fetch_max_diff_types(): + """A negative test that verifies that a TypingError is raised if + AtomicRef type and compared value to compute max are of different types. + """ + + @dpex_exp.kernel + def atomic_max_kernel(a, b): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=0) + v.fetch_max(a[i]) + + N = 10 + a = dpnp.ones(N, dtype=dpnp.float32) + b = dpnp.zeros(N, dtype=dpnp.int32) + + with pytest.raises(TypingError): + dpex_exp.call_kernel(atomic_max_kernel, dpex.Range(10), a, b) + + +@pytest.mark.parametrize("ref_index", [0, 5]) +def test_fetch_and(input_arrays_for_and_or_xor, ref_index): + @dpex_exp.kernel + def atomic_and_kernel(a, b, ref_index): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=ref_index) + v.fetch_and(a[i]) + + a, b = input_arrays_for_and_or_xor + + dpex_exp.call_kernel(atomic_and_kernel, dpex.Range(10), a, b, ref_index) + + # Verify that b[ref_index] is bitwise and(a[0:N]) + assert b[ref_index] == 1 + + +def test_fetch_and_diff_types(): + """A negative test that verifies that a TypingError is raised if + AtomicRef type and value for bitwise and are of different types. + """ + + @dpex_exp.kernel + def atomic_and_kernel(a, b): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=0) + v.fetch_and(a[i]) + + N = 10 + a = dpnp.ones(N, dtype=dpnp.float32) + b = dpnp.zeros(N, dtype=dpnp.int32) + + with pytest.raises(TypingError): + dpex_exp.call_kernel(atomic_and_kernel, dpex.Range(10), a, b) + + +@pytest.mark.parametrize("ref_index", [0, 5]) +def test_fetch_or(input_arrays_for_and_or_xor, ref_index): + @dpex_exp.kernel + def atomic_or_kernel(a, b, ref_index): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=ref_index) + v.fetch_or(a[i]) + + a, b = input_arrays_for_and_or_xor + + dpex_exp.call_kernel(atomic_or_kernel, dpex.Range(10), a, b, ref_index) + + # Verify that b[ref_index] is bitwise or(a[0:N]) + assert b[ref_index] == 15 + + +def test_fetch_or_diff_types(): + """A negative test that verifies that a TypingError is raised if + AtomicRef type and value for bitwise or are of different types. + """ + + @dpex_exp.kernel + def atomic_or_kernel(a, b): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=0) + v.fetch_or(a[i]) + + N = 10 + a = dpnp.ones(N, dtype=dpnp.float32) + b = dpnp.zeros(N, dtype=dpnp.int32) + + with pytest.raises(TypingError): + dpex_exp.call_kernel(atomic_or_kernel, dpex.Range(10), a, b) + + +@pytest.mark.parametrize("ref_index", [0, 5]) +def test_fetch_xor(input_arrays_for_and_or_xor, ref_index): + @dpex_exp.kernel + def atomic_xor_kernel(a, b, ref_index): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=ref_index) + v.fetch_xor(a[i]) + + a, b = input_arrays_for_and_or_xor + + dpex_exp.call_kernel(atomic_xor_kernel, dpex.Range(10), a, b, ref_index) + + # Verify that b[ref_index] is bitwise xor(a[0:N]) + assert b[ref_index] == 0 + + +def test_fetch_xor_diff_types(): + """A negative test that verifies that a TypingError is raised if + AtomicRef type and value for bitwise xor are of different types. + """ + + @dpex_exp.kernel + def atomic_xor_kernel(a, b): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=0) + v.fetch_xor(a[i]) + + N = 10 + a = dpnp.ones(N, dtype=dpnp.float32) + b = dpnp.zeros(N, dtype=dpnp.int32) + + with pytest.raises(TypingError): + dpex_exp.call_kernel(atomic_xor_kernel, dpex.Range(10), a, b) + + @dpex_exp.kernel def atomic_ref_0(a): i = dpex.get_global_id(0)