Skip to content

Commit 5b7b5e8

Browse files
committed
Adding test cases for fetch_phi functions
1 parent b1e6b2a commit 5b7b5e8

File tree

2 files changed

+299
-5
lines changed

2 files changed

+299
-5
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ def gen(context, builder, sig, args):
6969
"--spirv-ext=+SPV_EXT_shader_atomic_float_add"
7070
]
7171

72+
context.extra_compile_options[LLVM_SPIRV_ARGS] = [
73+
"--spirv-ext=+SPV_EXT_shader_atomic_float_min_max"
74+
]
75+
7276
ptr_type = retty.as_pointer()
7377
ptr_type.addrspace = atomic_ref_ty.address_space
7478

@@ -118,8 +122,31 @@ def _intrinsic_fetch_add(ty_context, ty_atomic_ref, ty_val):
118122
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_add")
119123

120124

125+
def _atomic_sub_float_wrapper(gen_fn):
126+
def gen(context, builder, sig, args):
127+
# args is a tuple, which is immutable
128+
# covert tuple to list obj first before replacing arg[1]
129+
# with fneg and convert back to tuple again.
130+
args_lst = list(args)
131+
args_lst[1] = builder.fneg(args[1])
132+
args = tuple(args_lst)
133+
134+
gen_fn(context, builder, sig, args)
135+
136+
return gen
137+
138+
121139
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
122140
def _intrinsic_fetch_sub(ty_context, ty_atomic_ref, ty_val):
141+
if ty_atomic_ref.dtype in (types.float32, types.float64):
142+
# dpcpp does not support ``__spirv_AtomicFSubEXT``. fetch_sub
143+
# for floats is implemented by negating the value and calling fetch_add.
144+
# For example, A.fetch_sub(A, val) is implemented as A.fetch_add(-val).
145+
sig, gen = _intrinsic_helper(
146+
ty_context, ty_atomic_ref, ty_val, "fetch_add"
147+
)
148+
return sig, _atomic_sub_float_wrapper(gen)
149+
123150
return _intrinsic_helper(ty_context, ty_atomic_ref, ty_val, "fetch_sub")
124151

125152

numba_dpex/tests/experimental/kernel_iface/spv_overloads/test_atomic_fetch_phi.py

Lines changed: 272 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import dpnp
6+
import numpy as np
67
import pytest
8+
from numba.core.errors import TypingError
79

810
import numba_dpex as dpex
911
import numba_dpex.experimental as dpex_exp
@@ -14,32 +16,297 @@
1416
no_bool=True, no_float16=True, no_none=True, no_complex=True
1517
)
1618

19+
list_of_int_dtypes = get_all_dtypes(
20+
no_bool=True, no_float=True, no_none=True, no_complex=True
21+
)
22+
23+
list_of_float_dtypes = get_all_dtypes(
24+
no_bool=True, no_int=True, no_float16=True, no_none=True, no_complex=True
25+
)
26+
1727

1828
@pytest.fixture(params=list_of_supported_dtypes)
19-
def input_arrays(request):
29+
def input_arrays_for_add(request):
2030
# The size of input and out arrays to be used
2131
N = 10
2232
a = dpnp.ones(N, dtype=request.param)
2333
b = dpnp.zeros(N, dtype=request.param)
2434
return a, b
2535

2636

37+
@pytest.fixture(params=list_of_supported_dtypes)
38+
def input_arrays_for_sub(request):
39+
# The size of input and out arrays to be used
40+
N = 10
41+
a = dpnp.ones(N, dtype=request.param)
42+
b_np = np.ones(N) * N + 1
43+
b = dpnp.asarray(b_np, dtype=request.param)
44+
return a, b
45+
46+
47+
@pytest.fixture(params=list_of_supported_dtypes)
48+
def input_arrays_for_min_max(request):
49+
# The size of input and out arrays to be used
50+
N = 10
51+
a = dpnp.arange(N, dtype=request.param)
52+
b = dpnp.ones(N, dtype=request.param)
53+
return a, b
54+
55+
56+
@pytest.fixture(params=list_of_int_dtypes)
57+
def input_arrays_for_and_or_xor(request):
58+
# The size of input and out arrays to be used
59+
N = 10
60+
a = dpnp.arange(N, dtype=request.param)
61+
b = dpnp.ones(N, dtype=request.param)
62+
return a, b
63+
64+
2765
@pytest.mark.parametrize("ref_index", [0, 5])
28-
def test_fetch_add(input_arrays, ref_index):
66+
def test_fetch_add(input_arrays_for_add, ref_index):
2967
@dpex_exp.kernel
30-
def atomic_ref_kernel(a, b, ref_index):
68+
def atomic_add_kernel(a, b, ref_index):
3169
i = dpex.get_global_id(0)
3270
v = AtomicRef(b, index=ref_index)
3371
v.fetch_add(a[i])
3472

35-
a, b = input_arrays
73+
a, b = input_arrays_for_add
3674

