Skip to content

Commit

Permalink
fix: generalize Index.ptr (#3206)
Browse files Browse the repository at this point in the history
* fix: generalize 'Index.ptr'

* include a test that checks TypeTracer
  • Loading branch information
jpivarski authored Aug 15, 2024
1 parent e15518f commit 0923d22
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/awkward/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,16 @@ def metadata(self) -> dict:

@property
def ptr(self):
if self._nplike == Numpy.instance():
if isinstance(self._nplike, Numpy):
return self._data.ctypes.data
elif self._nplike == Cupy.instance():
elif isinstance(self._nplike, Cupy):
return self._data.data.ptr
elif isinstance(self._nplike, TypeTracer):
return 0
else:
raise NotImplementedError(
f"this function hasn't been implemented for the {type(self._nplike).__name__} backend"
)

@property
def length(self) -> ShapeItem:
Expand Down
22 changes: 22 additions & 0 deletions tests/test_3206_generalize_index_ptr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import pytest

import awkward as ak


def test_1():
arr = ak.Array([[1, 3, 4], 5])
tarr = arr.layout.to_typetracer()

with pytest.raises(ak.errors.AxisError, match="exceeds the depth of this array"):
ak.flatten(tarr)


def test_2():
arr = ak.Array([[[1, 3, 4]], [5]])
tarr = arr.layout.to_typetracer()

assert ak.flatten(tarr).type == ak.flatten(arr).type

0 comments on commit 0923d22

Please sign in to comment.