Skip to content

Commit

Permalink
isfinite, isinf, isnan, less, less_equal, log, log1p, log2, log10, lo…
Browse files Browse the repository at this point in the history
…gaddexp, logical_and, logical_not, logical_or, logical_xor
  • Loading branch information
jpivarski committed Dec 29, 2023
1 parent eae61dd commit 89743e0
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 34 deletions.
48 changes: 14 additions & 34 deletions src/ragged/_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,8 +691,7 @@ def isfinite(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isfinite.html
"""

assert x, "TODO"
assert False, "TODO 78"
return _box(type(x), np.isfinite(*_unbox(x)))


def isinf(x: array, /) -> array:
Expand All @@ -710,8 +709,7 @@ def isinf(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isinf.html
"""

assert x, "TODO"
assert False, "TODO 79"
return _box(type(x), np.isinf(*_unbox(x)))


def isnan(x: array, /) -> array:
Expand All @@ -729,8 +727,7 @@ def isnan(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isnan.html
"""

assert x, "TODO"
assert False, "TODO 80"
return _box(type(x), np.isnan(*_unbox(x)))


def less(x1: array, x2: array, /) -> array:
Expand All @@ -750,9 +747,7 @@ def less(x1: array, x2: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.less.html
"""

assert x1, "TODO"
assert x2, "TODO"
assert False, "TODO 81"
return _box(type(x1), np.less(*_unbox(x1, x2)))


def less_equal(x1: array, x2: array, /) -> array:
Expand All @@ -772,9 +767,7 @@ def less_equal(x1: array, x2: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.less_equal.html
"""

assert x1, "TODO"
assert x2, "TODO"
assert False, "TODO 82"
return _box(type(x1), np.less_equal(*_unbox(x1, x2)))


def log(x: array, /) -> array:
Expand All @@ -793,8 +786,7 @@ def log(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.log.html
"""

assert x, "TODO"
assert False, "TODO 83"
return _box(type(x), np.log(*_unbox(x)))


def log1p(x: array, /) -> array:
Expand All @@ -817,8 +809,7 @@ def log1p(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.log1p.html
"""

assert x, "TODO"
assert False, "TODO 84"
return _box(type(x), np.log1p(*_unbox(x)))


def log2(x: array, /) -> array:
Expand All @@ -837,8 +828,7 @@ def log2(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.log2.html
"""

assert x, "TODO"
assert False, "TODO 85"
return _box(type(x), np.log2(*_unbox(x)))


def log10(x: array, /) -> array:
Expand All @@ -857,8 +847,7 @@ def log10(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.log10.html
"""

assert x, "TODO"
assert False, "TODO 86"
return _box(type(x), np.log10(*_unbox(x)))


def logaddexp(x1: array, x2: array, /) -> array:
Expand All @@ -878,9 +867,7 @@ def logaddexp(x1: array, x2: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.logaddexp.html
"""

assert x1, "TODO"
assert x2, "TODO"
assert False, "TODO 87"
return _box(type(x1), np.logaddexp(*_unbox(x1, x2)))


def logical_and(x1: array, x2: array, /) -> array:
Expand All @@ -899,9 +886,7 @@ def logical_and(x1: array, x2: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_and.html
"""

assert x1, "TODO"
assert x2, "TODO"
assert False, "TODO 88"
return _box(type(x1), np.logical_and(*_unbox(x1, x2)))


def logical_not(x: array, /) -> array:
Expand All @@ -918,8 +903,7 @@ def logical_not(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_not.html
"""

assert x, "TODO"
assert False, "TODO 89"
return _box(type(x), np.logical_not(*_unbox(x)))


def logical_or(x1: array, x2: array, /) -> array:
Expand All @@ -938,9 +922,7 @@ def logical_or(x1: array, x2: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_or.html
"""

assert x1, "TODO"
assert x2, "TODO"
assert False, "TODO 90"
return _box(type(x1), np.logical_or(*_unbox(x1, x2)))


def logical_xor(x1: array, x2: array, /) -> array:
Expand All @@ -959,9 +941,7 @@ def logical_xor(x1: array, x2: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_xor.html
"""

assert x1, "TODO"
assert x2, "TODO"
assert False, "TODO 91"
return _box(type(x1), np.logical_xor(*_unbox(x1, x2)))


def multiply(x1: array, x2: array, /) -> array:
Expand Down
137 changes: 137 additions & 0 deletions tests/test_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ def x_lt1(request):
return ragged.array(np.array(0.5))


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x_bool(request):
if request.param == "regular":
return ragged.array(np.array([False, True, False]))
elif request.param == "irregular":
return ragged.array(ak.Array([[True, True, False], [], [False, False]]))
else: # request.param == "scalar"
return ragged.array(np.array(True))


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x_int(request):
if request.param == "regular":
Expand All @@ -71,6 +81,7 @@ def x_complex(request):

y = x
y_lt1 = x_lt1
y_bool = x_bool
y_int = x_int
y_complex = x_complex

Expand Down Expand Up @@ -422,3 +433,129 @@ def test_imag(device, x_complex):
assert result.shape == x_complex.shape
assert xp.imag(first(x_complex)) == first(result)
assert xp.imag(first(x_complex)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_isfinite(device, x):
result = ragged.isfinite(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.isfinite(first(x)) == first(result)
assert xp.isfinite(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_isinf(device, x):
result = ragged.isinf(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.isinf(first(x)) == first(result)
assert xp.isinf(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_isnan(device, x):
result = ragged.isnan(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.isnan(first(x)) == first(result)
assert xp.isnan(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_less(device, x, y):
result = ragged.less(x.to_device(device), y.to_device(device))
assert type(result) is type(x) is type(y)
assert result.shape in (x.shape, y.shape)
assert xp.less(first(x), first(y)) == first(result)
assert xp.less(first(x), first(y)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_less_equal(device, x, y):
result = ragged.less_equal(x.to_device(device), y.to_device(device))
assert type(result) is type(x) is type(y)
assert result.shape in (x.shape, y.shape)
assert xp.less_equal(first(x), first(y)) == first(result)
assert xp.less_equal(first(x), first(y)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_log(device, x):
result = ragged.log(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.log(first(x)) == pytest.approx(first(result))
assert xp.log(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_log1p(device, x):
result = ragged.log1p(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.log1p(first(x)) == pytest.approx(first(result))
assert xp.log1p(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_log2(device, x):
result = ragged.log2(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.log2(first(x)) == pytest.approx(first(result))
assert xp.log2(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_log10(device, x):
result = ragged.log10(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.log10(first(x)) == pytest.approx(first(result))
assert xp.log10(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_logaddexp(device, x, y):
result = ragged.logaddexp(x.to_device(device), y.to_device(device))
assert type(result) is type(x) is type(y)
assert result.shape in (x.shape, y.shape)
assert xp.logaddexp(first(x), first(y)) == pytest.approx(first(result))
assert xp.logaddexp(first(x), first(y)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_logical_and(device, x_bool, y_bool):
result = ragged.logical_and(x_bool.to_device(device), y_bool.to_device(device))
assert type(result) is type(x_bool) is type(y_bool)
assert result.shape in (x_bool.shape, y_bool.shape)
assert xp.logical_and(first(x_bool), first(y_bool)) == first(result)
assert xp.logical_and(first(x_bool), first(y_bool)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_logical_not(device, x_bool):
result = ragged.logical_not(x_bool.to_device(device))
assert type(result) is type(x_bool)
assert result.shape == x_bool.shape
assert xp.logical_not(first(x_bool)) == first(result)
assert xp.logical_not(first(x_bool)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_logical_or(device, x_bool, y_bool):
result = ragged.logical_or(x_bool.to_device(device), y_bool.to_device(device))
assert type(result) is type(x_bool) is type(y_bool)
assert result.shape in (x_bool.shape, y_bool.shape)
assert xp.logical_or(first(x_bool), first(y_bool)) == first(result)
assert xp.logical_or(first(x_bool), first(y_bool)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_logical_xor(device, x_bool, y_bool):
result = ragged.logical_xor(x_bool.to_device(device), y_bool.to_device(device))
assert type(result) is type(x_bool) is type(y_bool)
assert result.shape in (x_bool.shape, y_bool.shape)
assert xp.logical_xor(first(x_bool), first(y_bool)) == first(result)
assert xp.logical_xor(first(x_bool), first(y_bool)).dtype == result.dtype

0 comments on commit 89743e0

Please sign in to comment.