diff --git a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py index 63deb68ff3..57b820e1d2 100644 --- a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py +++ b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py @@ -118,6 +118,36 @@ def _intrinsic_fetch_add(ty_context, ty_atomic_ref, ty_val): return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_add") +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) +def _intrinsic_fetch_sub(ty_context, ty_atomic_ref, ty_val): + return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_sub") + + +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) +def _intrinsic_fetch_min(ty_context, ty_atomic_ref, ty_val): + return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_min") + + +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) +def _intrinsic_fetch_max(ty_context, ty_atomic_ref, ty_val): + return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_max") + + +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) +def _intrinsic_fetch_and(ty_context, ty_atomic_ref, ty_val): + return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_and") + + +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) +def _intrinsic_fetch_or(ty_context, ty_atomic_ref, ty_val): + return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_or") + + +@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) +def _intrinsic_fetch_xor(ty_context, ty_atomic_ref, ty_val): + return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_xor") + + @intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME) def _intrinsic_atomic_ref_ctor( ty_context, ref, ty_index, ty_retty_ref # pylint: disable=unused-argument @@ -294,3 +324,168 @@ def ol_fetch_add_impl(atomic_ref, val): return _intrinsic_fetch_add(atomic_ref, val) return ol_fetch_add_impl + + +@overload_method(AtomicRefType, "fetch_sub", target=DPEX_KERNEL_EXP_TARGET_NAME) +def ol_fetch_sub(atomic_ref, val): + """SPIR-V overload for + :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_sub`. + + Generates the same LLVM IR instruction as dpcpp for the + `atomic_ref::fetch_sub` function. + + Raises: + TypingError: When the dtype of the aggregator value does not match the + dtype of the AtomicRef type. + """ + if atomic_ref.dtype != val: + raise errors.TypingError( + f"Type of value to sub: {val} does not match the type of the " + f"reference: {atomic_ref.dtype} stored in the atomic ref." + ) + + def ol_fetch_sub_impl(atomic_ref, val): + # pylint: disable=no-value-for-parameter + return _intrinsic_fetch_sub(atomic_ref, val) + + return ol_fetch_sub_impl + + +@overload_method(AtomicRefType, "fetch_min", target=DPEX_KERNEL_EXP_TARGET_NAME) +def ol_fetch_min(atomic_ref, val): + """SPIR-V overload for + :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_min`. + + Generates the same LLVM IR instruction as dpcpp for the + `atomic_ref::fetch_min` function. + + Raises: + TypingError: When the dtype of the aggregator value does not match the + dtype of the AtomicRef type. + """ + if atomic_ref.dtype != val: + raise errors.TypingError( + f"Type of value to find min: {val} does not match the type of the " + f"reference: {atomic_ref.dtype} stored in the atomic ref." + ) + + def ol_fetch_min_impl(atomic_ref, val): + # pylint: disable=no-value-for-parameter + return _intrinsic_fetch_min(atomic_ref, val) + + return ol_fetch_min_impl + + +@overload_method(AtomicRefType, "fetch_max", target=DPEX_KERNEL_EXP_TARGET_NAME) +def ol_fetch_max(atomic_ref, val): + """SPIR-V overload for + :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_max`. + + Generates the same LLVM IR instruction as dpcpp for the + `atomic_ref::fetch_max` function. + + Raises: + TypingError: When the dtype of the aggregator value does not match the + dtype of the AtomicRef type. + """ + if atomic_ref.dtype != val: + raise errors.TypingError( + f"Type of value to find max: {val} does not match the type of the " + f"reference: {atomic_ref.dtype} stored in the atomic ref." + ) + + def ol_fetch_max_impl(atomic_ref, val): + # pylint: disable=no-value-for-parameter + return _intrinsic_fetch_max(atomic_ref, val) + + return ol_fetch_max_impl + + +@overload_method(AtomicRefType, "fetch_and", target=DPEX_KERNEL_EXP_TARGET_NAME) +def ol_fetch_and(atomic_ref, val): + """SPIR-V overload for + :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_and`. + + Generates the same LLVM IR instruction as dpcpp for the + `atomic_ref::fetch_and` function. + + Raises: + TypingError: When the dtype of the aggregator value does not match the + dtype of the AtomicRef type. + """ + if atomic_ref.dtype != val: + raise errors.TypingError( + f"Type of value to and: {val} does not match the type of the " + f"reference: {atomic_ref.dtype} stored in the atomic ref." + ) + + if atomic_ref.dtype not in (types.int32, types.int64): + raise errors.TypingError( + "fetch_and operation only supported on int32 and int64 dtypes." + ) + + def ol_fetch_and_impl(atomic_ref, val): + # pylint: disable=no-value-for-parameter + return _intrinsic_fetch_and(atomic_ref, val) + + return ol_fetch_and_impl + + +@overload_method(AtomicRefType, "fetch_or", target=DPEX_KERNEL_EXP_TARGET_NAME) +def ol_fetch_or(atomic_ref, val): + """SPIR-V overload for + :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_or`. + + Generates the same LLVM IR instruction as dpcpp for the + `atomic_ref::fetch_or` function. + + Raises: + TypingError: When the dtype of the aggregator value does not match the + dtype of the AtomicRef type. + """ + if atomic_ref.dtype != val: + raise errors.TypingError( + f"Type of value to or: {val} does not match the type of the " + f"reference: {atomic_ref.dtype} stored in the atomic ref." + ) + + if atomic_ref.dtype not in (types.int32, types.int64): + raise errors.TypingError( + "fetch_or operation only supported on int32 and int64 dtypes." + ) + + def ol_fetch_or_impl(atomic_ref, val): + # pylint: disable=no-value-for-parameter + return _intrinsic_fetch_or(atomic_ref, val) + + return ol_fetch_or_impl + + +@overload_method(AtomicRefType, "fetch_xor", target=DPEX_KERNEL_EXP_TARGET_NAME) +def ol_fetch_xor(atomic_ref, val): + """SPIR-V overload for + :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_xor`. + + Generates the same LLVM IR instruction as dpcpp for the + `atomic_ref::fetch_xor` function. + + Raises: + TypingError: When the dtype of the aggregator value does not match the + dtype of the AtomicRef type. + """ + if atomic_ref.dtype != val: + raise errors.TypingError( + f"Type of value to xor: {val} does not match the type of the " + f"reference: {atomic_ref.dtype} stored in the atomic ref." + ) + + if atomic_ref.dtype not in (types.int32, types.int64): + raise errors.TypingError( + "fetch_xor operation only supported on int32 and int64 dtypes." + ) + + def ol_fetch_xor_impl(atomic_ref, val): + # pylint: disable=no-value-for-parameter + return _intrinsic_fetch_xor(atomic_ref, val) + + return ol_fetch_xor_impl