Skip to content

Commit 59601ff

Browse files
committed
TESTS: Reorganizing the tests folder for easier testing
- Fixing the af.display function when called with an alias
1 parent 303bea3 commit 59601ff

26 files changed

+975
-755
lines changed

arrayfire/array.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,12 @@ def display(a):
10411041
Multi dimensional arrayfire array
10421042
"""
10431043
expr = inspect.stack()[1][-2]
1044-
if (expr is not None):
1045-
print('%s' % expr[0].split('display(')[1][:-2])
1046-
safe_call(backend.get().af_print_array(a.arr))
1044+
1045+
try:
1046+
if (expr is not None):
1047+
st = expr[0].find('(') + 1
1048+
en = expr[0].rfind(')')
1049+
print('%s' % expr[0][st:en])
1050+
safe_call(backend.get().af_print_array(a.arr))
1051+
except:
1052+
safe_call(backend.get().af_print_array(a.arr))

tests/__main__.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#######################################################
2+
# Copyright (c) 2015, ArrayFire
3+
# All rights reserved.
4+
#
5+
# This file is distributed under 3-clause BSD license.
6+
# The complete license agreement can be obtained at:
7+
# http://arrayfire.com/licenses/BSD-3-Clause
8+
########################################################
9+
10+
import sys
11+
from simple_tests import *
12+
13+
tests = {}
14+
tests['simple'] = simple.tests
15+
16+
def assert_valid(name, name_list, name_str):
17+
is_valid = any([name == val for val in name_list])
18+
if not is_valid:
19+
err_str = "The first argument needs to be a %s name\n" % name_str
20+
err_str += "List of supported %ss: %s" % (name_str, str(list(name_list)))
21+
raise RuntimeError(err_str)
22+
23+
if __name__ == "__main__":
24+
25+
module_name = None
26+
num_args = len(sys.argv)
27+
28+
if (num_args > 1):
29+
module_name = sys.argv[1].lower()
30+
assert_valid(sys.argv[1].lower(), tests.keys(), "module")
31+
32+
if (module_name is None):
33+
for name in tests:
34+
tests[name].run()
35+
else:
36+
test = tests[module_name]
37+
test_list = None
38+
39+
if (num_args > 2):
40+
test_list = sys.argv[2:]
41+
for test_name in test_list:
42+
assert_valid(test_name.lower(), test.keys(), "test")
43+
44+
test.run(test_list)

tests/simple/__init__.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#######################################################
2+
# Copyright (c) 2015, ArrayFire
3+
# All rights reserved.
4+
#
5+
# This file is distributed under 3-clause BSD license.
6+
# The complete license agreement can be obtained at:
7+
# http://arrayfire.com/licenses/BSD-3-Clause
8+
########################################################
9+
10+
from .algorithm import *
11+
from .arith import *
12+
from .array_test import *
13+
from .blas import *
14+
from .data import *
15+
from .device import *
16+
from .image import *
17+
from .index import *
18+
from .lapack import *
19+
from .signal import *
20+
from .statistics import *
21+
from ._util import tests

tests/simple/_util.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#######################################################
2+
# Copyright (c) 2015, ArrayFire
3+
# All rights reserved.
4+
#
5+
# This file is distributed under 3-clause BSD license.
6+
# The complete license agreement can be obtained at:
7+
# http://arrayfire.com/licenses/BSD-3-Clause
8+
########################################################
9+
10+
import arrayfire as af
11+
12+
def display_func(verbose):
13+
if (verbose):
14+
return af.display
15+
else:
16+
def eval_func(foo):
17+
res = foo
18+
return eval_func
19+
20+
def print_func(verbose):
21+
def print_func_impl(*args):
22+
if (verbose):
23+
print(args)
24+
else:
25+
res = [args]
26+
return print_func_impl
27+
28+
class _simple_test_dict(dict):
29+
30+
def __init__(self):
31+
self.print_str = "Simple %16s: %s"
32+
super(_simple_test_dict, self).__init__()
33+
34+
def run(self, name_list=None, verbose=False):
35+
test_list = name_list if name_list is not None else self.keys()
36+
for key in test_list:
37+
38+
try:
39+
test = self[key]
40+
except:
41+
print(self.print_str % (key, "NOTFOUND"))
42+
continue
43+
44+
try:
45+
test(verbose)
46+
print(self.print_str % (key, "PASSED"))
47+
except:
48+
print(self.print_str % (key, "FAILED"))
49+
50+
tests = _simple_test_dict()

tests/simple/algorithm.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#!/usr/bin/python
2+
#######################################################
3+
# Copyright (c) 2015, ArrayFire
4+
# All rights reserved.
5+
#
6+
# This file is distributed under 3-clause BSD license.
7+
# The complete license agreement can be obtained at:
8+
# http://arrayfire.com/licenses/BSD-3-Clause
9+
########################################################
10+
11+
import arrayfire as af
12+
from . import _util
13+
14+
def simple_algorithm(verbose = False):
15+
display_func = _util.display_func(verbose)
16+
print_func = _util.print_func(verbose)
17+
18+
a = af.randu(3, 3)
19+
20+
print_func(af.sum(a), af.product(a), af.min(a), af.max(a),
21+
af.count(a), af.any_true(a), af.all_true(a))
22+
23+
display_func(af.sum(a, 0))
24+
display_func(af.sum(a, 1))
25+
26+
display_func(af.product(a, 0))
27+
display_func(af.product(a, 1))
28+
29+
display_func(af.min(a, 0))
30+
display_func(af.min(a, 1))
31+
32+
display_func(af.max(a, 0))
33+
display_func(af.max(a, 1))
34+
35+
display_func(af.count(a, 0))
36+
display_func(af.count(a, 1))
37+
38+
display_func(af.any_true(a, 0))
39+
display_func(af.any_true(a, 1))
40+
41+
display_func(af.all_true(a, 0))
42+
display_func(af.all_true(a, 1))
43+
44+
display_func(af.accum(a, 0))
45+
display_func(af.accum(a, 1))
46+
47+
display_func(af.sort(a, is_ascending=True))
48+
display_func(af.sort(a, is_ascending=False))
49+
50+
val,idx = af.sort_index(a, is_ascending=True)
51+
display_func(val)
52+
display_func(idx)
53+
val,idx = af.sort_index(a, is_ascending=False)
54+
display_func(val)
55+
display_func(idx)
56+
57+
b = af.randu(3,3)
58+
keys,vals = af.sort_by_key(a, b, is_ascending=True)
59+
display_func(keys)
60+
display_func(vals)
61+
keys,vals = af.sort_by_key(a, b, is_ascending=False)
62+
display_func(keys)
63+
display_func(vals)
64+
65+
c = af.randu(5,1)
66+
d = af.randu(5,1)
67+
cc = af.set_unique(c, is_sorted=False)
68+
dd = af.set_unique(af.sort(d), is_sorted=True)
69+
display_func(cc)
70+
display_func(dd)
71+
72+
display_func(af.set_union(cc, dd, is_unique=True))
73+
display_func(af.set_union(cc, dd, is_unique=False))
74+
75+
display_func(af.set_intersect(cc, cc, is_unique=True))
76+
display_func(af.set_intersect(cc, cc, is_unique=False))
77+
78+
_util.tests['algorithm'] = simple_algorithm

0 commit comments

Comments
 (0)