Skip to content

Commit

Permalink
Add an __array_interface__ attribute to GpuStruct
Browse files Browse the repository at this point in the history
  • Loading branch information
shwina committed Feb 6, 2025
1 parent e8bc8fc commit 0f404e7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,7 @@ def __call__(
else:
self.d_out_cccl.state = d_out.state

if self.h_init_cccl.type.type.value == cccl.TypeEnum.STORAGE:
self.h_init_cccl.state = h_init._data.__array_interface__["data"][0] # type: ignore
else:
self.h_init_cccl.state = h_init.__array_interface__["data"][0]
self.h_init_cccl.state = h_init.__array_interface__["data"][0]

stream_handle = protocols.validate_and_get_stream(stream)

Expand Down
4 changes: 4 additions & 0 deletions python/cuda_parallel/cuda/parallel/experimental/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ def __post_init__(self):
[tuple(getattr(self, name) for name in anns)], dtype=self.dtype
)

def __array_interface__(self):
return self._data.__array_interface__

setattr(this, "__post_init__", __post_init__)
setattr(this, "__array_interface__", property(__array_interface__))

# Wrap `this` in a dataclass for convenience:
this = dataclass(this)
Expand Down

0 comments on commit 0f404e7

Please sign in to comment.