diff --git a/tests/test_constants.py b/tests/test_constants.py index 855d94a..e513a0b 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -2,8 +2,23 @@ import pytest -import arrayfire_wrapper.dtypes as dtypes import arrayfire_wrapper.lib as wrapper +from arrayfire_wrapper.dtypes import ( + Dtype, + c32, + c64, + c_api_value_to_dtype, + f16, + f32, + f64, + s16, + s32, + s64, + u8, + u16, + u32, + u64, +) invalid_shape = ( random.randint(1, 10), @@ -14,6 +29,9 @@ ) +all_types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64] + + @pytest.mark.parametrize( "shape", [ @@ -27,7 +45,7 @@ def test_constant_shape(shape: tuple) -> None: """Test if constant creates an array with the correct shape.""" number = 5.0 - dtype = dtypes.s16 + dtype = s16 result = wrapper.constant(number, shape, dtype) @@ -46,9 +64,8 @@ def test_constant_shape(shape: tuple) -> None: ) def test_constant_complex_shape(shape: tuple) -> None: """Test if constant_complex creates an array with the correct shape.""" - dtype = dtypes.c32 + dtype = c32 - dtype = dtypes.c32 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -71,7 +88,7 @@ def test_constant_complex_shape(shape: tuple) -> None: ) def test_constant_long_shape(shape: tuple) -> None: """Test if constant_long creates an array with the correct shape.""" - dtype = dtypes.s64 + dtype = s64 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -93,7 +110,7 @@ def test_constant_long_shape(shape: tuple) -> None: ) def test_constant_ulong_shape(shape: tuple) -> None: """Test if constant_ulong creates an array with the correct shape.""" - dtype = dtypes.u64 + dtype = u64 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -109,7 +126,7 @@ def test_constant_shape_invalid() -> None: """Test if constant handles a shape with greater than 4 dimensions""" with pytest.raises(TypeError): number = 5.0 - dtype = dtypes.s16 + dtype = s16 wrapper.constant(number, invalid_shape, dtype) @@ -117,7 +134,7 @@ def test_constant_shape_invalid() -> None: def test_constant_complex_shape_invalid() -> None: """Test if constant_complex handles a shape with greater than 4 dimensions""" with pytest.raises(TypeError): - dtype = dtypes.c32 + dtype = c32 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -128,7 +145,7 @@ def test_constant_complex_shape_invalid() -> None: def test_constant_long_shape_invalid() -> None: """Test if constant_long handles a shape with greater than 4 dimensions""" with pytest.raises(TypeError): - dtype = dtypes.s64 + dtype = s64 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -139,7 +156,7 @@ def test_constant_long_shape_invalid() -> None: def test_constant_ulong_shape_invalid() -> None: """Test if constant_ulong handles a shape with greater than 4 dimensions""" with pytest.raises(TypeError): - dtype = dtypes.u64 + dtype = u64 rand_array = wrapper.randu((1, 1), dtype) number = wrapper.get_scalar(rand_array, dtype) @@ -148,50 +165,47 @@ def test_constant_ulong_shape_invalid() -> None: @pytest.mark.parametrize( - "dtype_index", - [i for i in range(13)], + "dtype", + all_types, ) -def test_constant_dtype(dtype_index: int) -> None: +def test_constant_dtype(dtype: Dtype) -> None: """Test if constant creates an array with the correct dtype.""" - if dtype_index in [1, 3] or (dtype_index == 2 and not wrapper.get_dbl_support()): + if is_cmplx_type(dtype) or not is_system_supported(dtype): pytest.skip() - dtype = dtypes.c_api_value_to_dtype(dtype_index) - rand_array = wrapper.randu((1, 1), dtype) value = wrapper.get_scalar(rand_array, dtype) shape = (2, 2) if isinstance(value, (int, float)): result = wrapper.constant(value, shape, dtype) - assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype + assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype else: pytest.skip() @pytest.mark.parametrize( - "dtype_index", - [i for i in range(13)], + "dtype", + all_types, ) -def test_constant_complex_dtype(dtype_index: int) -> None: +def test_constant_complex_dtype(dtype: Dtype) -> None: """Test if constant_complex creates an array with the correct dtype.""" - if dtype_index not in [1, 3] or (dtype_index == 3 and not wrapper.get_dbl_support()): + if not is_cmplx_type(dtype) or not is_system_supported(dtype): pytest.skip() - dtype = dtypes.c_api_value_to_dtype(dtype_index) rand_array = wrapper.randu((1, 1), dtype) value = wrapper.get_scalar(rand_array, dtype) shape = (2, 2) if isinstance(value, (int, float, complex)): result = wrapper.constant_complex(value, shape, dtype) - assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype + assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype else: pytest.skip() def test_constant_long_dtype() -> None: """Test if constant_long creates an array with the correct dtype.""" - dtype = dtypes.s64 + dtype = s64 rand_array = wrapper.randu((1, 1), dtype) value = wrapper.get_scalar(rand_array, dtype) @@ -200,14 +214,14 @@ def test_constant_long_dtype() -> None: if isinstance(value, (int, float)): result = wrapper.constant_long(value, shape, dtype) - assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype + assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype else: pytest.skip() def test_constant_ulong_dtype() -> None: """Test if constant_ulong creates an array with the correct dtype.""" - dtype = dtypes.u64 + dtype = u64 rand_array = wrapper.randu((1, 1), dtype) value = wrapper.get_scalar(rand_array, dtype) @@ -216,6 +230,17 @@ def test_constant_ulong_dtype() -> None: if isinstance(value, (int, float)): result = wrapper.constant_ulong(value, shape, dtype) - assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype + assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype else: pytest.skip() + + +def is_cmplx_type(dtype: Dtype) -> bool: + return dtype == c32 or dtype == c64 + + +def is_system_supported(dtype: Dtype) -> bool: + if dtype in [f64, c64] and not wrapper.get_dbl_support(): + return False + + return True diff --git a/tests/test_range.py b/tests/test_range.py new file mode 100644 index 0000000..1571698 --- /dev/null +++ b/tests/test_range.py @@ -0,0 +1,61 @@ +import random + +import pytest + +import arrayfire_wrapper.dtypes as dtypes +import arrayfire_wrapper.lib as wrapper + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10), 1), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +def test_range_shape(shape: tuple) -> None: + """Test if the range function output an AFArray with the correct shape""" + dim = 2 + dtype = dtypes.s16 + + result = wrapper.range(shape, dim, dtype) + + assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203 + + +def test_range_invalid_shape() -> None: + """Test if range function correctly handles an invalid shape""" + with pytest.raises(TypeError): + shape = ( + random.randint(1, 10), + random.randint(1, 10), + random.randint(1, 10), + random.randint(1, 10), + random.randint(1, 10), + ) + dim = 2 + dtype = dtypes.s16 + + wrapper.range(shape, dim, dtype) + + +@pytest.mark.parametrize( + "shape", + [ + (), + (random.randint(1, 10), 1), + (random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + (random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)), + ], +) +def test_range_invalid_dim(shape: tuple) -> None: + """Test if the range function can properly handle and invalid dimension given""" + with pytest.raises(RuntimeError): + dim = random.randint(4, 10) + dtype = dtypes.s16 + + wrapper.range(shape, dim, dtype)