Skip to content

Commit

Permalink
Adding test cases for fetch_phi functions
Browse files Browse the repository at this point in the history
  • Loading branch information
adarshyoga committed Jan 8, 2024
1 parent 7042c6f commit 0ee15fc
Show file tree
Hide file tree
Showing 2 changed files with 299 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def gen(context, builder, sig, args):
"--spirv-ext=+SPV_EXT_shader_atomic_float_add"
]

context.extra_compile_options[LLVM_SPIRV_ARGS] = [
"--spirv-ext=+SPV_EXT_shader_atomic_float_min_max"
]

ptr_type = retty.as_pointer()
ptr_type.addrspace = atomic_ref_ty.address_space

Expand Down Expand Up @@ -118,8 +122,31 @@ def _intrinsic_fetch_add(ty_context, ty_atomic_ref, ty_val):
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_add")


def _atomic_sub_float_wrapper(gen_fn):
def gen(context, builder, sig, args):
# args is a tuple, which is immutable
# covert tuple to list obj first before replacing arg[1]
# with fneg and convert back to tuple again.
args_lst = list(args)
args_lst[1] = builder.fneg(args[1])
args = tuple(args_lst)

gen_fn(context, builder, sig, args)

return gen


@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
def _intrinsic_fetch_sub(ty_context, ty_atomic_ref, ty_val):
if ty_atomic_ref.dtype in (types.float32, types.float64):
# dpcpp does not support ``__spirv_AtomicFSubEXT``. fetch_sub
# for floats is implemented by negating the value and calling fetch_add.
# For example, A.fetch_sub(A, val) is implemented as A.fetch_add(-val).
sig, gen = _intrinsic_helper(
ty_context, ty_atomic_ref, ty_val, "fetch_add"
)
return sig, _atomic_sub_float_wrapper(gen)

return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_sub")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# SPDX-License-Identifier: Apache-2.0

import dpnp
import numpy as np
import pytest
from numba.core.errors import TypingError

import numba_dpex as dpex
import numba_dpex.experimental as dpex_exp
Expand All @@ -14,32 +16,297 @@
no_bool=True, no_float16=True, no_none=True, no_complex=True
)

list_of_int_dtypes = get_all_dtypes(
no_bool=True, no_float=True, no_none=True, no_complex=True
)

list_of_float_dtypes = get_all_dtypes(
no_bool=True, no_int=True, no_float16=True, no_none=True, no_complex=True
)


@pytest.fixture(params=list_of_supported_dtypes)
def input_arrays(request):
def input_arrays_for_add(request):
# The size of input and out arrays to be used
N = 10
a = dpnp.ones(N, dtype=request.param)
b = dpnp.zeros(N, dtype=request.param)
return a, b


@pytest.fixture(params=list_of_supported_dtypes)
def input_arrays_for_sub(request):
# The size of input and out arrays to be used
N = 10
a = dpnp.ones(N, dtype=request.param)
b_np = np.ones(N) * N + 1
b = dpnp.asarray(b_np, dtype=request.param)
return a, b


@pytest.fixture(params=list_of_supported_dtypes)
def input_arrays_for_min_max(request):
# The size of input and out arrays to be used
N = 10
a = dpnp.arange(N, dtype=request.param)
b = dpnp.ones(N, dtype=request.param)
return a, b


@pytest.fixture(params=list_of_int_dtypes)
def input_arrays_for_and_or_xor(request):
# The size of input and out arrays to be used
N = 10
a = dpnp.arange(N, dtype=request.param)
b = dpnp.ones(N, dtype=request.param)
return a, b


@pytest.mark.parametrize("ref_index", [0, 5])
def test_fetch_add(input_arrays, ref_index):
def test_fetch_add(input_arrays_for_add, ref_index):
@dpex_exp.kernel
def atomic_ref_kernel(a, b, ref_index):
def atomic_add_kernel(a, b, ref_index):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=ref_index)
v.fetch_add(a[i])

a, b = input_arrays
a, b = input_arrays_for_add

dpex_exp.call_kernel(atomic_ref_kernel, dpex.Range(10), a, b, ref_index)
dpex_exp.call_kernel(atomic_add_kernel, dpex.Range(10), a, b, ref_index)

# Verify that `a` was accumulated at b[ref_index]
assert b[ref_index] == 10


def test_fetch_add_diff_types():
"""A negative test that verifies that a TypingError is raised if
AtomicRef type and value to be added are of different types.
"""

@dpex_exp.kernel
def atomic_add_kernel(a, b):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=0)
v.fetch_add(a[i])

N = 10
a = dpnp.ones(N, dtype=dpnp.float32)
b = dpnp.zeros(N, dtype=dpnp.int32)

with pytest.raises(TypingError):
dpex_exp.call_kernel(atomic_add_kernel, dpex.Range(10), a, b)


@pytest.mark.parametrize("ref_index", [0, 5])
def test_fetch_sub(input_arrays_for_sub, ref_index):
@dpex_exp.kernel
def atomic_sub_kernel(a, b, ref_index):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=ref_index)
v.fetch_sub(a[i])

a, b = input_arrays_for_sub

dpex_exp.call_kernel(atomic_sub_kernel, dpex.Range(10), a, b, ref_index)

# Verify that b[ref_index] -= a[0:N]
assert b[ref_index] == 1


def test_fetch_sub_diff_types():
"""A negative test that verifies that a TypingError is raised if
AtomicRef type and value to be subtracted are of different types.
"""

@dpex_exp.kernel
def atomic_sub_kernel(a, b):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=0)
v.fetch_sub(a[i])

N = 10
a = dpnp.ones(N, dtype=dpnp.float32)
b = dpnp.zeros(N, dtype=dpnp.int32)

