Skip to content

Commit

Permalink
greater, greater_equal, imag
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Dec 29, 2023
1 parent 9e40df0 commit eae61dd
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
22 changes: 13 additions & 9 deletions src/ragged/_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from __future__ import annotations

import warnings

import numpy as np

from ._spec_array_object import _box, _unbox, array
Expand Down Expand Up @@ -441,7 +443,7 @@ def conj(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.conj.html
"""

return _box(type(x), np.conj(*_unbox(x)))
return _box(type(x), np.conjugate(*_unbox(x)))


def cos(x: array, /) -> array:
Expand Down Expand Up @@ -625,9 +627,7 @@ def greater(x1: array, x2: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.greater.html
"""

assert x1, "TODO"
assert x2, "TODO"
assert False, "TODO 75"
return _box(type(x1), np.greater(*_unbox(x1, x2)))


def greater_equal(x1: array, x2: array, /) -> array:
Expand All @@ -647,9 +647,7 @@ def greater_equal(x1: array, x2: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.greater_equal.html
"""

assert x1, "TODO"
assert x2, "TODO"
assert False, "TODO 76"
return _box(type(x1), np.greater_equal(*_unbox(x1, x2)))


def imag(x: array, /) -> array:
Expand All @@ -669,8 +667,14 @@ def imag(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.imag.html
"""

assert x, "TODO"
assert False, "TODO 77"
(a,) = _unbox(x)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return _box(
type(x),
(a - np.conjugate(a)) / 2j,
dtype=np.dtype(f"f{x.dtype.itemsize // 2}"),
)


def isfinite(x: array, /) -> array:
Expand Down
33 changes: 30 additions & 3 deletions tests/test_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def x_int(request):
@pytest.fixture(params=["regular", "irregular", "scalar"])
def x_complex(request):
if request.param == "regular":
return ragged.array(np.array([1+0.1j, 2+0.2j, 3+0.3j]))
return ragged.array(np.array([1 + 0.1j, 2 + 0.2j, 3 + 0.3j]))
elif request.param == "irregular":
return ragged.array(ak.Array([[1+0j, 2+0j, 3+0j], [], [4+0j, 5+0j]]))
return ragged.array(ak.Array([[1 + 0j, 2 + 0j, 3 + 0j], [], [4 + 0j, 5 + 0j]]))
else: # request.param == "scalar"
return ragged.array(np.array(10+1j))
return ragged.array(np.array(10 + 1j))


y = x
Expand Down Expand Up @@ -395,3 +395,30 @@ def test_floor_divide_int(device, x_int, y_int):
assert result.shape in (x_int.shape, y_int.shape)
assert xp.floor_divide(first(x_int), first(y_int)) == first(result)
assert xp.floor_divide(first(x_int), first(y_int)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_greater(device, x, y):
result = ragged.greater(x.to_device(device), y.to_device(device))
assert type(result) is type(x) is type(y)
assert result.shape in (x.shape, y.shape)
assert xp.greater(first(x), first(y)) == first(result)
assert xp.greater(first(x), first(y)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_greater_equal(device, x, y):
result = ragged.greater_equal(x.to_device(device), y.to_device(device))
assert type(result) is type(x) is type(y)
assert result.shape in (x.shape, y.shape)
assert xp.greater_equal(first(x), first(y)) == first(result)
assert xp.greater_equal(first(x), first(y)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_imag(device, x_complex):
result = ragged.imag(x_complex.to_device(device))
assert type(result) is type(x_complex)
assert result.shape == x_complex.shape
assert xp.imag(first(x_complex)) == first(result)
assert xp.imag(first(x_complex)).dtype == result.dtype

0 comments on commit eae61dd

Please sign in to comment.