diff --git a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py index 63deb68ff3..315d00c47e 100644 --- a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py +++ b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py @@ -69,6 +69,10 @@ def gen(context, builder, sig, args): "--spirv-ext=+SPV_EXT_shader_atomic_float_add" ] + context.extra_compile_options[LLVM_SPIRV_ARGS] = [ + "--spirv-ext=+SPV_EXT_shader_atomic_float_min_max" + ] + ptr_type = retty.as_pointer() ptr_type.addrspace = atomic_ref_ty.address_space @@ -118,6 +122,59 @@ def _intrinsic_fetch_add(ty_context, ty_atomic_ref, ty_val): return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_add") +def _atomic_sub_float_wrapper(gen_fn): + def gen(context, builder, sig, args): + # args is a tuple, which is immutable + # covert tuple to list obj first before replacing arg[1] + # with fneg and convert back to tuple again. + args_lst = list(args) + args_lst[1] = builder.fneg(args[1]) + args = tuple(args_lst) + + gen_fn(context, builder, sig, args) + + return gen + + +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) +def _intrinsic_fetch_sub(ty_context, ty_atomic_ref, ty_val): + if ty_atomic_ref.dtype in (types.float32, types.float64): + # dpcpp does not support ``__spirv_AtomicFSubEXT``. fetch_sub + # for floats is implemented by negating the value and calling fetch_add. + # For example, A.fetch_sub(A, val) is implemented as A.fetch_add(-val). + sig, gen = _intrinsic_helper( + ty_context, ty_atomic_ref, ty_val, "fetch_add" + ) + return sig, _atomic_sub_float_wrapper(gen) + + return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_sub") + + +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) +def _intrinsic_fetch_min(ty_context, ty_atomic_ref, ty_val): + return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_min") + + +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) +def _intrinsic_fetch_max(ty_context, ty_atomic_ref, ty_val): + return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_max") + + +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) +def _intrinsic_fetch_and(ty_context, ty_atomic_ref, ty_val): + return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_and") + + +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) +def _intrinsic_fetch_or(ty_context, ty_atomic_ref, ty_val): + return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_or") + + +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) +def _intrinsic_fetch_xor(ty_context, ty_atomic_ref, ty_val): + return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_xor") + + @intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) def _intrinsic_atomic_ref_ctor( ty_context, ref, ty_index, ty_retty_ref # pylint: disable=unused-argument @@ -294,3 +351,168 @@ def ol_fetch_add_impl(atomic_ref, val): return _intrinsic_fetch_add(atomic_ref, val) return ol_fetch_add_impl + + +@overload_method(AtomicRefType, "fetch_sub", target=DPEX_KERNEL_EXP_TARGET_NAME) +def ol_fetch_sub(atomic_ref, val): + """SPIR-V overload for + :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_sub`. + + Generates the same LLVM IR instruction as dpcpp for the + `atomic_ref::fetch_sub` function. + + Raises: + TypingError: When the dtype of the aggregator value does not match the + dtype of the AtomicRef type. + """ + if atomic_ref.dtype != val: + raise errors.TypingError( + f"Type of value to sub: {val} does not match the type of the " + f"reference: {atomic_ref.dtype} stored in the atomic ref." + ) + + def ol_fetch_sub_impl(atomic_ref, val): + # pylint: disable=no-value-for-parameter + return _intrinsic_fetch_sub(atomic_ref, val) + + return ol_fetch_sub_impl + + +@overload_method(AtomicRefType, "fetch_min", target=DPEX_KERNEL_EXP_TARGET_NAME) +def ol_fetch_min(atomic_ref, val): + """SPIR-V overload for + :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_min`. + + Generates the same LLVM IR instruction as dpcpp for the + `atomic_ref::fetch_min` function. + + Raises: + TypingError: When the dtype of the aggregator value does not match the + dtype of the AtomicRef type. + """ + if atomic_ref.dtype != val: + raise errors.TypingError( + f"Type of value to find min: {val} does not match the type of the " + f"reference: {atomic_ref.dtype} stored in the atomic ref." + ) + + def ol_fetch_min_impl(atomic_ref, val): + # pylint: disable=no-value-for-parameter + return _intrinsic_fetch_min(atomic_ref, val) + + return ol_fetch_min_impl + + +@overload_method(AtomicRefType, "fetch_max", target=DPEX_KERNEL_EXP_TARGET_NAME) +def ol_fetch_max(atomic_ref, val): + """SPIR-V overload for + :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_max`. + + Generates the same LLVM IR instruction as dpcpp for the + `atomic_ref::fetch_max` function. + + Raises: + TypingError: When the dtype of the aggregator value does not match the + dtype of the AtomicRef type. + """ + if atomic_ref.dtype != val: + raise errors.TypingError( + f"Type of value to find max: {val} does not match the type of the " + f"reference: {atomic_ref.dtype} stored in the atomic ref." + ) + + def ol_fetch_max_impl(atomic_ref, val): + # pylint: disable=no-value-for-parameter + return _intrinsic_fetch_max(atomic_ref, val) + + return ol_fetch_max_impl + + +@overload_method(AtomicRefType, "fetch_and", target=DPEX_KERNEL_EXP_TARGET_NAME) +def ol_fetch_and(atomic_ref, val): + """SPIR-V overload for + :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_and`. + + Generates the same LLVM IR instruction as dpcpp for the + `atomic_ref::fetch_and` function. + + Raises: + TypingError: When the dtype of the aggregator value does not match the + dtype of the AtomicRef type. + """ + if atomic_ref.dtype != val: + raise errors.TypingError( + f"Type of value to and: {val} does not match the type of the " + f"reference: {atomic_ref.dtype} stored in the atomic ref." + ) + + if atomic_ref.dtype not in (types.int32, types.int64): + raise errors.TypingError( + "fetch_and operation only supported on int32 and int64 dtypes." + ) + + def ol_fetch_and_impl(atomic_ref, val): + # pylint: disable=no-value-for-parameter + return _intrinsic_fetch_and(atomic_ref, val) + + return ol_fetch_and_impl + + +@overload_method(AtomicRefType, "fetch_or", target=DPEX_KERNEL_EXP_TARGET_NAME) +def ol_fetch_or(atomic_ref, val): + """SPIR-V overload for + :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_or`. + + Generates the same LLVM IR instruction as dpcpp for the + `atomic_ref::fetch_or` function. + + Raises: + TypingError: When the dtype of the aggregator value does not match the + dtype of the AtomicRef type. + """ + if atomic_ref.dtype != val: + raise errors.TypingError( + f"Type of value to or: {val} does not match the type of the " + f"reference: {atomic_ref.dtype} stored in the atomic ref." + ) + + if atomic_ref.dtype not in (types.int32, types.int64): + raise errors.TypingError( + "fetch_or operation only supported on int32 and int64 dtypes." + ) + + def ol_fetch_or_impl(atomic_ref, val): + # pylint: disable=no-value-for-parameter + return _intrinsic_fetch_or(atomic_ref, val) + + return ol_fetch_or_impl + + +@overload_method(AtomicRefType, "fetch_xor", target=DPEX_KERNEL_EXP_TARGET_NAME) +def ol_fetch_xor(atomic_ref, val): + """SPIR-V overload for + :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_xor`. + + Generates the same LLVM IR instruction as dpcpp for the + `atomic_ref::fetch_xor` function. + + Raises: + TypingError: When the dtype of the aggregator value does not match the + dtype of the AtomicRef type. + """ + if atomic_ref.dtype != val: + raise errors.TypingError( + f"Type of value to xor: {val} does not match the type of the " + f"reference: {atomic_ref.dtype} stored in the atomic ref." + ) + + if atomic_ref.dtype not in (types.int32, types.int64): + raise errors.TypingError( + "fetch_xor operation only supported on int32 and int64 dtypes." + ) + + def ol_fetch_xor_impl(atomic_ref, val): + # pylint: disable=no-value-for-parameter + return _intrinsic_fetch_xor(atomic_ref, val) + + return ol_fetch_xor_impl diff --git a/numba_dpex/spirv_generator.py b/numba_dpex/spirv_generator.py index 36d3fef9f0..80b18ff93e 100644 --- a/numba_dpex/spirv_generator.py +++ b/numba_dpex/spirv_generator.py @@ -159,7 +159,10 @@ def finalize(self): # TODO: find better approach to set SPIRV compiler arguments. Workaround # against caching intrinsic that sets this argument. # https://github.com/IntelPython/numba-dpex/issues/1262 - llvm_spirv_args = ["--spirv-ext=+SPV_EXT_shader_atomic_float_add"] + llvm_spirv_args = [ + "--spirv-ext=+SPV_EXT_shader_atomic_float_add", + "--spirv-ext=+SPV_EXT_shader_atomic_float_min_max", + ] for key in list(self.context.extra_compile_options.keys()): if key == LLVM_SPIRV_ARGS: llvm_spirv_args = self.context.extra_compile_options[key] diff --git a/numba_dpex/tests/experimental/kernel_iface/spv_overloads/test_atomic_fetch_phi.py b/numba_dpex/tests/experimental/kernel_iface/spv_overloads/test_atomic_fetch_phi.py index ae101ad684..64dd384375 100644 --- a/numba_dpex/tests/experimental/kernel_iface/spv_overloads/test_atomic_fetch_phi.py +++ b/numba_dpex/tests/experimental/kernel_iface/spv_overloads/test_atomic_fetch_phi.py @@ -3,7 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 import dpnp +import numpy as np import pytest +from numba.core.errors import TypingError import numba_dpex as dpex import numba_dpex.experimental as dpex_exp @@ -14,9 +16,17 @@ no_bool=True, no_float16=True, no_none=True, no_complex=True ) +list_of_int_dtypes = get_all_dtypes( + no_bool=True, no_float=True, no_none=True, no_complex=True +) + +list_of_float_dtypes = get_all_dtypes( + no_bool=True, no_int=True, no_float16=True, no_none=True, no_complex=True +) + @pytest.fixture(params=list_of_supported_dtypes) -def input_arrays(request): +def input_arrays_for_add(request): # The size of input and out arrays to be used N = 10 a = dpnp.ones(N, dtype=request.param) @@ -24,22 +34,279 @@ def input_arrays(request): return a, b +@pytest.fixture(params=list_of_supported_dtypes) +def input_arrays_for_sub(request): + # The size of input and out arrays to be used + N = 10 + a = dpnp.ones(N, dtype=request.param) + b_np = np.ones(N) * N + 1 + b = dpnp.asarray(b_np, dtype=request.param) + return a, b + + +@pytest.fixture(params=list_of_supported_dtypes) +def input_arrays_for_min_max(request): + # The size of input and out arrays to be used + N = 10 + a = dpnp.arange(N, dtype=request.param) + b = dpnp.ones(N, dtype=request.param) + return a, b + + +@pytest.fixture(params=list_of_int_dtypes) +def input_arrays_for_and_or_xor(request): + # The size of input and out arrays to be used + N = 10 + a = dpnp.arange(N, dtype=request.param) + b = dpnp.ones(N, dtype=request.param) + return a, b + + @pytest.mark.parametrize("ref_index", [0, 5]) -def test_fetch_add(input_arrays, ref_index): +def test_fetch_add(input_arrays_for_add, ref_index): @dpex_exp.kernel - def atomic_ref_kernel(a, b, ref_index): + def atomic_add_kernel(a, b, ref_index): i = dpex.get_global_id(0) v = AtomicRef(b, index=ref_index) v.fetch_add(a[i]) - a, b = input_arrays + a, b = input_arrays_for_add - dpex_exp.call_kernel(atomic_ref_kernel, dpex.Range(10), a, b, ref_index) + dpex_exp.call_kernel(atomic_add_kernel, dpex.Range(10), a, b, ref_index) # Verify that `a` was accumulated at b[ref_index] assert b[ref_index] == 10 +def test_fetch_add_diff_types(): + """A negative test that verifies that a TypingError is raised if + AtomicRef type and value to be added are of different types. + """ + + @dpex_exp.kernel + def atomic_add_kernel(a, b): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=0) + v.fetch_add(a[i]) + + N = 10 + a = dpnp.ones(N, dtype=dpnp.float32) + b = dpnp.zeros(N, dtype=dpnp.int32) + + with pytest.raises(TypingError): + dpex_exp.call_kernel(atomic_add_kernel, dpex.Range(10), a, b) + + +@pytest.mark.parametrize("ref_index", [0, 5]) +def test_fetch_sub(input_arrays_for_sub, ref_index): + @dpex_exp.kernel + def atomic_sub_kernel(a, b, ref_index): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=ref_index) + v.fetch_sub(a[i]) + + a, b = input_arrays_for_sub + + dpex_exp.call_kernel(atomic_sub_kernel, dpex.Range(10), a, b, ref_index) + + # Verify that b[ref_index] -= a[0:N] + assert b[ref_index] == 1 + + +def test_fetch_sub_diff_types(): + """A negative test that verifies that a TypingError is raised if + AtomicRef type and value to be subtracted are of different types. + """ + + @dpex_exp.kernel + def atomic_sub_kernel(a, b): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=0) + v.fetch_sub(a[i]) + + N = 10 + a = dpnp.ones(N, dtype=dpnp.float32) + b = dpnp.zeros(N, dtype=dpnp.int32) + + with pytest.raises(TypingError): + dpex_exp.call_kernel(atomic_sub_kernel, dpex.Range(10), a, b) + + +@pytest.mark.parametrize("ref_index", [0, 5]) +def test_fetch_min(input_arrays_for_min_max, ref_index): + @dpex_exp.kernel + def atomic_min_kernel(a, b, ref_index): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=ref_index) + v.fetch_min(a[i]) + + a, b = input_arrays_for_min_max + + dpex_exp.call_kernel(atomic_min_kernel, dpex.Range(10), a, b, ref_index) + + # Verify that b[ref_index] is min(a[0:N]) + assert b[ref_index] == 0 + + +def test_fetch_min_diff_types(): + """A negative test that verifies that a TypingError is raised if + AtomicRef type and compared value to compute min are of different types. + """ + + @dpex_exp.kernel + def atomic_min_kernel(a, b): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=0) + v.fetch_min(a[i]) + + N = 10 + a = dpnp.ones(N, dtype=dpnp.float32) + b = dpnp.zeros(N, dtype=dpnp.int32) + + with pytest.raises(TypingError): + dpex_exp.call_kernel(atomic_min_kernel, dpex.Range(10), a, b) + + +@pytest.mark.parametrize("ref_index", [0, 5]) +def test_fetch_max(input_arrays_for_min_max, ref_index): + @dpex_exp.kernel + def atomic_max_kernel(a, b, ref_index): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=ref_index) + v.fetch_max(a[i]) + + a, b = input_arrays_for_min_max + + dpex_exp.call_kernel(atomic_max_kernel, dpex.Range(10), a, b, ref_index) + + # Verify that b[ref_index] is max(a[0:N]) + assert b[ref_index] == 9 + + +def test_fetch_max_diff_types(): + """A negative test that verifies that a TypingError is raised if + AtomicRef type and compared value to compute max are of different types. + """ + + @dpex_exp.kernel + def atomic_max_kernel(a, b): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=0) + v.fetch_max(a[i]) + + N = 10 + a = dpnp.ones(N, dtype=dpnp.float32) + b = dpnp.zeros(N, dtype=dpnp.int32) + + with pytest.raises(TypingError): + dpex_exp.call_kernel(atomic_max_kernel, dpex.Range(10), a, b) + + +@pytest.mark.parametrize("ref_index", [0, 5]) +def test_fetch_and(input_arrays_for_and_or_xor, ref_index): + @dpex_exp.kernel + def atomic_and_kernel(a, b, ref_index): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=ref_index) + v.fetch_and(a[i]) + + a, b = input_arrays_for_and_or_xor + + dpex_exp.call_kernel(atomic_and_kernel, dpex.Range(10), a, b, ref_index) + + # Verify that b[ref_index] is bitwise and(a[0:N]) + assert b[ref_index] == 0 + + +def test_fetch_and_diff_types(): + """A negative test that verifies that a TypingError is raised if + AtomicRef type and value for bitwise and are of different types. + """ + + @dpex_exp.kernel + def atomic_and_kernel(a, b): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=0) + v.fetch_and(a[i]) + + N = 10 + a = dpnp.ones(N, dtype=dpnp.float32) + b = dpnp.zeros(N, dtype=dpnp.int32) + + with pytest.raises(TypingError): + dpex_exp.call_kernel(atomic_and_kernel, dpex.Range(10), a, b) + + +@pytest.mark.parametrize("ref_index", [0, 5]) +def test_fetch_or(input_arrays_for_and_or_xor, ref_index): + @dpex_exp.kernel + def atomic_or_kernel(a, b, ref_index): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=ref_index) + v.fetch_or(a[i]) + + a, b = input_arrays_for_and_or_xor + + dpex_exp.call_kernel(atomic_or_kernel, dpex.Range(10), a, b, ref_index) + + # Verify that b[ref_index] is bitwise or(a[0:N]) + assert b[ref_index] == 15 + + +def test_fetch_or_diff_types(): + """A negative test that verifies that a TypingError is raised if + AtomicRef type and value for bitwise or are of different types. + """ + + @dpex_exp.kernel + def atomic_or_kernel(a, b): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=0) + v.fetch_or(a[i]) + + N = 10 + a = dpnp.ones(N, dtype=dpnp.float32) + b = dpnp.zeros(N, dtype=dpnp.int32) + + with pytest.raises(TypingError): + dpex_exp.call_kernel(atomic_or_kernel, dpex.Range(10), a, b) + + +@pytest.mark.parametrize("ref_index", [0, 5]) +def test_fetch_xor(input_arrays_for_and_or_xor, ref_index): + @dpex_exp.kernel + def atomic_xor_kernel(a, b, ref_index): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=ref_index) + v.fetch_xor(a[i]) + + a, b = input_arrays_for_and_or_xor + + dpex_exp.call_kernel(atomic_xor_kernel, dpex.Range(10), a, b, ref_index) + + # Verify that b[ref_index] is bitwise xor(a[0:N]) + assert b[ref_index] == 0 + + +def test_fetch_xor_diff_types(): + """A negative test that verifies that a TypingError is raised if + AtomicRef type and value for bitwise xor are of different types. + """ + + @dpex_exp.kernel + def atomic_xor_kernel(a, b): + i = dpex.get_global_id(0) + v = AtomicRef(b, index=0) + v.fetch_xor(a[i]) + + N = 10 + a = dpnp.ones(N, dtype=dpnp.float32) + b = dpnp.zeros(N, dtype=dpnp.int32) + + with pytest.raises(TypingError): + dpex_exp.call_kernel(atomic_xor_kernel, dpex.Range(10), a, b) + + @dpex_exp.kernel def atomic_ref_0(a): i = dpex.get_global_id(0) @@ -54,7 +321,7 @@ def atomic_ref_1(a): v.fetch_add(a[i + 2]) -def test_spirv_compiler_flags(): +def test_spirv_compiler_flags_add(): """Check if float atomic flag is being populated from intrinsic for the second call. @@ -68,3 +335,36 @@ def test_spirv_compiler_flags(): assert a[0] == N - 1 assert a[1] == N - 1 + + +@dpex_exp.kernel +def atomic_max_0(a): + i = dpex.get_global_id(0) + v = AtomicRef(a, index=0) + if i != 0: + v.fetch_max(a[i]) + + +@dpex_exp.kernel +def atomic_max_1(a): + i = dpex.get_global_id(0) + v = AtomicRef(a, index=0) + if i != 0: + v.fetch_max(a[i]) + + +def test_spirv_compiler_flags_max(): + """Check if float atomic flag is being populated from intrinsic for the + second call. + + https://github.com/IntelPython/numba-dpex/issues/1262 + """ + N = 10 + a = dpnp.arange(N, dtype=dpnp.float32) + b = dpnp.arange(N, dtype=dpnp.float32) + + dpex_exp.call_kernel(atomic_max_0, dpex.Range(N), a) + dpex_exp.call_kernel(atomic_max_1, dpex.Range(N), b) + + assert a[0] == N - 1 + assert b[0] == N - 1