Skip to content

Commit

Permalink
Revert "Convert d_in/d_out to StridedMemoryView and get ptr from that"
Browse files Browse the repository at this point in the history
This reverts commit 9cc2902.
  • Loading branch information
shwina committed Feb 6, 2025
1 parent 9cc2902 commit 8bae237
Showing 1 changed file with 17 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from numba import cuda
from numba.cuda.cudadrv import enums

from cuda.core.experimental.utils import StridedMemoryView

from .. import _cccl as cccl
from .._bindings import get_bindings, get_paths
from .._caching import CachableFunction, cache_with_key
Expand Down Expand Up @@ -62,6 +60,7 @@ def __init__(
):
# Referenced from __del__:
self.build_result = None

self.d_in_cccl = cccl.to_cccl_iter(d_in)
self.d_out_cccl = cccl.to_cccl_iter(d_out)
self.h_init_cccl = cccl.to_cccl_value(h_init)
Expand Down Expand Up @@ -101,14 +100,26 @@ def __call__(
h_init: np.ndarray | GpuStruct,
stream=None,
):
d_in = StridedMemoryView(d_in, stream_ptr=1)
d_out = StridedMemoryView(d_out, stream_ptr=1)
if self.d_in_cccl.type.value == cccl.IteratorKind.ITERATOR:
assert num_items is not None
else:
assert self.d_in_cccl.type.value == cccl.IteratorKind.POINTER
if num_items is None:
num_items = d_in.size
else:
assert num_items == d_in.size
_dtype_validation(
self._ctor_d_in_cccl_type_enum_name,
cccl.type_enum_as_name(self.d_in_cccl.value_type.type.value),
)
_dtype_validation(self._ctor_d_out_dtype, protocols.get_dtype(d_out))
_dtype_validation(self._ctor_init_dtype, h_init.dtype)
if self.d_in_cccl.type.value == 0:
self.d_in_cccl.state = d_in.ptr
self.d_in_cccl.state = protocols.get_data_pointer(d_in)
else:
self.d_in_cccl.state = d_in.state
if self.d_in_cccl.type.value == 0:
self.d_out_cccl.state = d_out.ptr
self.d_out_cccl.state = protocols.get_data_pointer(d_out)
else:
self.d_out_cccl.state = d_out.state
self.h_init_cccl.state = h_init.__array_interface__["data"][0]
Expand Down

0 comments on commit 8bae237

Please sign in to comment.