From f90ef6182c98561c45f5f27a5a64500fbc71f03e Mon Sep 17 00:00:00 2001 From: Saketh Chaluvadi <87680444+sakchal@users.noreply.github.com> Date: Wed, 6 Mar 2024 16:58:41 -0500 Subject: [PATCH] Added unit tests for the diagonal function (#23) * added diagonal tests * added random tests * modified random tests * fixed formatting issues for random tests * reformatted manage_array file * reformatted diag tests, manage array --------- Co-authored-by: Chaluvadi --- .../create_and_modify_array/manage_array.py | 2 +- tests/test_diag.py | 43 +++++++ tests/test_random.py | 111 ++++++++++++++++++ 3 files changed, 155 insertions(+), 1 deletion(-) create mode 100644 tests/test_diag.py create mode 100644 tests/test_random.py diff --git a/arrayfire_wrapper/lib/create_and_modify_array/manage_array.py b/arrayfire_wrapper/lib/create_and_modify_array/manage_array.py index 52d4cba..34ad512 100644 --- a/arrayfire_wrapper/lib/create_and_modify_array/manage_array.py +++ b/arrayfire_wrapper/lib/create_and_modify_array/manage_array.py @@ -166,7 +166,7 @@ def get_scalar(arr: AFArray, dtype: Dtype, /) -> int | float | complex | bool | out = dtype.c_type() call_from_clib(get_scalar.__name__, ctypes.pointer(out), arr) if dtype == c32 or dtype == c64: - return complex(out[0], out[1]) # type: ignore + return complex(out[0], out[1]) # type: ignore else: return cast(int | float | complex | bool | None, out.value) diff --git a/tests/test_diag.py b/tests/test_diag.py new file mode 100644 index 0000000..f666b08 --- /dev/null +++ b/tests/test_diag.py @@ -0,0 +1,43 @@ +import pytest + +import arrayfire_wrapper.dtypes as dtypes +import arrayfire_wrapper.lib as wrapper + + +@pytest.mark.parametrize("diagonal_shape", [(2,), (10,), (100,), (1000,)]) +def test_diagonal_shape(diagonal_shape: tuple) -> None: + """Test if diagonal array is keeping the shape of the passed into the input array""" + in_arr = wrapper.constant(1, diagonal_shape, dtypes.s16) + diag_array = wrapper.diag_create(in_arr, 0) + + extracted_diagonal = wrapper.diag_extract(diag_array, 0) + + assert wrapper.get_dims(extracted_diagonal)[0 : len(diagonal_shape)] == diagonal_shape # noqa: E203 + + +@pytest.mark.parametrize("diagonal_shape", [(2,), (10,), (100,), (1000,)]) +def test_diagonal_val(diagonal_shape: tuple) -> None: + """Test if diagonal array is keeping the same value as that of the values passed into the input array""" + dtype = dtypes.s16 + in_arr = wrapper.constant(1, diagonal_shape, dtype) + diag_array = wrapper.diag_create(in_arr, 0) + + extracted_diagonal = wrapper.diag_extract(diag_array, 0) + + assert wrapper.get_scalar(extracted_diagonal, dtype) == wrapper.get_scalar(in_arr, dtype) + + +@pytest.mark.parametrize( + "diagonal_shape", + [ + (10, 10, 10), + (100, 100, 100, 100), + ], +) +def test_invalid_diagonal(diagonal_shape: tuple) -> None: + """Test if an invalid diagonal shape is being properly handled""" + with pytest.raises(RuntimeError): + in_arr = wrapper.constant(1, diagonal_shape, dtypes.s16) + diag_array = wrapper.diag_create(in_arr, 0) + + wrapper.diag_extract(diag_array, 0) diff --git a/tests/test_random.py b/tests/test_random.py new file mode 100644 index 0000000..d46cd8e --- /dev/null +++ b/tests/test_random.py @@ -0,0 +1,111 @@ +import random + +import pytest + +import arrayfire_wrapper.dtypes as dtypes +import arrayfire_wrapper.lib as wrapper + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10), 1), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +def test_randu_shape(shape: tuple) -> None: + """Test if randu function creates an array with the correct shape.""" + dtype = dtypes.s16 + + result = wrapper.randu(shape, dtype) + + assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203 + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10), 1), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +def test_random_uniform_shape(shape: tuple) -> None: + """Test if rand uniform function creates an array with the correct shape.""" + dtype = dtypes.s16 + engine = wrapper.create_random_engine(100, 10) + + result = wrapper.random_uniform(shape, dtype, engine) + + assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203 + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10), 1), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +def test_randn_shape(shape: tuple) -> None: + """Test if randn function creates an array with the correct shape.""" + dtype = dtypes.f32 + + result = wrapper.randn(shape, dtype) + + assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203 + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10), 1), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +def test_random_normal_shape(shape: tuple) -> None: + """Test if random normal function creates an array with the correct shape.""" + dtype = dtypes.f32 + engine = wrapper.create_random_engine(100, 10) + + result = wrapper.random_normal(shape, dtype, engine) + + assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203 + + +@pytest.mark.parametrize( + "engine_index", + [100, 200, 300], +) +def test_create_random_engine(engine_index: int) -> None: + engine = wrapper.create_random_engine(engine_index, 10) + + engine_type = wrapper.random_engine_get_type(engine) + + assert engine_type == engine_index + + +@pytest.mark.parametrize( + "invalid_index", + [random.randint(301, 600), random.randint(301, 600), random.randint(301, 600)], +) +def test_invalid_random_engine(invalid_index: int) -> None: + "Test if invalid engine types are properly handled" + with pytest.raises(RuntimeError): + + invalid_engine = wrapper.create_random_engine(invalid_index, 10) + + engine_type = wrapper.random_engine_get_type(invalid_engine) + + assert engine_type == invalid_engine