Skip to content

Commit

Permalink
multiply, negative, not_equal, positive, pow, real, remainder, round
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Dec 30, 2023
1 parent 89743e0 commit 1fea779
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/ragged/_spec_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(
msg = f"unsupported __dlpack_device__ type: {device_type}"
raise TypeError(msg)

elif isinstance(obj, (bool, numbers.Real)):
elif isinstance(obj, (bool, numbers.Complex)):
self._impl = np.array(obj)
self._shape, self._dtype = (), self._impl.dtype

Expand Down
52 changes: 32 additions & 20 deletions src/ragged/_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,9 +960,7 @@ def multiply(x1: array, x2: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.multiply.html
"""

assert x1, "TODO"
assert x2, "TODO"
assert False, "TODO 92"
return _box(type(x1), np.multiply(*_unbox(x1, x2)))


def negative(x: array, /) -> array:
Expand All @@ -980,8 +978,7 @@ def negative(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.negative.html
"""

assert x, "TODO"
assert False, "TODO 93"
return _box(type(x), np.negative(*_unbox(x)))


def not_equal(x1: array, x2: array, /) -> array:
Expand All @@ -1001,9 +998,7 @@ def not_equal(x1: array, x2: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.not_equal.html
"""

assert x1, "TODO"
assert x2, "TODO"
assert False, "TODO 94"
return _box(type(x1), np.not_equal(*_unbox(x1, x2)))


def positive(x: array, /) -> array:
Expand All @@ -1021,8 +1016,7 @@ def positive(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.positive.html
"""

assert x, "TODO"
assert False, "TODO 95"
return _box(type(x), np.positive(*_unbox(x)))


def pow(x1: array, x2: array, /) -> array: # pylint: disable=W0622
Expand All @@ -1045,9 +1039,7 @@ def pow(x1: array, x2: array, /) -> array: # pylint: disable=W0622
https://data-apis.org/array-api/latest/API_specification/generated/array_api.pow.html
"""

assert x1, "TODO"
assert x2, "TODO"
assert False, "TODO 96"
return _box(type(x1), np.power(*_unbox(x1, x2)))


def real(x: array, /) -> array:
Expand All @@ -1067,8 +1059,14 @@ def real(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.real.html
"""

assert x, "TODO"
assert False, "TODO 97"
(a,) = _unbox(x)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return _box(
type(x),
(a + np.conjugate(a)) / 2,
dtype=np.dtype(f"f{x.dtype.itemsize // 2}"),
)


def remainder(x1: array, x2: array, /) -> array:
Expand All @@ -1090,9 +1088,7 @@ def remainder(x1: array, x2: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.remainder.html
"""

assert x1, "TODO"
assert x2, "TODO"
assert False, "TODO 98"
return _box(type(x1), np.remainder(*_unbox(x1, x2)))


def round(x: array, /) -> array: # pylint: disable=W0622
Expand All @@ -1118,8 +1114,24 @@ def round(x: array, /) -> array: # pylint: disable=W0622
https://data-apis.org/array-api/latest/API_specification/generated/array_api.round.html
"""

assert x, "TODO"
assert False, "TODO 99"
(a,) = _unbox(x)
if x.dtype in (np.complex64, np.complex128):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
a_conj = np.conjugate(a)
dt = np.dtype(f"f{x.dtype.itemsize // 2}")
re = _box(type(x), (a + a_conj) / 2, dtype=dt)
im = _box(type(x), (a - a_conj) / 2j, dtype=dt)
return add(round(re), multiply(round(im), array(1j, device=x.device)))

else:
frac, whole = np.modf(a)
abs_frac = np.absolute(frac)
return _box(
type(x),
whole
+ ((abs_frac == 0.5) * (whole % 2 != 0) + (abs_frac > 0.5)) * np.sign(frac),
)


def sign(x: array, /) -> array:
Expand Down
81 changes: 81 additions & 0 deletions tests/test_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,3 +559,84 @@ def test_logical_xor(device, x_bool, y_bool):
assert result.shape in (x_bool.shape, y_bool.shape)
assert xp.logical_xor(first(x_bool), first(y_bool)) == first(result)
assert xp.logical_xor(first(x_bool), first(y_bool)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_multiply(device, x, y):
result = ragged.multiply(x.to_device(device), y.to_device(device))
assert type(result) is type(x) is type(y)
assert result.shape in (x.shape, y.shape)
assert xp.multiply(first(x), first(y)) == first(result)
assert xp.multiply(first(x), first(y)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_negative(device, x):
result = ragged.negative(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.negative(first(x)) == pytest.approx(first(result))
assert xp.negative(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_not_equal(device, x, y):
result = ragged.not_equal(x.to_device(device), y.to_device(device))
assert type(result) is type(x) is type(y)
assert result.shape in (x.shape, y.shape)
assert xp.not_equal(first(x), first(y)) == first(result)
assert xp.not_equal(first(x), first(y)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_positive(device, x):
result = ragged.positive(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.positive(first(x)) == pytest.approx(first(result))
assert xp.positive(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_pow(device, x, y):
result = ragged.pow(x.to_device(device), y.to_device(device))
assert type(result) is type(x) is type(y)
assert result.shape in (x.shape, y.shape)
assert xp.pow(first(x), first(y)) == first(result)
assert xp.pow(first(x), first(y)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_real(device, x_complex):
result = ragged.real(x_complex.to_device(device))
assert type(result) is type(x_complex)
assert result.shape == x_complex.shape
assert xp.real(first(x_complex)) == first(result)
assert xp.real(first(x_complex)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_remainder(device, x, y):
result = ragged.remainder(x.to_device(device), y.to_device(device))
assert type(result) is type(x) is type(y)
assert result.shape in (x.shape, y.shape)
assert xp.remainder(first(x), first(y)) == first(result)
assert xp.remainder(first(x), first(y)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_round(device, x):
result = ragged.round(x.to_device(device))
assert type(result) is type(x)
assert result.shape == x.shape
assert xp.round(first(x)) == first(result)
assert xp.round(first(x)).dtype == result.dtype


@pytest.mark.parametrize("device", devices)
def test_round_complex(device, x_complex):
result = ragged.round(x_complex.to_device(device))
assert type(result) is type(x_complex)
assert result.shape == x_complex.shape
assert xp.round(first(x_complex)) == first(result)
assert xp.round(first(x_complex)).dtype == result.dtype

0 comments on commit 1fea779

Please sign in to comment.