Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebased master branch #47

Merged
merged 5 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 52 additions & 27 deletions tests/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,23 @@

import pytest

import arrayfire_wrapper.dtypes as dtypes
import arrayfire_wrapper.lib as wrapper
from arrayfire_wrapper.dtypes import (
Dtype,
c32,
c64,
c_api_value_to_dtype,
f16,
f32,
f64,
s16,
s32,
s64,
u8,
u16,
u32,
u64,
)

invalid_shape = (
random.randint(1, 10),
Expand All @@ -14,6 +29,9 @@
)


all_types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]


@pytest.mark.parametrize(
"shape",
[
Expand All @@ -27,7 +45,7 @@
def test_constant_shape(shape: tuple) -> None:
"""Test if constant creates an array with the correct shape."""
number = 5.0
dtype = dtypes.s16
dtype = s16

result = wrapper.constant(number, shape, dtype)

Expand All @@ -46,9 +64,8 @@ def test_constant_shape(shape: tuple) -> None:
)
def test_constant_complex_shape(shape: tuple) -> None:
"""Test if constant_complex creates an array with the correct shape."""
dtype = dtypes.c32
dtype = c32

dtype = dtypes.c32
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

Expand All @@ -71,7 +88,7 @@ def test_constant_complex_shape(shape: tuple) -> None:
)
def test_constant_long_shape(shape: tuple) -> None:
"""Test if constant_long creates an array with the correct shape."""
dtype = dtypes.s64
dtype = s64
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

Expand All @@ -93,7 +110,7 @@ def test_constant_long_shape(shape: tuple) -> None:
)
def test_constant_ulong_shape(shape: tuple) -> None:
"""Test if constant_ulong creates an array with the correct shape."""
dtype = dtypes.u64
dtype = u64
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

Expand All @@ -109,15 +126,15 @@ def test_constant_shape_invalid() -> None:
"""Test if constant handles a shape with greater than 4 dimensions"""
with pytest.raises(TypeError):
number = 5.0
dtype = dtypes.s16
dtype = s16

wrapper.constant(number, invalid_shape, dtype)


def test_constant_complex_shape_invalid() -> None:
"""Test if constant_complex handles a shape with greater than 4 dimensions"""
with pytest.raises(TypeError):
dtype = dtypes.c32
dtype = c32
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

Expand All @@ -128,7 +145,7 @@ def test_constant_complex_shape_invalid() -> None:
def test_constant_long_shape_invalid() -> None:
"""Test if constant_long handles a shape with greater than 4 dimensions"""
with pytest.raises(TypeError):
dtype = dtypes.s64
dtype = s64
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

Expand All @@ -139,7 +156,7 @@ def test_constant_long_shape_invalid() -> None:
def test_constant_ulong_shape_invalid() -> None:
"""Test if constant_ulong handles a shape with greater than 4 dimensions"""
with pytest.raises(TypeError):
dtype = dtypes.u64
dtype = u64
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

Expand All @@ -148,50 +165,47 @@ def test_constant_ulong_shape_invalid() -> None:


@pytest.mark.parametrize(
"dtype_index",
[i for i in range(13)],
"dtype",
all_types,
)
def test_constant_dtype(dtype_index: int) -> None:
def test_constant_dtype(dtype: Dtype) -> None:
"""Test if constant creates an array with the correct dtype."""
if dtype_index in [1, 3] or (dtype_index == 2 and not wrapper.get_dbl_support()):
if is_cmplx_type(dtype) or not is_system_supported(dtype):
pytest.skip()

dtype = dtypes.c_api_value_to_dtype(dtype_index)

rand_array = wrapper.randu((1, 1), dtype)
value = wrapper.get_scalar(rand_array, dtype)
shape = (2, 2)
if isinstance(value, (int, float)):
result = wrapper.constant(value, shape, dtype)
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
else:
pytest.skip()


@pytest.mark.parametrize(
"dtype_index",
[i for i in range(13)],
"dtype",
all_types,
)
def test_constant_complex_dtype(dtype_index: int) -> None:
def test_constant_complex_dtype(dtype: Dtype) -> None:
"""Test if constant_complex creates an array with the correct dtype."""
if dtype_index not in [1, 3] or (dtype_index == 3 and not wrapper.get_dbl_support()):
if not is_cmplx_type(dtype) or not is_system_supported(dtype):
pytest.skip()

dtype = dtypes.c_api_value_to_dtype(dtype_index)
rand_array = wrapper.randu((1, 1), dtype)
value = wrapper.get_scalar(rand_array, dtype)
shape = (2, 2)

if isinstance(value, (int, float, complex)):
result = wrapper.constant_complex(value, shape, dtype)
assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
else:
pytest.skip()


def test_constant_long_dtype() -> None:
"""Test if constant_long creates an array with the correct dtype."""
dtype = dtypes.s64
dtype = s64

rand_array = wrapper.randu((1, 1), dtype)
value = wrapper.get_scalar(rand_array, dtype)
Expand All @@ -200,14 +214,14 @@ def test_constant_long_dtype() -> None:
if isinstance(value, (int, float)):
result = wrapper.constant_long(value, shape, dtype)

assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
else:
pytest.skip()


def test_constant_ulong_dtype() -> None:
"""Test if constant_ulong creates an array with the correct dtype."""
dtype = dtypes.u64
dtype = u64

rand_array = wrapper.randu((1, 1), dtype)
value = wrapper.get_scalar(rand_array, dtype)
Expand All @@ -216,6 +230,17 @@ def test_constant_ulong_dtype() -> None:
if isinstance(value, (int, float)):
result = wrapper.constant_ulong(value, shape, dtype)

assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
else:
pytest.skip()


def is_cmplx_type(dtype: Dtype) -> bool:
return dtype == c32 or dtype == c64


def is_system_supported(dtype: Dtype) -> bool:
if dtype in [f64, c64] and not wrapper.get_dbl_support():
return False

return True
61 changes: 61 additions & 0 deletions tests/test_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import random

import pytest

import arrayfire_wrapper.dtypes as dtypes
import arrayfire_wrapper.lib as wrapper


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10), 1),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
def test_range_shape(shape: tuple) -> None:
"""Test if the range function output an AFArray with the correct shape"""
dim = 2
dtype = dtypes.s16

result = wrapper.range(shape, dim, dtype)

assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203


def test_range_invalid_shape() -> None:
"""Test if range function correctly handles an invalid shape"""
with pytest.raises(TypeError):
shape = (
random.randint(1, 10),
random.randint(1, 10),
random.randint(1, 10),
random.randint(1, 10),
random.randint(1, 10),
)
dim = 2
dtype = dtypes.s16

wrapper.range(shape, dim, dtype)


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10), 1),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
def test_range_invalid_dim(shape: tuple) -> None:
"""Test if the range function can properly handle and invalid dimension given"""
with pytest.raises(RuntimeError):
dim = random.randint(4, 10)
dtype = dtypes.s16

wrapper.range(shape, dim, dtype)
Loading