From 6bd6c177ab27bf480597fe5f4d5fdba35931adcd Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 9 Jan 2024 18:11:18 -0600 Subject: [PATCH] feat: add all direct from Awkward functions --- src/ragged/_spec_array_object.py | 34 ++++++++++++++ src/ragged/_spec_data_type_functions.py | 15 ++---- tests/conftest.py | 59 ++++++++++++++++++++++++ tests/test_spec_data_type_functions.py | 27 +++++++++++ tests/test_spec_elementwise_functions.py | 57 ----------------------- 5 files changed, 125 insertions(+), 67 deletions(-) diff --git a/src/ragged/_spec_array_object.py b/src/ragged/_spec_array_object.py index b300fc5..843bd98 100644 --- a/src/ragged/_spec_array_object.py +++ b/src/ragged/_spec_array_object.py @@ -216,6 +216,14 @@ def __init__( elif isinstance(self._impl, np.ndarray) and device == "cuda": cp = _import.cupy() self._impl = cp.array(self._impl) + self._device = device + else: + if isinstance(self._impl, ak.Array): + self._device = ak.backend(self._impl) + elif isinstance(self._impl, np.ndarray): + self._device = "cpu" + else: + self._device = "cuda" if copy is not None: raise NotImplementedError("TODO 1") # noqa: EM101 @@ -1101,6 +1109,32 @@ def __irshift__(self, other: int | array, /) -> array: __rrshift__ = __rshift__ +def _is_shared( + x1: array | ak.Array | SupportsDLPack, x2: array | ak.Array | SupportsDLPack +) -> bool: + x1_buf = x1._impl if isinstance(x1, array) else x1 # pylint: disable=W0212 + x2_buf = x2._impl if isinstance(x2, array) else x2 # pylint: disable=W0212 + + if isinstance(x1_buf, ak.Array): + x1_buf = x1_buf.layout + while not isinstance(x1_buf, NumpyArray): + x1_buf = x1_buf.content + x1_buf = x1_buf.data + + if isinstance(x2_buf, ak.Array): + x2_buf = x2_buf.layout + while not isinstance(x2_buf, NumpyArray): + x2_buf = x2_buf.content + x2_buf = x2_buf.data + + while x1_buf.base is not None: # type: ignore[union-attr] + x1_buf = x1_buf.base # type: ignore[union-attr] + while x2_buf.base is not None: # type: ignore[union-attr] + x2_buf = x2_buf.base # type: ignore[union-attr] + + return x1_buf is x2_buf + + def _unbox(*inputs: array) -> tuple[ak.Array | SupportsDLPack, ...]: if len(inputs) > 1 and any(type(inputs[0]) is not type(x) for x in inputs): types = "\n".join(f"{type(x).__module__}.{type(x).__name__}" for x in inputs) diff --git a/src/ragged/_spec_data_type_functions.py b/src/ragged/_spec_data_type_functions.py index 4106ef8..9c9858b 100644 --- a/src/ragged/_spec_data_type_functions.py +++ b/src/ragged/_spec_data_type_functions.py @@ -10,7 +10,7 @@ import numpy as np -from ._spec_array_object import array +from ._spec_array_object import _box, _unbox, array from ._typing import Dtype _type = type @@ -23,11 +23,7 @@ def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array: Args: x: Array to cast. dtype: Desired data type. - copy: Specifies whether to copy an array when the specified `dtype` - matches the data type of the input array `x`. If `True`, a newly - allocated array is always returned. If `False` and the specified - `dtype` matches the data type of the input array, the input array - is returned; otherwise, a newly allocated array is returned. + copy: Ignored because `ragged.array` data buffers are immutable. Returns: An array having the specified data type. The returned array has the @@ -36,10 +32,9 @@ def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html """ - x # noqa: B018, pylint: disable=W0104 - dtype # noqa: B018, pylint: disable=W0104 - copy # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 50") # noqa: EM101 + copy # noqa: B018, argument is ignored, pylint: disable=W0104 + + return _box(type(x), *_unbox(x), dtype=dtype) def can_cast(from_: Dtype | array, to: Dtype, /) -> bool: diff --git a/tests/conftest.py b/tests/conftest.py index 4dd6d6e..0bebee7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,8 @@ import reprlib +import awkward as ak +import numpy as np import pytest import ragged @@ -20,3 +22,60 @@ def repr1(self, x, level): reprlib.Repr.repr1_original = reprlib.Repr.repr1 # type: ignore[attr-defined] reprlib.Repr.repr1 = repr1 # type: ignore[method-assign] + + +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x(request): + if request.param == "regular": + return ragged.array(np.array([1.0, 2.0, 3.0])) + elif request.param == "irregular": + return ragged.array(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]])) + else: # request.param == "scalar" + return ragged.array(np.array(10.0)) + + +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x_lt1(request): + if request.param == "regular": + return ragged.array(np.array([0.1, 0.2, 0.3])) + elif request.param == "irregular": + return ragged.array(ak.Array([[0.1, 0.2, 0.3], [], [0.4, 0.5]])) + else: # request.param == "scalar" + return ragged.array(np.array(0.5)) + + +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x_bool(request): + if request.param == "regular": + return ragged.array(np.array([False, True, False])) + elif request.param == "irregular": + return ragged.array(ak.Array([[True, True, False], [], [False, False]])) + else: # request.param == "scalar" + return ragged.array(np.array(True)) + + +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x_int(request): + if request.param == "regular": + return ragged.array(np.array([0, 1, 2], dtype=np.int64)) + elif request.param == "irregular": + return ragged.array(ak.Array([[1, 2, 3], [], [4, 5]])) + else: # request.param == "scalar" + return ragged.array(np.array(10, dtype=np.int64)) + + +@pytest.fixture(params=["regular", "irregular", "scalar"]) +def x_complex(request): + if request.param == "regular": + return ragged.array(np.array([1 + 0.1j, 2 + 0.2j, 3 + 0.3j])) + elif request.param == "irregular": + return ragged.array(ak.Array([[1 + 0j, 2 + 0j, 3 + 0j], [], [4 + 0j, 5 + 0j]])) + else: # request.param == "scalar" + return ragged.array(np.array(10 + 1j)) + + +y = x +y_lt1 = x_lt1 +y_bool = x_bool +y_int = x_int +y_complex = x_complex diff --git a/tests/test_spec_data_type_functions.py b/tests/test_spec_data_type_functions.py index 526181f..53ac14d 100644 --- a/tests/test_spec_data_type_functions.py +++ b/tests/test_spec_data_type_functions.py @@ -6,10 +6,27 @@ from __future__ import annotations +from typing import Any + +import awkward as ak import numpy as np +import pytest import ragged +devices = ["cpu"] +try: + import cupy as cp + + devices.append("cuda") +except ModuleNotFoundError: + cp = None + + +def first(x: ragged.array) -> Any: + out = ak.flatten(x._impl, axis=None)[0] if x.shape != () else x._impl + return np.asarray(out.item(), dtype=x.dtype) + def test_existence(): assert ragged.astype is not None @@ -20,6 +37,16 @@ def test_existence(): assert ragged.result_type is not None +@pytest.mark.parametrize("device", devices) +@pytest.mark.parametrize("dt", ["float64", np.float64, np.dtype(np.float64)]) +def test_astype(device, x_int, dt): + x = x_int.to_device(device) + y = ragged.astype(x, dt) + assert first(y) == first(x) + assert y.dtype == np.dtype(np.float64) + assert y.device == x.device + + def test_can_cast(): assert ragged.can_cast(np.float32, np.complex128) assert not ragged.can_cast(np.complex128, np.float32) diff --git a/tests/test_spec_elementwise_functions.py b/tests/test_spec_elementwise_functions.py index 3838778..fdac4a9 100644 --- a/tests/test_spec_elementwise_functions.py +++ b/tests/test_spec_elementwise_functions.py @@ -29,63 +29,6 @@ cp = None -@pytest.fixture(params=["regular", "irregular", "scalar"]) -def x(request): - if request.param == "regular": - return ragged.array(np.array([1.0, 2.0, 3.0])) - elif request.param == "irregular": - return ragged.array(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]])) - else: # request.param == "scalar" - return ragged.array(np.array(10.0)) - - -@pytest.fixture(params=["regular", "irregular", "scalar"]) -def x_lt1(request): - if request.param == "regular": - return ragged.array(np.array([0.1, 0.2, 0.3])) - elif request.param == "irregular": - return ragged.array(ak.Array([[0.1, 0.2, 0.3], [], [0.4, 0.5]])) - else: # request.param == "scalar" - return ragged.array(np.array(0.5)) - - -@pytest.fixture(params=["regular", "irregular", "scalar"]) -def x_bool(request): - if request.param == "regular": - return ragged.array(np.array([False, True, False])) - elif request.param == "irregular": - return ragged.array(ak.Array([[True, True, False], [], [False, False]])) - else: # request.param == "scalar" - return ragged.array(np.array(True)) - - -@pytest.fixture(params=["regular", "irregular", "scalar"]) -def x_int(request): - if request.param == "regular": - return ragged.array(np.array([0, 1, 2], dtype=np.int64)) - elif request.param == "irregular": - return ragged.array(ak.Array([[1, 2, 3], [], [4, 5]])) - else: # request.param == "scalar" - return ragged.array(np.array(10, dtype=np.int64)) - - -@pytest.fixture(params=["regular", "irregular", "scalar"]) -def x_complex(request): - if request.param == "regular": - return ragged.array(np.array([1 + 0.1j, 2 + 0.2j, 3 + 0.3j])) - elif request.param == "irregular": - return ragged.array(ak.Array([[1 + 0j, 2 + 0j, 3 + 0j], [], [4 + 0j, 5 + 0j]])) - else: # request.param == "scalar" - return ragged.array(np.array(10 + 1j)) - - -y = x -y_lt1 = x_lt1 -y_bool = x_bool -y_int = x_int -y_complex = x_complex - - def first(x: ragged.array) -> Any: out = ak.flatten(x._impl, axis=None)[0] if x.shape != () else x._impl return xp.asarray(out.item(), dtype=x.dtype)