Skip to content

Commit

Permalink
sign, sin, sinh, square, sqrt, subtract, tan, tanh, trunc; finished a…
Browse files Browse the repository at this point in the history
…ll of the free elementwise functions
  • Loading branch information
jpivarski committed Dec 30, 2023
1 parent 1fea779 commit 00194d4
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 19 deletions.
28 changes: 9 additions & 19 deletions src/ragged/_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,8 +1156,7 @@ def sign(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.sign.html
"""

assert x, "TODO"
assert False, "TODO 100"
return _box(type(x), np.sign(*_unbox(x)))


def sin(x: array, /) -> array:
Expand All @@ -1177,8 +1176,7 @@ def sin(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.sin.html
"""

assert x, "TODO"
assert False, "TODO 101"
return _box(type(x), np.sin(*_unbox(x)))


def sinh(x: array, /) -> array:
Expand All @@ -1201,8 +1199,7 @@ def sinh(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.sinh.html
"""

assert x, "TODO"
assert False, "TODO 102"
return _box(type(x), np.sinh(*_unbox(x)))


def square(x: array, /) -> array:
Expand All @@ -1223,8 +1220,7 @@ def square(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.square.html
"""

assert x, "TODO"
assert False, "TODO 103"
return _box(type(x), np.square(*_unbox(x)))


def sqrt(x: array, /) -> array:
Expand All @@ -1243,8 +1239,7 @@ def sqrt(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.sqrt.html
"""

assert x, "TODO"
assert False, "TODO 104"
return _box(type(x), np.sqrt(*_unbox(x)))


def subtract(x1: array, x2: array, /) -> array:
Expand All @@ -1266,9 +1261,7 @@ def subtract(x1: array, x2: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.subtract.html
"""

assert x1, "TODO"
assert x2, "TODO"
assert False, "TODO 105"
return _box(type(x1), np.subtract(*_unbox(x1, x2)))


def tan(x: array, /) -> array:
Expand All @@ -1288,8 +1281,7 @@ def tan(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.tan.html
"""

assert x, "TODO"
assert False, "TODO 106"
return _box(type(x), np.tan(*_unbox(x)))


def tanh(x: array, /) -> array:
Expand All @@ -1315,8 +1307,7 @@ def tanh(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.tanh.html
"""

assert x, "TODO"
assert False, "TODO 107"
return _box(type(x), np.tanh(*_unbox(x)))


def trunc(x: array, /) -> array:
Expand All @@ -1334,5 +1325,4 @@ def trunc(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.trunc.html
"""

assert x, "TODO"
assert False, "TODO 108"
return _box(type(x), np.trunc(*_unbox(x)))
81 changes: 81 additions & 0 deletions tests/test_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,84 @@ def test_round_complex(device, x_complex):
assert result.shape == x_complex.shape
assert xp.round(first(x_complex)) == first(result)
assert xp.round(first(x_complex)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_sign(device, x):
result = ragged.sign(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.sign(first(x)) == first(result)
assert xp.sign(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_sin(device, x):
result = ragged.sin(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.sin(first(x)) == pytest.approx(first(result))
assert xp.sin(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_sinh(device, x):
result = ragged.sinh(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.sinh(first(x)) == pytest.approx(first(result))
assert xp.sinh(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_square(device, x):
result = ragged.square(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.square(first(x)) == first(result)
assert xp.square(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_sqrt(device, x):
result = ragged.sqrt(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.sqrt(first(x)) == pytest.approx(first(result))
assert xp.sqrt(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_subtract(device, x, y):
result = ragged.subtract(x.to_device(device), y.to_device(device))
assert type(result) is type(x) is type(y)
assert result.shape in (x.shape, y.shape)
assert xp.subtract(first(x), first(y)) == first(result)
assert xp.subtract(first(x), first(y)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_tan(device, x):
result = ragged.tan(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.tan(first(x)) == pytest.approx(first(result))
assert xp.tan(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_tanh(device, x):
result = ragged.tanh(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.tanh(first(x)) == pytest.approx(first(result))
assert xp.tanh(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_trunc(device, x):
result = ragged.trunc(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.trunc(first(x)) == first(result)
assert xp.trunc(first(x)).dtype == result.dtype

0 comments on commit 00194d4

Please sign in to comment.