with pytest.raises(TypingError):
dpex_exp.call_kernel(atomic_sub_kernel, dpex.Range(10), a, b)


@pytest.mark.parametrize("ref_index", [0, 5])
def test_fetch_min(input_arrays_for_min_max, ref_index):
@dpex_exp.kernel
def atomic_min_kernel(a, b, ref_index):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=ref_index)
v.fetch_min(a[i])

a, b = input_arrays_for_min_max

dpex_exp.call_kernel(atomic_min_kernel, dpex.Range(10), a, b, ref_index)

# Verify that b[ref_index] is min(a[0:N])
assert b[ref_index] == 0


def test_fetch_min_diff_types():
"""A negative test that verifies that a TypingError is raised if
AtomicRef type and compared value to compute min are of different types.
"""

@dpex_exp.kernel
def atomic_min_kernel(a, b):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=0)
v.fetch_min(a[i])

N = 10
a = dpnp.ones(N, dtype=dpnp.float32)
b = dpnp.zeros(N, dtype=dpnp.int32)

with pytest.raises(TypingError):
dpex_exp.call_kernel(atomic_min_kernel, dpex.Range(10), a, b)


@pytest.mark.parametrize("ref_index", [0, 5])
def test_fetch_max(input_arrays_for_min_max, ref_index):
@dpex_exp.kernel
def atomic_max_kernel(a, b, ref_index):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=ref_index)
v.fetch_max(a[i])

a, b = input_arrays_for_min_max

dpex_exp.call_kernel(atomic_max_kernel, dpex.Range(10), a, b, ref_index)

# Verify that b[ref_index] is max(a[0:N])
assert b[ref_index] == 9


def test_fetch_max_diff_types():
"""A negative test that verifies that a TypingError is raised if
AtomicRef type and compared value to compute max are of different types.
"""

@dpex_exp.kernel
def atomic_max_kernel(a, b):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=0)
v.fetch_max(a[i])

N = 10
a = dpnp.ones(N, dtype=dpnp.float32)
b = dpnp.zeros(N, dtype=dpnp.int32)

with pytest.raises(TypingError):
dpex_exp.call_kernel(atomic_max_kernel, dpex.Range(10), a, b)


@pytest.mark.parametrize("ref_index", [0, 5])
def test_fetch_and(input_arrays_for_and_or_xor, ref_index):
@dpex_exp.kernel
def atomic_and_kernel(a, b, ref_index):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=ref_index)
v.fetch_and(a[i])

a, b = input_arrays_for_and_or_xor

dpex_exp.call_kernel(atomic_and_kernel, dpex.Range(10), a, b, ref_index)

# Verify that b[ref_index] is bitwise and(a[0:N])
assert b[ref_index] == 1


def test_fetch_and_diff_types():
"""A negative test that verifies that a TypingError is raised if
AtomicRef type and value for bitwise and are of different types.
"""

@dpex_exp.kernel
def atomic_and_kernel(a, b):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=0)
v.fetch_and(a[i])

N = 10
a = dpnp.ones(N, dtype=dpnp.float32)
b = dpnp.zeros(N, dtype=dpnp.int32)

with pytest.raises(TypingError):
dpex_exp.call_kernel(atomic_and_kernel, dpex.Range(10), a, b)


@pytest.mark.parametrize("ref_index", [0, 5])
def test_fetch_or(input_arrays_for_and_or_xor, ref_index):
@dpex_exp.kernel
def atomic_or_kernel(a, b, ref_index):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=ref_index)
v.fetch_or(a[i])

a, b = input_arrays_for_and_or_xor

dpex_exp.call_kernel(atomic_or_kernel, dpex.Range(10), a, b, ref_index)

# Verify that b[ref_index] is bitwise or(a[0:N])
assert b[ref_index] == 15


def test_fetch_or_diff_types():
"""A negative test that verifies that a TypingError is raised if
AtomicRef type and value for bitwise or are of different types.
"""

@dpex_exp.kernel
def atomic_or_kernel(a, b):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=0)
v.fetch_or(a[i])

N = 10
a = dpnp.ones(N, dtype=dpnp.float32)
b = dpnp.zeros(N, dtype=dpnp.int32)

with pytest.raises(TypingError):
dpex_exp.call_kernel(atomic_or_kernel, dpex.Range(10), a, b)


@pytest.mark.parametrize("ref_index", [0, 5])
def test_fetch_xor(input_arrays_for_and_or_xor, ref_index):
@dpex_exp.kernel
def atomic_xor_kernel(a, b, ref_index):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=ref_index)
v.fetch_xor(a[i])

a, b = input_arrays_for_and_or_xor

dpex_exp.call_kernel(atomic_xor_kernel, dpex.Range(10), a, b, ref_index)

# Verify that b[ref_index] is bitwise xor(a[0:N])
assert b[ref_index] == 0


def test_fetch_xor_diff_types():
"""A negative test that verifies that a TypingError is raised if
AtomicRef type and value for bitwise xor are of different types.
"""

@dpex_exp.kernel
def atomic_xor_kernel(a, b):
i = dpex.get_global_id(0)
v = AtomicRef(b, index=0)
v.fetch_xor(a[i])

N = 10
a = dpnp.ones(N, dtype=dpnp.float32)
b = dpnp.zeros(N, dtype=dpnp.int32)

with pytest.raises(TypingError):
dpex_exp.call_kernel(atomic_xor_kernel, dpex.Range(10), a, b)


@dpex_exp.kernel
def atomic_ref_0(a):
i = dpex.get_global_id(0)
Expand Down

0 comments on commit 0ee15fc

Please sign in to comment.