diff --git a/Python/pywarpx/fields.py b/Python/pywarpx/fields.py index dc0de9d9491..0f680595ef4 100644 --- a/Python/pywarpx/fields.py +++ b/Python/pywarpx/fields.py @@ -47,6 +47,7 @@ import cupy as cp except ImportError: cp = None +from .LoadThirdParty import load_cupy try: from mpi4py import MPI as mpi @@ -516,6 +517,12 @@ def __setitem__(self, index, value): if not isinstance(ic , slice) or len(global_shape) < 4: global_shape[3:3] = [1] value3d.shape = global_shape + if libwarpx.libwarpx_so.Config.have_gpu: + # check if cupy is available for use + xp, cupy_status = load_cupy() + if cupy_status is not None: + libwarpx.amr.Print(cupy_status) + starts = [ixstart, iystart, izstart] stops = [ixstop, iystop, izstop] for mfi in self.mf: @@ -526,7 +533,7 @@ def __setitem__(self, index, value): slice_value = value3d[global_slices] if libwarpx.libwarpx_so.Config.have_gpu: # Copy data from host to device - slice_value = cp.asarray(slice_value) + slice_value = xp.asarray(slice_value) mf_arr[block_slices] = slice_value else: mf_arr[block_slices] = value