Skip to content

Commit

Permalink
Bugfix in fields.py for GPU run without cupy (#4750)
Browse files Browse the repository at this point in the history
* Bugfix in `fields.py` for GPU run without `cupy`

* apply suggestion from code review
  • Loading branch information
roelof-groenewald committed Mar 8, 2024
1 parent 510f016 commit 887a167
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion Python/pywarpx/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import cupy as cp
except ImportError:
cp = None
from .LoadThirdParty import load_cupy

try:
from mpi4py import MPI as mpi
Expand Down Expand Up @@ -516,6 +517,12 @@ def __setitem__(self, index, value):
if not isinstance(ic , slice) or len(global_shape) < 4: global_shape[3:3] = [1]
value3d.shape = global_shape

if libwarpx.libwarpx_so.Config.have_gpu:
# check if cupy is available for use
xp, cupy_status = load_cupy()
if cupy_status is not None:
libwarpx.amr.Print(cupy_status)

starts = [ixstart, iystart, izstart]
stops = [ixstop, iystop, izstop]
for mfi in self.mf:
Expand All @@ -526,7 +533,7 @@ def __setitem__(self, index, value):
slice_value = value3d[global_slices]
if libwarpx.libwarpx_so.Config.have_gpu:
# Copy data from host to device
slice_value = cp.asarray(slice_value)
slice_value = xp.asarray(slice_value)
mf_arr[block_slices] = slice_value
else:
mf_arr[block_slices] = value
Expand Down

0 comments on commit 887a167

Please sign in to comment.