Skip to content

Commit

Permalink
Adding overloads and tests for fetch_* atomic functions
Browse files Browse the repository at this point in the history
  • Loading branch information
adarshyoga committed Jan 9, 2024
1 parent d12331e commit 4f309af
Show file tree
Hide file tree
Showing 3 changed files with 532 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def gen(context, builder, sig, args):
"--spirv-ext=+SPV_EXT_shader_atomic_float_add"
]

context.extra_compile_options[LLVM_SPIRV_ARGS] = [
"--spirv-ext=+SPV_EXT_shader_atomic_float_min_max"
]

ptr_type = retty.as_pointer()
ptr_type.addrspace = atomic_ref_ty.address_space

Expand Down Expand Up @@ -118,6 +122,59 @@ def _intrinsic_fetch_add(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_add")


def _atomic_sub_float_wrapper(gen_fn):
def gen(context, builder, sig, args):
# args is a tuple, which is immutable
# covert tuple to list obj first before replacing arg[1]
# with fneg and convert back to tuple again.
args_lst = list(args)
args_lst[1] = builder.fneg(args[1])
args = tuple(args_lst)

gen_fn(context, builder, sig, args)

return gen


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_sub(ty_context, ty_atomic_ref, ty_val):
if ty_atomic_ref.dtype in (types.float32, types.float64):
# dpcpp does not support ``__spirv_AtomicFSubEXT``. fetch_sub
# for floats is implemented by negating the value and calling fetch_add.
# For example, A.fetch_sub(A, val) is implemented as A.fetch_add(-val).
sig, gen = _intrinsic_helper(
ty_context, ty_atomic_ref, ty_val, "fetch_add"
)
return sig, _atomic_sub_float_wrapper(gen)

return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_sub")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_min(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_min")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_max(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_max")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_and(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_and")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_or(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_or")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_xor(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_xor")


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_atomic_ref_ctor(
ty_context, ref, ty_index, ty_retty_ref # pylint: disable=unused-argument
Expand Down Expand Up @@ -294,3 +351,168 @@ def ol_fetch_add_impl(atomic_ref, val):
return _intrinsic_fetch_add(atomic_ref, val)

return ol_fetch_add_impl


@overload_method(AtomicRefType, "fetch_sub", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_sub(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_sub`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_sub` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to sub: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_fetch_sub_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_sub(atomic_ref, val)

return ol_fetch_sub_impl


@overload_method(AtomicRefType, "fetch_min", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_min(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_min`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_min` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to find min: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_fetch_min_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_min(atomic_ref, val)

return ol_fetch_min_impl


@overload_method(AtomicRefType, "fetch_max", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_max(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_max`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_max` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to find max: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

def ol_fetch_max_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_max(atomic_ref, val)

return ol_fetch_max_impl


@overload_method(AtomicRefType, "fetch_and", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_and(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_and`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_and` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to and: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

if atomic_ref.dtype not in (types.int32, types.int64):
raise errors.TypingError(
"fetch_and operation only supported on int32 and int64 dtypes."
)

def ol_fetch_and_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_and(atomic_ref, val)

return ol_fetch_and_impl


@overload_method(AtomicRefType, "fetch_or", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_or(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_or`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_or` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to or: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

if atomic_ref.dtype not in (types.int32, types.int64):
raise errors.TypingError(
"fetch_or operation only supported on int32 and int64 dtypes."
)

def ol_fetch_or_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_or(atomic_ref, val)

return ol_fetch_or_impl


@overload_method(AtomicRefType, "fetch_xor", target=DPEX_KERNEL_EXP_TARGET_NAME)
def ol_fetch_xor(atomic_ref, val):
"""SPIR-V overload for
:meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_xor`.
Generates the same LLVM IR instruction as dpcpp for the
`atomic_ref::fetch_xor` function.
Raises:
TypingError: When the dtype of the aggregator value does not match the
dtype of the AtomicRef type.
"""
if atomic_ref.dtype != val:
raise errors.TypingError(
f"Type of value to xor: {val} does not match the type of the "
f"reference: {atomic_ref.dtype} stored in the atomic ref."
)

if atomic_ref.dtype not in (types.int32, types.int64):
raise errors.TypingError(
"fetch_xor operation only supported on int32 and int64 dtypes."
)

def ol_fetch_xor_impl(atomic_ref, val):
# pylint: disable=no-value-for-parameter
return _intrinsic_fetch_xor(atomic_ref, val)

return ol_fetch_xor_impl
5 changes: 4 additions & 1 deletion numba_dpex/spirv_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def finalize(self):
# TODO: find better approach to set SPIRV compiler arguments. Workaround
# against caching intrinsic that sets this argument.
# https://github.com/IntelPython/numba-dpex/issues/1262
llvm_spirv_args = ["--spirv-ext=+SPV_EXT_shader_atomic_float_add"]
llvm_spirv_args = [
"--spirv-ext=+SPV_EXT_shader_atomic_float_add",
"--spirv-ext=+SPV_EXT_shader_atomic_float_min_max",
]
for key in list(self.context.extra_compile_options.keys()):
if key == LLVM_SPIRV_ARGS:
llvm_spirv_args = self.context.extra_compile_options[key]
Expand Down
Loading

0 comments on commit 4f309af

Please sign in to comment.