Skip to content

Commit

Permalink
MultiFab: to_numpy/cupy
Browse files Browse the repository at this point in the history
Add numpy & cupy helpers for MultiFab.
  • Loading branch information
ax3l committed Sep 22, 2023
1 parent 596f0e7 commit d256160
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 16 deletions.
39 changes: 23 additions & 16 deletions src/amrex/MultiFab.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""


def mf_to_numpy(self, copy=False, order="F"):
def mf_to_numpy(amr, self, copy=False, order="F"):
"""
Provide a Numpy view into a MultiFab.
Expand All @@ -29,13 +29,24 @@ def mf_to_numpy(self, copy=False, order="F"):
Returns
-------
list of np.array
list of numpy.array
A list of numpy n-dimensional arrays, for each local block in the
MultiFab.
"""
mf = self
if copy:
mf = amr.MultiFab(
self.box_array(),
self.dm(),
self.n_comp(),
self.n_grow_vect(),
amr.MFInfo().set_arena(amr.The_Pinned_Arena()),
)
amr.dtoh_memcpy(mf, self)

views = []
for mfi in self:
views.append(self.array(mfi).to_numpy(copy, order))
for mfi in mf:
views.append(mf.array(mfi).to_numpy(copy=False, order=order))

return views

Expand Down Expand Up @@ -80,15 +91,11 @@ def mf_to_cupy(self, copy=False, order="F"):

def register_MultiFab_extension(amr):
"""MultiFab helper methods"""
import inspect
import sys

# register member functions for every MultiFab* type
for _, MultiFab_type in inspect.getmembers(
sys.modules[amr.__name__],
lambda member: inspect.isclass(member)
and member.__module__ == amr.__name__
and member.__name__.startswith("MultiFab"),
):
MultiFab_type.to_numpy = mf_to_numpy
MultiFab_type.to_cupy = mf_to_cupy

# register member functions for the MultiFab type
amr.MultiFab.to_numpy = lambda self, copy=False, order="F": mf_to_numpy(
amr, self, copy, order
)
amr.MultiFab.to_numpy.__doc__ = mf_to_numpy.__doc__

amr.MultiFab.to_cupy = mf_to_cupy
2 changes: 2 additions & 0 deletions src/amrex/space1d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ def Print(*args, **kwargs):

from ..Array4 import register_Array4_extension
from ..ArrayOfStructs import register_AoS_extension
from ..MultiFab import register_MultiFab_extension
from ..PODVector import register_PODVector_extension
from ..StructOfArrays import register_SoA_extension

register_Array4_extension(amrex_1d_pybind)
register_MultiFab_extension(amrex_1d_pybind)
register_PODVector_extension(amrex_1d_pybind)
register_SoA_extension(amrex_1d_pybind)
register_AoS_extension(amrex_1d_pybind)
2 changes: 2 additions & 0 deletions src/amrex/space2d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ def Print(*args, **kwargs):

from ..Array4 import register_Array4_extension
from ..ArrayOfStructs import register_AoS_extension
from ..MultiFab import register_MultiFab_extension
from ..PODVector import register_PODVector_extension
from ..StructOfArrays import register_SoA_extension

register_Array4_extension(amrex_2d_pybind)
register_MultiFab_extension(amrex_2d_pybind)
register_PODVector_extension(amrex_2d_pybind)
register_SoA_extension(amrex_2d_pybind)
register_AoS_extension(amrex_2d_pybind)
2 changes: 2 additions & 0 deletions src/amrex/space3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ def Print(*args, **kwargs):

from ..Array4 import register_Array4_extension
from ..ArrayOfStructs import register_AoS_extension
from ..MultiFab import register_MultiFab_extension
from ..PODVector import register_PODVector_extension
from ..StructOfArrays import register_SoA_extension

register_Array4_extension(amrex_3d_pybind)
register_MultiFab_extension(amrex_3d_pybind)
register_PODVector_extension(amrex_3d_pybind)
register_SoA_extension(amrex_3d_pybind)
register_AoS_extension(amrex_3d_pybind)
10 changes: 10 additions & 0 deletions tests/test_multifab.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,13 @@ def test_mfab_dtoh_copy(make_mfab_device):
device_max = mfab_device.max(0)
assert device_min == device_max
assert device_max == 11.0

# numpy bindings (w/ copy)
local_boxes_host = mfab_device.to_numpy(copy=True)
assert max([np.max(box) for box in local_boxes_host]) == device_max

# cupy bindings (w/o copy)
import cupy as cp

local_boxes_device = mfab_device.to_cupy()
assert max([cp.max(box) for box in local_boxes_device]) == device_max

0 comments on commit d256160

Please sign in to comment.