diff --git a/src/ragged/_spec_elementwise_functions.py b/src/ragged/_spec_elementwise_functions.py index ba07e34..3357c6c 100644 --- a/src/ragged/_spec_elementwise_functions.py +++ b/src/ragged/_spec_elementwise_functions.py @@ -1156,8 +1156,7 @@ def sign(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.sign.html """ - assert x, "TODO" - assert False, "TODO 100" + return _box(type(x), np.sign(*_unbox(x))) def sin(x: array, /) -> array: @@ -1177,8 +1176,7 @@ def sin(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.sin.html """ - assert x, "TODO" - assert False, "TODO 101" + return _box(type(x), np.sin(*_unbox(x))) def sinh(x: array, /) -> array: @@ -1201,8 +1199,7 @@ def sinh(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.sinh.html """ - assert x, "TODO" - assert False, "TODO 102" + return _box(type(x), np.sinh(*_unbox(x))) def square(x: array, /) -> array: @@ -1223,8 +1220,7 @@ def square(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.square.html """ - assert x, "TODO" - assert False, "TODO 103" + return _box(type(x), np.square(*_unbox(x))) def sqrt(x: array, /) -> array: @@ -1243,8 +1239,7 @@ def sqrt(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.sqrt.html """ - assert x, "TODO" - assert False, "TODO 104" + return _box(type(x), np.sqrt(*_unbox(x))) def subtract(x1: array, x2: array, /) -> array: @@ -1266,9 +1261,7 @@ def subtract(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.subtract.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO 105" + return _box(type(x1), np.subtract(*_unbox(x1, x2))) def tan(x: array, /) -> array: @@ -1288,8 +1281,7 @@ def tan(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.tan.html """ - assert x, "TODO" - assert False, "TODO 106" + return _box(type(x), np.tan(*_unbox(x))) def tanh(x: array, /) -> array: @@ -1315,8 +1307,7 @@ def tanh(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.tanh.html """ - assert x, "TODO" - assert False, "TODO 107" + return _box(type(x), np.tanh(*_unbox(x))) def trunc(x: array, /) -> array: @@ -1334,5 +1325,4 @@ def trunc(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.trunc.html """ - assert x, "TODO" - assert False, "TODO 108" + return _box(type(x), np.trunc(*_unbox(x))) diff --git a/tests/test_spec_elementwise_functions.py b/tests/test_spec_elementwise_functions.py index f803e87..1ec6161 100644 --- a/tests/test_spec_elementwise_functions.py +++ b/tests/test_spec_elementwise_functions.py @@ -640,3 +640,84 @@ def test_round_complex(device, x_complex): assert result.shape == x_complex.shape assert xp.round(first(x_complex)) == first(result) assert xp.round(first(x_complex)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_sign(device, x): + result = ragged.sign(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.sign(first(x)) == first(result) + assert xp.sign(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_sin(device, x): + result = ragged.sin(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.sin(first(x)) == pytest.approx(first(result)) + assert xp.sin(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_sinh(device, x): + result = ragged.sinh(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.sinh(first(x)) == pytest.approx(first(result)) + assert xp.sinh(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_square(device, x): + result = ragged.square(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.square(first(x)) == first(result) + assert xp.square(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_sqrt(device, x): + result = ragged.sqrt(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.sqrt(first(x)) == pytest.approx(first(result)) + assert xp.sqrt(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_subtract(device, x, y): + result = ragged.subtract(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.subtract(first(x), first(y)) == first(result) + assert xp.subtract(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_tan(device, x): + result = ragged.tan(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.tan(first(x)) == pytest.approx(first(result)) + assert xp.tan(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_tanh(device, x): + result = ragged.tanh(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.tanh(first(x)) == pytest.approx(first(result)) + assert xp.tanh(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_trunc(device, x): + result = ragged.trunc(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.trunc(first(x)) == first(result) + assert xp.trunc(first(x)).dtype == result.dtype