Skip to content

Commit

Permalink
Added unit tests for the diagonal function (#23)
Browse files Browse the repository at this point in the history
* added diagonal tests

* added random tests

* modified random tests

* fixed formatting issues for random tests

* reformatted manage_array file

* reformatted diag tests, manage array

---------

Co-authored-by: Chaluvadi <[email protected]>
  • Loading branch information
sakchal and Chaluvadi authored Mar 6, 2024
1 parent a657583 commit f90ef61
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def get_scalar(arr: AFArray, dtype: Dtype, /) -> int | float | complex | bool |
out = dtype.c_type()
call_from_clib(get_scalar.__name__, ctypes.pointer(out), arr)
if dtype == c32 or dtype == c64:
return complex(out[0], out[1]) # type: ignore
return complex(out[0], out[1]) # type: ignore
else:
return cast(int | float | complex | bool | None, out.value)

Expand Down
43 changes: 43 additions & 0 deletions tests/test_diag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest

import arrayfire_wrapper.dtypes as dtypes
import arrayfire_wrapper.lib as wrapper


@pytest.mark.parametrize("diagonal_shape", [(2,), (10,), (100,), (1000,)])
def test_diagonal_shape(diagonal_shape: tuple) -> None:
"""Test if diagonal array is keeping the shape of the passed into the input array"""
in_arr = wrapper.constant(1, diagonal_shape, dtypes.s16)
diag_array = wrapper.diag_create(in_arr, 0)

extracted_diagonal = wrapper.diag_extract(diag_array, 0)

assert wrapper.get_dims(extracted_diagonal)[0 : len(diagonal_shape)] == diagonal_shape # noqa: E203


@pytest.mark.parametrize("diagonal_shape", [(2,), (10,), (100,), (1000,)])
def test_diagonal_val(diagonal_shape: tuple) -> None:
"""Test if diagonal array is keeping the same value as that of the values passed into the input array"""
dtype = dtypes.s16
in_arr = wrapper.constant(1, diagonal_shape, dtype)
diag_array = wrapper.diag_create(in_arr, 0)

extracted_diagonal = wrapper.diag_extract(diag_array, 0)

assert wrapper.get_scalar(extracted_diagonal, dtype) == wrapper.get_scalar(in_arr, dtype)


@pytest.mark.parametrize(
"diagonal_shape",
[
(10, 10, 10),
(100, 100, 100, 100),
],
)
def test_invalid_diagonal(diagonal_shape: tuple) -> None:
"""Test if an invalid diagonal shape is being properly handled"""
with pytest.raises(RuntimeError):
in_arr = wrapper.constant(1, diagonal_shape, dtypes.s16)
diag_array = wrapper.diag_create(in_arr, 0)

wrapper.diag_extract(diag_array, 0)
111 changes: 111 additions & 0 deletions tests/test_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import random

import pytest

import arrayfire_wrapper.dtypes as dtypes
import arrayfire_wrapper.lib as wrapper


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10), 1),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
def test_randu_shape(shape: tuple) -> None:
"""Test if randu function creates an array with the correct shape."""
dtype = dtypes.s16

result = wrapper.randu(shape, dtype)

assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10), 1),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
def test_random_uniform_shape(shape: tuple) -> None:
"""Test if rand uniform function creates an array with the correct shape."""
dtype = dtypes.s16
engine = wrapper.create_random_engine(100, 10)

result = wrapper.random_uniform(shape, dtype, engine)

assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10), 1),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
def test_randn_shape(shape: tuple) -> None:
"""Test if randn function creates an array with the correct shape."""
dtype = dtypes.f32

result = wrapper.randn(shape, dtype)

assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10), 1),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
def test_random_normal_shape(shape: tuple) -> None:
"""Test if random normal function creates an array with the correct shape."""
dtype = dtypes.f32
engine = wrapper.create_random_engine(100, 10)

result = wrapper.random_normal(shape, dtype, engine)

assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203


@pytest.mark.parametrize(
"engine_index",
[100, 200, 300],
)
def test_create_random_engine(engine_index: int) -> None:
engine = wrapper.create_random_engine(engine_index, 10)

engine_type = wrapper.random_engine_get_type(engine)

assert engine_type == engine_index


@pytest.mark.parametrize(
"invalid_index",
[random.randint(301, 600), random.randint(301, 600), random.randint(301, 600)],
)
def test_invalid_random_engine(invalid_index: int) -> None:
"Test if invalid engine types are properly handled"
with pytest.raises(RuntimeError):

invalid_engine = wrapper.create_random_engine(invalid_index, 10)

engine_type = wrapper.random_engine_get_type(invalid_engine)

assert engine_type == invalid_engine

0 comments on commit f90ef61

Please sign in to comment.