Skip to content

Commit

Permalink
Use 'numpy.array_api' to test these functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Dec 29, 2023
1 parent dc9e7de commit 6abc202
Showing 1 changed file with 42 additions and 36 deletions.
78 changes: 42 additions & 36 deletions tests/test_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@

from __future__ import annotations

import warnings
from typing import Any

import awkward as ak
import numpy as np

with warnings.catch_warnings():
warnings.simplefilter("ignore")
import numpy.array_api as xp

import pytest

import ragged
Expand All @@ -26,11 +32,11 @@
@pytest.fixture(params=["regular", "irregular", "scalar"])
def x(request):
if request.param == "regular":
return ragged.array(np.array([1, 2, 3], dtype=np.int64))
return ragged.array(np.array([1.0, 2.0, 3.0]))
elif request.param == "irregular":
return ragged.array(ak.Array([[1.1, 1.2, 1.3], [], [1.4, 1.5]]))
return ragged.array(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]]))
else: # request.param == "scalar"
return ragged.array(np.array(10, dtype=np.int64))
return ragged.array(np.array(10.0))


@pytest.fixture(params=["regular", "irregular", "scalar"])
Expand Down Expand Up @@ -60,7 +66,7 @@ def x_int(request):

def first(x: ragged.array) -> Any:
out = ak.flatten(x._impl, axis=None)[0] if x.shape != () else x._impl
return out.item()
return xp.asarray(out.item(), dtype=x.dtype)


