diff --git a/src/ragged/_spec_elementwise_functions.py b/src/ragged/_spec_elementwise_functions.py index 5f80470..b40967b 100644 --- a/src/ragged/_spec_elementwise_functions.py +++ b/src/ragged/_spec_elementwise_functions.py @@ -691,8 +691,7 @@ def isfinite(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.isfinite.html """ - assert x, "TODO" - assert False, "TODO 78" + return _box(type(x), np.isfinite(*_unbox(x))) def isinf(x: array, /) -> array: @@ -710,8 +709,7 @@ def isinf(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.isinf.html """ - assert x, "TODO" - assert False, "TODO 79" + return _box(type(x), np.isinf(*_unbox(x))) def isnan(x: array, /) -> array: @@ -729,8 +727,7 @@ def isnan(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.isnan.html """ - assert x, "TODO" - assert False, "TODO 80" + return _box(type(x), np.isnan(*_unbox(x))) def less(x1: array, x2: array, /) -> array: @@ -750,9 +747,7 @@ def less(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.less.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO 81" + return _box(type(x1), np.less(*_unbox(x1, x2))) def less_equal(x1: array, x2: array, /) -> array: @@ -772,9 +767,7 @@ def less_equal(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.less_equal.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO 82" + return _box(type(x1), np.less_equal(*_unbox(x1, x2))) def log(x: array, /) -> array: @@ -793,8 +786,7 @@ def log(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.log.html """ - assert x, "TODO" - assert False, "TODO 83" + return _box(type(x), np.log(*_unbox(x))) def log1p(x: array, /) -> array: @@ -817,8 +809,7 @@ def log1p(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.log1p.html """ - assert x, "TODO" - assert False, "TODO 84" + return _box(type(x), np.log1p(*_unbox(x))) def log2(x: array, /) -> array: @@ -837,8 +828,7 @@ def log2(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.log2.html """ - assert x, "TODO" - assert False, "TODO 85" + return _box(type(x), np.log2(*_unbox(x))) def log10(x: array, /) -> array: @@ -857,8 +847,7 @@ def log10(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.log10.html """ - assert x, "TODO" - assert False, "TODO 86" + return _box(type(x), np.log10(*_unbox(x))) def logaddexp(x1: array, x2: array, /) -> array: @@ -878,9 +867,7 @@ def logaddexp(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.logaddexp.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO 87" + return _box(type(x1), np.logaddexp(*_unbox(x1, x2))) def logical_and(x1: array, x2: array, /) -> array: @@ -899,9 +886,7 @@ def logical_and(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_and.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO 88" + return _box(type(x1), np.logical_and(*_unbox(x1, x2))) def logical_not(x: array, /) -> array: @@ -918,8 +903,7 @@ def logical_not(x: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_not.html """ - assert x, "TODO" - assert False, "TODO 89" + return _box(type(x), np.logical_not(*_unbox(x))) def logical_or(x1: array, x2: array, /) -> array: @@ -938,9 +922,7 @@ def logical_or(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_or.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO 90" + return _box(type(x1), np.logical_or(*_unbox(x1, x2))) def logical_xor(x1: array, x2: array, /) -> array: @@ -959,9 +941,7 @@ def logical_xor(x1: array, x2: array, /) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.logical_xor.html """ - assert x1, "TODO" - assert x2, "TODO" - assert False, "TODO 91" + return _box(type(x1), np.logical_xor(*_unbox(x1, x2))) def multiply(x1: array, x2: array, /) -> array: diff --git a/tests/test_spec_elementwise_functions.py b/tests/test_spec_elementwise_functions.py index 87f5a7a..c44cdd1 100644 --- a/tests/test_spec_elementwise_functions.py +++ b/tests/test_spec_elementwise_functions.py @@ -49,6 +49,16 @@ def x_lt1(request): return ragged.array(np.array(0.5)) +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x_bool(request): + if request.param == "regular": + return ragged.array(np.array([False, True, False])) + elif request.param == "irregular": + return ragged.array(ak.Array([[True, True, False], [], [False, False]])) + else: # request.param == "scalar" + return ragged.array(np.array(True)) + + @pytest.fixture(params=["regular", "irregular", "scalar"]) def x_int(request): if request.param == "regular": @@ -71,6 +81,7 @@ def x_complex(request): y = x y_lt1 = x_lt1 +y_bool = x_bool y_int = x_int y_complex = x_complex @@ -422,3 +433,129 @@ def test_imag(device, x_complex): assert result.shape == x_complex.shape assert xp.imag(first(x_complex)) == first(result) assert xp.imag(first(x_complex)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_isfinite(device, x): + result = ragged.isfinite(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.isfinite(first(x)) == first(result) + assert xp.isfinite(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_isinf(device, x): + result = ragged.isinf(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.isinf(first(x)) == first(result) + assert xp.isinf(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_isnan(device, x): + result = ragged.isnan(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.isnan(first(x)) == first(result) + assert xp.isnan(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_less(device, x, y): + result = ragged.less(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.less(first(x), first(y)) == first(result) + assert xp.less(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_less_equal(device, x, y): + result = ragged.less_equal(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.less_equal(first(x), first(y)) == first(result) + assert xp.less_equal(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_log(device, x): + result = ragged.log(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.log(first(x)) == pytest.approx(first(result)) + assert xp.log(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_log1p(device, x): + result = ragged.log1p(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.log1p(first(x)) == pytest.approx(first(result)) + assert xp.log1p(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_log2(device, x): + result = ragged.log2(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.log2(first(x)) == pytest.approx(first(result)) + assert xp.log2(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_log10(device, x): + result = ragged.log10(x.to_device(device)) + assert type(result) is type(x) + assert result.shape == x.shape + assert xp.log10(first(x)) == pytest.approx(first(result)) + assert xp.log10(first(x)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_logaddexp(device, x, y): + result = ragged.logaddexp(x.to_device(device), y.to_device(device)) + assert type(result) is type(x) is type(y) + assert result.shape in (x.shape, y.shape) + assert xp.logaddexp(first(x), first(y)) == pytest.approx(first(result)) + assert xp.logaddexp(first(x), first(y)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_logical_and(device, x_bool, y_bool): + result = ragged.logical_and(x_bool.to_device(device), y_bool.to_device(device)) + assert type(result) is type(x_bool) is type(y_bool) + assert result.shape in (x_bool.shape, y_bool.shape) + assert xp.logical_and(first(x_bool), first(y_bool)) == first(result) + assert xp.logical_and(first(x_bool), first(y_bool)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_logical_not(device, x_bool): + result = ragged.logical_not(x_bool.to_device(device)) + assert type(result) is type(x_bool) + assert result.shape == x_bool.shape + assert xp.logical_not(first(x_bool)) == first(result) + assert xp.logical_not(first(x_bool)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_logical_or(device, x_bool, y_bool): + result = ragged.logical_or(x_bool.to_device(device), y_bool.to_device(device)) + assert type(result) is type(x_bool) is type(y_bool) + assert result.shape in (x_bool.shape, y_bool.shape) + assert xp.logical_or(first(x_bool), first(y_bool)) == first(result) + assert xp.logical_or(first(x_bool), first(y_bool)).dtype == result.dtype + + +@pytest.mark.parametrize("device", devices) +def test_logical_xor(device, x_bool, y_bool): + result = ragged.logical_xor(x_bool.to_device(device), y_bool.to_device(device)) + assert type(result) is type(x_bool) is type(y_bool) + assert result.shape in (x_bool.shape, y_bool.shape) + assert xp.logical_xor(first(x_bool), first(y_bool)) == first(result) + assert xp.logical_xor(first(x_bool), first(y_bool)).dtype == result.dtype