diff --git a/src/ragged/_spec_array_object.py b/src/ragged/_spec_array_object.py index c430872..e212b29 100644 --- a/src/ragged/_spec_array_object.py +++ b/src/ragged/_spec_array_object.py @@ -179,7 +179,7 @@ def __init__( msg = f"unsupported __dlpack_device__ type: {device_type}" raise TypeError(msg) - elif isinstance(obj, (bool, numbers.Real)): + elif isinstance(obj, (bool, numbers.Complex)): self._impl = np.array(obj) self._shape, self._dtype = (), self._impl.dtype diff --git a/src/ragged/_spec_elementwise_functions.py b/src/ragged/_spec_elementwise_functions.py index b40967b..ba07e34 100644 --- a/src/ragged/_spec_elementwise_functions.py +++ b/src/ragged/_spec_elementwise_functions.py @@ -960,9 +960,7 @@ def multiply(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.multiply.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO 92" + return _box(type(x1), np.multiply(*_unbox(x1, x2))) def negative(x: array, /) -> array: @@ -980,8 +978,7 @@ def negative(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.negative.html """ - assert x, "TODO" - assert False, "TODO 93" + return _box(type(x), np.negative(*_unbox(x))) def not_equal(x1: array, x2: array, /) -> array: @@ -1001,9 +998,7 @@ def not_equal(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.not_equal.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO 94" + return _box(type(x1), np.not_equal(*_unbox(x1, x2))) def positive(x: array, /) -> array: @@ -1021,8 +1016,7 @@ def positive(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.positive.html """ - assert x, "TODO" - assert False, "TODO 95" + return _box(type(x), np.positive(*_unbox(x))) def pow(x1: array, x2: array, /) -> array: # pylint: disable=W0622 @@ -1045,9 +1039,7 @@ def pow(x1: array, x2: array, /) -> array: # pylint: disable=W0622 https://data-apis.org/array-api/latest/API_specification/generated/array_api.pow.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO 96" + return _box(type(x1), np.power(*_unbox(x1, x2))) def real(x: array, /) -> array: @@ -1067,8 +1059,14 @@ def real(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.real.html """ - assert x, "TODO" - assert False, "TODO 97" + (a,) = _unbox(x) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return _box( + type(x), + (a + np.conjugate(a)) / 2, + dtype=np.dtype(f"f{x.dtype.itemsize // 2}"), + ) def remainder(x1: array, x2: array, /) -> array: @@ -1090,9 +1088,7 @@ def remainder(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.remainder.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO 98" + return _box(type(x1), np.remainder(*_unbox(x1, x2))) def round(x: array, /) -> array: # pylint: disable=W0622 @@ -1118,8 +1114,24 @@ def round(x: array, /) -> array: # pylint: disable=W0622 https://data-apis.org/array-api/latest/API_specification/generated/array_api.round.html """ - assert x, "TODO" - assert False, "TODO 99" + (a,) = _unbox(x) + if x.dtype in (np.complex64, np.complex128): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + a_conj = np.conjugate(a) + dt = np.dtype(f"f{x.dtype.itemsize // 2}") + re = _box(type(x), (a + a_conj) / 2, dtype=dt) + im = _box(type(x), (a - a_conj) / 2j, dtype=dt) + return add(round(re), multiply(round(im), array(1j, device=x.device))) + + else: + frac, whole = np.modf(a) + abs_frac = np.absolute(frac) + return _box( + type(x), + whole + + ((abs_frac == 0.5) * (whole % 2 != 0) + (abs_frac > 0.5)) * np.sign(frac), + ) def sign(x: array, /) -> array: diff --git a/tests/test_spec_elementwise_functions.py b/tests/test_spec_elementwise_functions.py index c44cdd1..f803e87 100644 --- a/tests/test_spec_elementwise_functions.py +++ b/tests/test_spec_elementwise_functions.py @@ -559,3 +559,84 @@ def test_logical_xor(device, x_bool, y_bool): assert result.shape in (x_bool.shape, y_bool.shape) assert xp.logical_xor(first(x_bool), first(y_bool)) == first(result) assert xp.logical_xor(first(x_bool), first(y_bool)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_multiply(device, x, y): + result = ragged.multiply(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.multiply(first(x), first(y)) == first(result) + assert xp.multiply(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_negative(device, x): + result = ragged.negative(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.negative(first(x)) == pytest.approx(first(result)) + assert xp.negative(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_not_equal(device, x, y): + result = ragged.not_equal(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.not_equal(first(x), first(y)) == first(result) + assert xp.not_equal(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_positive(device, x): + result = ragged.positive(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.positive(first(x)) == pytest.approx(first(result)) + assert xp.positive(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_pow(device, x, y): + result = ragged.pow(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.pow(first(x), first(y)) == first(result) + assert xp.pow(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_real(device, x_complex): + result = ragged.real(x_complex.to_device(device)) + assert type(result) is type(x_complex) + assert result.shape == x_complex.shape + assert xp.real(first(x_complex)) == first(result) + assert xp.real(first(x_complex)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_remainder(device, x, y): + result = ragged.remainder(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.remainder(first(x), first(y)) == first(result) + assert xp.remainder(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_round(device, x): + result = ragged.round(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.round(first(x)) == first(result) + assert xp.round(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_round_complex(device, x_complex): + result = ragged.round(x_complex.to_device(device)) + assert type(result) is type(x_complex) + assert result.shape == x_complex.shape + assert xp.round(first(x_complex)) == first(result) + assert xp.round(first(x_complex)).dtype == result.dtype