Skip to content

Commit

Permalink
exp, expm1, floor, floor_divide
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Dec 29, 2023
1 parent d6b5490 commit 9e40df0
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 9 deletions.
13 changes: 4 additions & 9 deletions src/ragged/_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,7 @@ def exp(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.exp.html
"""

assert x, "TODO"
assert False, "TODO 71"
return _box(type(x), np.exp(*_unbox(x)))


def expm1(x: array, /) -> array:
Expand All @@ -566,8 +565,7 @@ def expm1(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.expm1.html
"""

assert x, "TODO"
assert False, "TODO 72"
return _box(type(x), np.expm1(*_unbox(x)))


def floor(x: array, /) -> array:
Expand All @@ -586,8 +584,7 @@ def floor(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.floor.html
"""

assert x, "TODO"
assert False, "TODO 73"
return _box(type(x), np.floor(*_unbox(x)), dtype=x.dtype)


def floor_divide(x1: array, x2: array, /) -> array:
Expand All @@ -608,9 +605,7 @@ def floor_divide(x1: array, x2: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.floor_divide.html
"""

assert x1, "TODO"
assert x2, "TODO"
assert False, "TODO 74"
return _box(type(x1), np.floor_divide(*_unbox(x1, x2)))


def greater(x1: array, x2: array, /) -> array:
Expand Down
64 changes: 64 additions & 0 deletions tests/test_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ def test_ceil(device, x):
assert xp.ceil(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_ceil_int(device, x_int):
result = ragged.ceil(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.ceil(first(x_int)) == first(result)
assert xp.ceil(first(x_int)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_conj(device, x_complex):
result = ragged.conj(x_complex.to_device(device))
Expand Down Expand Up @@ -331,3 +340,58 @@ def test_equal(device, x, y):
assert result.shape in (x.shape, y.shape)
assert xp.equal(first(x), first(y)) == first(result)
assert xp.equal(first(x), first(y)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_exp(device, x):
result = ragged.exp(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.exp(first(x)) == pytest.approx(first(result))
assert xp.exp(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_expm1(device, x):
result = ragged.expm1(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.expm1(first(x)) == pytest.approx(first(result))
assert xp.expm1(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_floor(device, x):
result = ragged.floor(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.floor(first(x)) == first(result)
assert xp.floor(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_floor_int(device, x_int):
result = ragged.floor(x_int.to_device(device))
assert type(result) is type(x_int)
assert result.shape == x_int.shape
assert xp.floor(first(x_int)) == first(result)
assert xp.floor(first(x_int)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_floor_divide(device, x, y):
result = ragged.floor_divide(x.to_device(device), y.to_device(device))
assert type(result) is type(x) is type(y)
assert result.shape in (x.shape, y.shape)
assert xp.floor_divide(first(x), first(y)) == first(result)
assert xp.floor_divide(first(x), first(y)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_floor_divide_int(device, x_int, y_int):
with np.errstate(divide="ignore"):
result = ragged.floor_divide(x_int.to_device(device), y_int.to_device(device))
assert type(result) is type(x_int) is type(y_int)
assert result.shape in (x_int.shape, y_int.shape)
assert xp.floor_divide(first(x_int), first(y_int)) == first(result)
assert xp.floor_divide(first(x_int), first(y_int)).dtype == result.dtype

0 comments on commit 9e40df0

Please sign in to comment.