Skip to content

Commit

Permalink
feat: add all direct from Awkward functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Jan 10, 2024
1 parent 96bb313 commit 6bd6c17
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 67 deletions.
34 changes: 34 additions & 0 deletions src/ragged/_spec_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,14 @@ def __init__(
elif isinstance(self._impl, np.ndarray) and device == "cuda":
cp = _import.cupy()
self._impl = cp.array(self._impl)
self._device = device
else:
if isinstance(self._impl, ak.Array):
self._device = ak.backend(self._impl)
elif isinstance(self._impl, np.ndarray):
self._device = "cpu"
else:
self._device = "cuda"

if copy is not None:
raise NotImplementedError("TODO 1") # noqa: EM101
Expand Down Expand Up @@ -1101,6 +1109,32 @@ def __irshift__(self, other: int | array, /) -> array:
__rrshift__ = __rshift__


def _is_shared(
x1: array | ak.Array | SupportsDLPack, x2: array | ak.Array | SupportsDLPack
) -> bool:
x1_buf = x1._impl if isinstance(x1, array) else x1 # pylint: disable=W0212
x2_buf = x2._impl if isinstance(x2, array) else x2 # pylint: disable=W0212

if isinstance(x1_buf, ak.Array):
x1_buf = x1_buf.layout
while not isinstance(x1_buf, NumpyArray):
x1_buf = x1_buf.content
x1_buf = x1_buf.data

if isinstance(x2_buf, ak.Array):
x2_buf = x2_buf.layout
while not isinstance(x2_buf, NumpyArray):
x2_buf = x2_buf.content
x2_buf = x2_buf.data

while x1_buf.base is not None: # type: ignore[union-attr]
x1_buf = x1_buf.base # type: ignore[union-attr]
while x2_buf.base is not None: # type: ignore[union-attr]
x2_buf = x2_buf.base # type: ignore[union-attr]

return x1_buf is x2_buf


def _unbox(*inputs: array) -> tuple[ak.Array | SupportsDLPack, ...]:
if len(inputs) > 1 and any(type(inputs[0]) is not type(x) for x in inputs):
types = "\n".join(f"{type(x).__module__}.{type(x).__name__}" for x in inputs)
Expand Down
15 changes: 5 additions & 10 deletions src/ragged/_spec_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np

from ._spec_array_object import array
from ._spec_array_object import _box, _unbox, array
from ._typing import Dtype

_type = type
Expand All @@ -23,11 +23,7 @@ def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
Args:
x: Array to cast.
dtype: Desired data type.
copy: Specifies whether to copy an array when the specified `dtype`
matches the data type of the input array `x`. If `True`, a newly
allocated array is always returned. If `False` and the specified
`dtype` matches the data type of the input array, the input array
is returned; otherwise, a newly allocated array is returned.
copy: Ignored because `ragged.array` data buffers are immutable.
Returns:
An array having the specified data type. The returned array has the
Expand All @@ -36,10 +32,9 @@ def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html
"""

x # noqa: B018, pylint: disable=W0104
dtype # noqa: B018, pylint: disable=W0104
copy # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 50") # noqa: EM101
copy # noqa: B018, argument is ignored, pylint: disable=W0104

return _box(type(x), *_unbox(x), dtype=dtype)


def can_cast(from_: Dtype | array, to: Dtype, /) -> bool:
Expand Down
59 changes: 59 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import reprlib

import awkward as ak
import numpy as np
import pytest

import ragged
Expand All @@ -20,3 +22,60 @@ def repr1(self, x, level):

reprlib.Repr.repr1_original = reprlib.Repr.repr1 # type: ignore[attr-defined]
reprlib.Repr.repr1 = repr1 # type: ignore[method-assign]


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x(request):
if request.param == "regular":
return ragged.array(np.array([1.0, 2.0, 3.0]))
elif request.param == "irregular":
return ragged.array(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]]))
else: # request.param == "scalar"
return ragged.array(np.array(10.0))


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x_lt1(request):
if request.param == "regular":
return ragged.array(np.array([0.1, 0.2, 0.3]))
elif request.param == "irregular":
return ragged.array(ak.Array([[0.1, 0.2, 0.3], [], [0.4, 0.5]]))
else: # request.param == "scalar"
return ragged.array(np.array(0.5))


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x_bool(request):
if request.param == "regular":
return ragged.array(np.array([False, True, False]))
elif request.param == "irregular":
return ragged.array(ak.Array([[True, True, False], [], [False, False]]))
else: # request.param == "scalar"
return ragged.array(np.array(True))


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x_int(request):
if request.param == "regular":
return ragged.array(np.array([0, 1, 2], dtype=np.int64))
elif request.param == "irregular":
return ragged.array(ak.Array([[1, 2, 3], [], [4, 5]]))
else: # request.param == "scalar"
return ragged.array(np.array(10, dtype=np.int64))


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x_complex(request):
if request.param == "regular":
return ragged.array(np.array([1 + 0.1j, 2 + 0.2j, 3 + 0.3j]))
elif request.param == "irregular":
return ragged.array(ak.Array([[1 + 0j, 2 + 0j, 3 + 0j], [], [4 + 0j, 5 + 0j]]))
else: # request.param == "scalar"
return ragged.array(np.array(10 + 1j))


y = x
y_lt1 = x_lt1
y_bool = x_bool
y_int = x_int
y_complex = x_complex
27 changes: 27 additions & 0 deletions tests/test_spec_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,27 @@

from __future__ import annotations

from typing import Any

import awkward as ak
import numpy as np
import pytest

import ragged

devices = ["cpu"]
try:
import cupy as cp

devices.append("cuda")
except ModuleNotFoundError:
cp = None


def first(x: ragged.array) -> Any:
out = ak.flatten(x._impl, axis=None)[0] if x.shape != () else x._impl
return np.asarray(out.item(), dtype=x.dtype)


def test_existence():
assert ragged.astype is not None
Expand All @@ -20,6 +37,16 @@ def test_existence():
assert ragged.result_type is not None


@pytest.mark.parametrize("device", devices)
@pytest.mark.parametrize("dt", ["float64", np.float64, np.dtype(np.float64)])
def test_astype(device, x_int, dt):
x = x_int.to_device(device)
y = ragged.astype(x, dt)
assert first(y) == first(x)
assert y.dtype == np.dtype(np.float64)
assert y.device == x.device


def test_can_cast():
assert ragged.can_cast(np.float32, np.complex128)
assert not ragged.can_cast(np.complex128, np.float32)
Expand Down
57 changes: 0 additions & 57 deletions tests/test_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,63 +29,6 @@
cp = None


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x(request):
if request.param == "regular":
return ragged.array(np.array([1.0, 2.0, 3.0]))
elif request.param == "irregular":
return ragged.array(ak.Array([[1.1, 2.2, 3.3], [], [4.4, 5.5]]))
else: # request.param == "scalar"
return ragged.array(np.array(10.0))


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x_lt1(request):
if request.param == "regular":
return ragged.array(np.array([0.1, 0.2, 0.3]))
elif request.param == "irregular":
return ragged.array(ak.Array([[0.1, 0.2, 0.3], [], [0.4, 0.5]]))
else: # request.param == "scalar"
return ragged.array(np.array(0.5))


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x_bool(request):
if request.param == "regular":
return ragged.array(np.array([False, True, False]))
elif request.param == "irregular":
return ragged.array(ak.Array([[True, True, False], [], [False, False]]))
else: # request.param == "scalar"
return ragged.array(np.array(True))


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x_int(request):
if request.param == "regular":
return ragged.array(np.array([0, 1, 2], dtype=np.int64))
elif request.param == "irregular":
return ragged.array(ak.Array([[1, 2, 3], [], [4, 5]]))
else: # request.param == "scalar"
return ragged.array(np.array(10, dtype=np.int64))


@pytest.fixture(params=["regular", "irregular", "scalar"])
def x_complex(request):
if request.param == "regular":
return ragged.array(np.array([1 + 0.1j, 2 + 0.2j, 3 + 0.3j]))
elif request.param == "irregular":
return ragged.array(ak.Array([[1 + 0j, 2 + 0j, 3 + 0j], [], [4 + 0j, 5 + 0j]]))
else: # request.param == "scalar"
return ragged.array(np.array(10 + 1j))


y = x
y_lt1 = x_lt1
y_bool = x_bool
y_int = x_int
y_complex = x_complex


def first(x: ragged.array) -> Any:
out = ak.flatten(x._impl, axis=None)[0] if x.shape != () else x._impl
return xp.asarray(out.item(), dtype=x.dtype)
Expand Down

0 comments on commit 6bd6c17

Please sign in to comment.