def test_existence():
Expand Down Expand Up @@ -130,116 +136,116 @@ def test_abs(device, x):
result = ragged.abs(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert result.dtype == x.dtype
assert np.absolute(first(x)) == first(result)
assert xp.abs(first(x)) == first(result)
assert xp.abs(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_acos(device, x_lt1):
result = ragged.acos(x_lt1.to_device(device))
assert type(result) is type(x_lt1)
assert result.shape == x_lt1.shape
assert result.dtype == np.dtype(np.float64)
assert np.arccos(first(x_lt1)) == pytest.approx(first(result))
assert xp.acos(first(x_lt1)) == pytest.approx(first(result))
assert xp.acos(first(x_lt1)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_acosh(device, x):
result = ragged.acosh(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert result.dtype == np.dtype(np.float64)
assert np.arccosh(first(x)) == pytest.approx(first(result))
assert xp.acosh(first(x)) == pytest.approx(first(result))
assert xp.acosh(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_add(device, x, y):
result = ragged.add(x.to_device(device), y.to_device(device))
assert type(result) is type(x) is type(y)
assert result.shape in (x.shape, y.shape)
assert result.dtype in (x.dtype, y.dtype)
assert np.add(first(x), first(y)) == first(result)
assert xp.add(first(x), first(y)) == first(result)
assert xp.add(first(x), first(y)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_asin(device, x_lt1):
result = ragged.asin(x_lt1.to_device(device))
assert type(result) is type(x_lt1)
assert result.shape == x_lt1.shape
assert result.dtype == np.dtype(np.float64)
assert np.arcsin(first(x_lt1)) == pytest.approx(first(result))
assert xp.asin(first(x_lt1)) == pytest.approx(first(result))
assert xp.asin(first(x_lt1)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_asinh(device, x):
result = ragged.asinh(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert result.dtype == np.dtype(np.float64)
assert np.arcsinh(first(x)) == pytest.approx(first(result))
assert xp.asinh(first(x)) == pytest.approx(first(result))
assert xp.asinh(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_atan(device, x):
result = ragged.atan(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert result.dtype == np.dtype(np.float64)
assert np.arctan(first(x)) == pytest.approx(first(result))
assert xp.atan(first(x)) == pytest.approx(first(result))
assert xp.atan(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_atan2(device, x, y):
result = ragged.atan2(y.to_device(device), x.to_device(device))
assert type(result) is type(x) is type(y)
assert result.shape in (x.shape, y.shape)
assert result.dtype == np.dtype(np.float64)
assert np.arctan2(first(y), first(x)) == pytest.approx(first(result))
assert xp.atan2(first(y), first(x)) == pytest.approx(first(result))
assert xp.atan2(first(y), first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_atanh(device, x_lt1):
result = ragged.atanh(x_lt1.to_device(device))
assert type(result) is type(x_lt1)
assert result.shape == x_lt1.shape
assert result.dtype == np.dtype(np.float64)
assert np.arctanh(first(x_lt1)) == pytest.approx(first(result))
assert xp.atanh(first(x_lt1)) == pytest.approx(first(result))
assert xp.atanh(first(x_lt1)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_bitwise_and(device, x_int, y_int):
result = ragged.bitwise_and(x_int.to_device(device), y_int.to_device(device))
assert type(result) is type(x_int) is type(y_int)
assert result.shape in (x_int.shape, y_int.shape)
assert result.dtype == np.dtype(np.int64)
assert np.bitwise_and(first(x_int), first(y_int)) == first(result)
assert xp.bitwise_and(first(x_int), first(y_int)) == first(result)
assert xp.bitwise_and(first(x_int), first(y_int)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_bitwise_invert(device, x_int):
result = ragged.bitwise_invert(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert result.dtype == np.dtype(np.int64)
assert np.invert(first(x_int)) == first(result)
assert xp.bitwise_invert(first(x_int)) == first(result)
assert xp.bitwise_invert(first(x_int)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_bitwise_left_shift(device, x_int, y_int):
result = ragged.bitwise_left_shift(x_int.to_device(device), y_int.to_device(device))
assert type(result) is type(x_int) is type(y_int)
assert result.shape in (x_int.shape, y_int.shape)
assert result.dtype == np.dtype(np.int64)
assert np.left_shift(first(x_int), first(y_int)) == first(result)
assert xp.bitwise_left_shift(first(x_int), first(y_int)) == first(result)
assert xp.bitwise_left_shift(first(x_int), first(y_int)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_bitwise_or(device, x_int, y_int):
result = ragged.bitwise_or(x_int.to_device(device), y_int.to_device(device))
assert type(result) is type(x_int) is type(y_int)
assert result.shape in (x_int.shape, y_int.shape)
assert result.dtype == np.dtype(np.int64)
assert np.bitwise_or(first(x_int), first(y_int)) == first(result)
assert xp.bitwise_or(first(x_int), first(y_int)) == first(result)
assert xp.bitwise_or(first(x_int), first(y_int)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
Expand All @@ -249,23 +255,23 @@ def test_bitwise_right_shift(device, x_int, y_int):
)
assert type(result) is type(x_int) is type(y_int)
assert result.shape in (x_int.shape, y_int.shape)
assert result.dtype == np.dtype(np.int64)
assert np.right_shift(first(x_int), first(y_int)) == first(result)
assert xp.bitwise_right_shift(first(x_int), first(y_int)) == first(result)
assert xp.bitwise_right_shift(first(x_int), first(y_int)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_bitwise_xor(device, x_int, y_int):
result = ragged.bitwise_xor(x_int.to_device(device), y_int.to_device(device))
assert type(result) is type(x_int) is type(y_int)
assert result.shape in (x_int.shape, y_int.shape)
assert result.dtype == np.dtype(np.int64)
assert np.bitwise_xor(first(x_int), first(y_int)) == first(result)
assert xp.bitwise_xor(first(x_int), first(y_int)) == first(result)
assert xp.bitwise_xor(first(x_int), first(y_int)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_ceil(device, x):
result = ragged.ceil(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert result.dtype == x.dtype
assert np.ceil(first(x)) == first(result)
assert xp.ceil(first(x)) == first(result)
assert xp.ceil(first(x)).dtype == result.dtype

0 comments on commit 6abc202

Please sign in to comment.