diff --git a/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py b/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py index 8877d9310e1..1178754fa0d 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py +++ b/python/cuda_parallel/cuda/parallel/experimental/algorithms/reduce.py @@ -110,10 +110,7 @@ def __call__( else: self.d_out_cccl.state = d_out.state - if self.h_init_cccl.type.type.value == cccl.TypeEnum.STORAGE: - self.h_init_cccl.state = h_init._data.__array_interface__["data"][0] # type: ignore - else: - self.h_init_cccl.state = h_init.__array_interface__["data"][0] + self.h_init_cccl.state = h_init.__array_interface__["data"][0] stream_handle = protocols.validate_and_get_stream(stream) diff --git a/python/cuda_parallel/cuda/parallel/experimental/struct.py b/python/cuda_parallel/cuda/parallel/experimental/struct.py index 3ca09d39676..8338ddf84de 100644 --- a/python/cuda_parallel/cuda/parallel/experimental/struct.py +++ b/python/cuda_parallel/cuda/parallel/experimental/struct.py @@ -71,7 +71,11 @@ def __post_init__(self): [tuple(getattr(self, name) for name in anns)], dtype=self.dtype ) + def __array_interface__(self): + return self._data.__array_interface__ + setattr(this, "__post_init__", __post_init__) + setattr(this, "__array_interface__", property(__array_interface__)) # Wrap `this` in a dataclass for convenience: this = dataclass(this)