37-
dpex_exp.call_kernel(atomic_ref_kernel, dpex.Range(10), a, b, ref_index)
75+
dpex_exp.call_kernel(atomic_add_kernel, dpex.Range(10), a, b, ref_index)
3876

3977
# Verify that `a` was accumulated at b[ref_index]
4078
assert b[ref_index] == 10
4179

4280

81+
def test_fetch_add_diff_types():
82+
"""A negative test that verifies that a TypingError is raised if
83+
AtomicRef type and value to be added are of different types.
84+
"""
85+
86+
@dpex_exp.kernel
87+
def atomic_add_kernel(a, b):
88+
i = dpex.get_global_id(0)
89+
v = AtomicRef(b, index=0)
90+
v.fetch_add(a[i])
91+
92+
N = 10
93+
a = dpnp.ones(N, dtype=dpnp.float32)
94+
b = dpnp.zeros(N, dtype=dpnp.int32)
95+
96+
with pytest.raises(TypingError):
97+
dpex_exp.call_kernel(atomic_add_kernel, dpex.Range(10), a, b)
98+
99+
100+
@pytest.mark.parametrize("ref_index", [0, 5])
101+
def test_fetch_sub(input_arrays_for_sub, ref_index):
102+
@dpex_exp.kernel
103+
def atomic_sub_kernel(a, b, ref_index):
104+
i = dpex.get_global_id(0)
105+
v = AtomicRef(b, index=ref_index)
106+
v.fetch_sub(a[i])
107+
108+
a, b = input_arrays_for_sub
109+
110+
dpex_exp.call_kernel(atomic_sub_kernel, dpex.Range(10), a, b, ref_index)
111+
112+
# Verify that b[ref_index] -= a[0:N]
113+
assert b[ref_index] == 1
114+
115+
116+
def test_fetch_sub_diff_types():
117+
"""A negative test that verifies that a TypingError is raised if
118+
AtomicRef type and value to be subtracted are of different types.
119+
"""
120+
121+
@dpex_exp.kernel
122+
def atomic_sub_kernel(a, b):
123+
i = dpex.get_global_id(0)
124+
v = AtomicRef(b, index=0)
125+
v.fetch_sub(a[i])
126+
127+
N = 10
128+
a = dpnp.ones(N, dtype=dpnp.float32)
129+
b = dpnp.zeros(N, dtype=dpnp.int32)
130+
131+
with pytest.raises(TypingError):
132+
dpex_exp.call_kernel(atomic_sub_kernel, dpex.Range(10), a, b)
133+
134+
135+
@pytest.mark.parametrize("ref_index", [0, 5])
136+
def test_fetch_min(input_arrays_for_min_max, ref_index):
137+
@dpex_exp.kernel
138+
def atomic_min_kernel(a, b, ref_index):
139+
i = dpex.get_global_id(0)
140+
v = AtomicRef(b, index=ref_index)
141+
v.fetch_min(a[i])
142+
143+
a, b = input_arrays_for_min_max
144+
145+
dpex_exp.call_kernel(atomic_min_kernel, dpex.Range(10), a, b, ref_index)
146+
147+
# Verify that b[ref_index] is min(a[0:N])
148+
assert b[ref_index] == 0
149+
150+
151+
def test_fetch_min_diff_types():
152+
"""A negative test that verifies that a TypingError is raised if
153+
AtomicRef type and compared value to compute min are of different types.
154+
"""
155+
156+
@dpex_exp.kernel
157+
def atomic_min_kernel(a, b):
158+
i = dpex.get_global_id(0)
159+
v = AtomicRef(b, index=0)
160+
v.fetch_min(a[i])
161+
162+
N = 10
163+
a = dpnp.ones(N, dtype=dpnp.float32)
164+
b = dpnp.zeros(N, dtype=dpnp.int32)
165+
166+
with pytest.raises(TypingError):
167+
dpex_exp.call_kernel(atomic_min_kernel, dpex.Range(10), a, b)
168+
169+
170+
@pytest.mark.parametrize("ref_index", [0, 5])
171+
def test_fetch_max(input_arrays_for_min_max, ref_index):
172+
@dpex_exp.kernel
173+
def atomic_max_kernel(a, b, ref_index):
174+
i = dpex.get_global_id(0)
175+
v = AtomicRef(b, index=ref_index)
176+
v.fetch_max(a[i])
177+
178+
a, b = input_arrays_for_min_max
179+
180+
dpex_exp.call_kernel(atomic_max_kernel, dpex.Range(10), a, b, ref_index)
181+
182+
# Verify that b[ref_index] is max(a[0:N])
183+
assert b[ref_index] == 9
184+
185+
186+
def test_fetch_max_diff_types():
187+
"""A negative test that verifies that a TypingError is raised if
188+
AtomicRef type and compared value to compute max are of different types.
189+
"""
190+
191+
@dpex_exp.kernel
192+
def atomic_max_kernel(a, b):
193+
i = dpex.get_global_id(0)
194+
v = AtomicRef(b, index=0)
195+
v.fetch_max(a[i])
196+
197+
N = 10
198+
a = dpnp.ones(N, dtype=dpnp.float32)
199+
b = dpnp.zeros(N, dtype=dpnp.int32)
200+
201+
with pytest.raises(TypingError):
202+
dpex_exp.call_kernel(atomic_max_kernel, dpex.Range(10), a, b)
203+
204+
205+
@pytest.mark.parametrize("ref_index", [0, 5])
206+
def test_fetch_and(input_arrays_for_and_or_xor, ref_index):
207+
@dpex_exp.kernel
208+
def atomic_and_kernel(a, b, ref_index):
209+
i = dpex.get_global_id(0)
210+
v = AtomicRef(b, index=ref_index)
211+
v.fetch_and(a[i])
212+
213+
a, b = input_arrays_for_and_or_xor
214+
215+
dpex_exp.call_kernel(atomic_and_kernel, dpex.Range(10), a, b, ref_index)
216+
217+
# Verify that b[ref_index] is bitwise and(a[0:N])
218+
assert b[ref_index] == 1
219+
220+
221+
def test_fetch_and_diff_types():
222+
"""A negative test that verifies that a TypingError is raised if
223+
AtomicRef type and value for bitwise and are of different types.
224+
"""
225+
226+
@dpex_exp.kernel
227+
def atomic_and_kernel(a, b):
228+
i = dpex.get_global_id(0)
229+
v = AtomicRef(b, index=0)
230+
v.fetch_and(a[i])
231+
232+
N = 10
233+
a = dpnp.ones(N, dtype=dpnp.float32)
234+
b = dpnp.zeros(N, dtype=dpnp.int32)
235+
236+
with pytest.raises(TypingError):
237+
dpex_exp.call_kernel(atomic_and_kernel, dpex.Range(10), a, b)
238+
239+
240+
@pytest.mark.parametrize("ref_index", [0, 5])
241+
def test_fetch_or(input_arrays_for_and_or_xor, ref_index):
242+
@dpex_exp.kernel
243+
def atomic_or_kernel(a, b, ref_index):
244+
i = dpex.get_global_id(0)
245+
v = AtomicRef(b, index=ref_index)
246+
v.fetch_or(a[i])
247+
248+
a, b = input_arrays_for_and_or_xor
249+
250+
dpex_exp.call_kernel(atomic_or_kernel, dpex.Range(10), a, b, ref_index)
251+
252+
# Verify that b[ref_index] is bitwise or(a[0:N])
253+
assert b[ref_index] == 15
254+
255+
256+
def test_fetch_or_diff_types():
257+
"""A negative test that verifies that a TypingError is raised if
258+
AtomicRef type and value for bitwise or are of different types.
259+
"""
260+
261+
@dpex_exp.kernel
262+
def atomic_or_kernel(a, b):
263+
i = dpex.get_global_id(0)
264+
v = AtomicRef(b, index=0)
265+
v.fetch_or(a[i])
266+
267+
N = 10
268+
a = dpnp.ones(N, dtype=dpnp.float32)
269+
b = dpnp.zeros(N, dtype=dpnp.int32)
270+
271+
with pytest.raises(TypingError):
272+
dpex_exp.call_kernel(atomic_or_kernel, dpex.Range(10), a, b)
273+
274+
275+
@pytest.mark.parametrize("ref_index", [0, 5])
276+
def test_fetch_xor(input_arrays_for_and_or_xor, ref_index):
277+
@dpex_exp.kernel
278+
def atomic_xor_kernel(a, b, ref_index):
279+
i = dpex.get_global_id(0)
280+
v = AtomicRef(b, index=ref_index)
281+
v.fetch_xor(a[i])
282+
283+
a, b = input_arrays_for_and_or_xor
284+
285+
dpex_exp.call_kernel(atomic_xor_kernel, dpex.Range(10), a, b, ref_index)
286+
287+
# Verify that b[ref_index] is bitwise xor(a[0:N])
288+
assert b[ref_index] == 0
289+
290+
291+
def test_fetch_xor_diff_types():
292+
"""A negative test that verifies that a TypingError is raised if
293+
AtomicRef type and value for bitwise xor are of different types.
294+
"""
295+
296+
@dpex_exp.kernel
297+
def atomic_xor_kernel(a, b):
298+
i = dpex.get_global_id(0)
299+
v = AtomicRef(b, index=0)
300+
v.fetch_xor(a[i])
301+
302+
N = 10
303+
a = dpnp.ones(N, dtype=dpnp.float32)
304+
b = dpnp.zeros(N, dtype=dpnp.int32)
305+
306+
with pytest.raises(TypingError):
307+
dpex_exp.call_kernel(atomic_xor_kernel, dpex.Range(10), a, b)
308+
309+
43310
@dpex_exp.kernel
44311
def atomic_ref_0(a):
45312
i = dpex.get_global_id(0)

0 commit comments

Comments
 (0)