Skip to content

Commit

Permalink
MultiFab: to_numpy/cupy (#192)
Browse files Browse the repository at this point in the history
* MultiFab: to_numpy/cupy

Add numpy & cupy helpers for MultiFab.

* Update Stub Files

---------

Co-authored-by: ax3l <[email protected]>
  • Loading branch information
ax3l and ax3l authored Sep 22, 2023
1 parent 596f0e7 commit 9aa7cb8
Show file tree
Hide file tree
Showing 11 changed files with 228 additions and 16 deletions.
39 changes: 23 additions & 16 deletions src/amrex/MultiFab.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""


def mf_to_numpy(self, copy=False, order="F"):
def mf_to_numpy(amr, self, copy=False, order="F"):
"""
Provide a Numpy view into a MultiFab.
Expand All @@ -29,13 +29,24 @@ def mf_to_numpy(self, copy=False, order="F"):
Returns
-------
list of np.array
list of numpy.array
A list of numpy n-dimensional arrays, for each local block in the
MultiFab.
"""
mf = self
if copy:
mf = amr.MultiFab(
self.box_array(),
self.dm(),
self.n_comp(),
self.n_grow_vect(),
amr.MFInfo().set_arena(amr.The_Pinned_Arena()),
)
amr.dtoh_memcpy(mf, self)

views = []
for mfi in self:
views.append(self.array(mfi).to_numpy(copy, order))
for mfi in mf:
views.append(mf.array(mfi).to_numpy(copy=False, order=order))

return views

Expand Down Expand Up @@ -80,15 +91,11 @@ def mf_to_cupy(self, copy=False, order="F"):

def register_MultiFab_extension(amr):
"""MultiFab helper methods"""
import inspect
import sys

# register member functions for every MultiFab* type
for _, MultiFab_type in inspect.getmembers(
sys.modules[amr.__name__],
lambda member: inspect.isclass(member)
and member.__module__ == amr.__name__
and member.__name__.startswith("MultiFab"),
):
MultiFab_type.to_numpy = mf_to_numpy
MultiFab_type.to_cupy = mf_to_cupy

# register member functions for the MultiFab type
amr.MultiFab.to_numpy = lambda self, copy=False, order="F": mf_to_numpy(
amr, self, copy, order
)
amr.MultiFab.to_numpy.__doc__ = mf_to_numpy.__doc__

amr.MultiFab.to_cupy = mf_to_cupy
2 changes: 2 additions & 0 deletions src/amrex/space1d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ def Print(*args, **kwargs):

from ..Array4 import register_Array4_extension
from ..ArrayOfStructs import register_AoS_extension
from ..MultiFab import register_MultiFab_extension
from ..PODVector import register_PODVector_extension
from ..StructOfArrays import register_SoA_extension

register_Array4_extension(amrex_1d_pybind)
register_MultiFab_extension(amrex_1d_pybind)
register_PODVector_extension(amrex_1d_pybind)
register_SoA_extension(amrex_1d_pybind)
register_AoS_extension(amrex_1d_pybind)
2 changes: 2 additions & 0 deletions src/amrex/space1d/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import os as os

from amrex.Array4 import register_Array4_extension
from amrex.ArrayOfStructs import register_AoS_extension
from amrex.MultiFab import register_MultiFab_extension
from amrex.PODVector import register_PODVector_extension
from amrex.StructOfArrays import register_SoA_extension
from amrex.space1d.amrex_1d_pybind import (
Expand Down Expand Up @@ -461,6 +462,7 @@ __all__ = [
"refine",
"register_AoS_extension",
"register_Array4_extension",
"register_MultiFab_extension",
"register_PODVector_extension",
"register_SoA_extension",
"size",
Expand Down
61 changes: 61 additions & 0 deletions src/amrex/space1d/amrex_1d_pybind/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4043,6 +4043,67 @@ class MultiFab(FabArray_FArrayBox):
"""
Same as sum with local=false, but for non-cell-centered data, thisskips non-unique points that are owned by multiple boxes.
"""
def to_cupy(self, copy=False, order="F"):
"""
Provide a Cupy view into a MultiFab.
Note on the order of indices:
By default, this is as in AMReX in Fortran contiguous order, indexing as
x,y,z. This has performance implications for use in external libraries such
as cupy.
The order="C" option will index as z,y,x and perform better with cupy.
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
Parameters
----------
self : amrex.MultiFab
A MultiFab class in pyAMReX
copy : bool, optional
Copy the data if true, otherwise create a view (default).
order : string, optional
F order (default) or C. C is faster with external libraries.
Returns
-------
list of cupy.array
A list of cupy n-dimensional arrays, for each local block in the
MultiFab.
Raises
------
ImportError
Raises an exception if cupy is not installed
"""
def to_numpy(self, copy=False, order="F"):
"""
Provide a Numpy view into a MultiFab.
Note on the order of indices:
By default, this is as in AMReX in Fortran contiguous order, indexing as
x,y,z. This has performance implications for use in external libraries such
as cupy.
The order="C" option will index as z,y,x and perform better with cupy.
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
Parameters
----------
self : amrex.MultiFab
A MultiFab class in pyAMReX
copy : bool, optional
Copy the data if true, otherwise create a view (default).
order : string, optional
F order (default) or C. C is faster with external libraries.
Returns
-------
list of numpy.array
A list of numpy n-dimensional arrays, for each local block in the
MultiFab.
"""
def weighted_sync(self, arg0: MultiFab, arg1: Periodicity) -> None: ...

class PIdx:
Expand Down
2 changes: 2 additions & 0 deletions src/amrex/space2d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ def Print(*args, **kwargs):

from ..Array4 import register_Array4_extension
from ..ArrayOfStructs import register_AoS_extension
from ..MultiFab import register_MultiFab_extension
from ..PODVector import register_PODVector_extension
from ..StructOfArrays import register_SoA_extension

register_Array4_extension(amrex_2d_pybind)
register_MultiFab_extension(amrex_2d_pybind)
register_PODVector_extension(amrex_2d_pybind)
register_SoA_extension(amrex_2d_pybind)
register_AoS_extension(amrex_2d_pybind)
2 changes: 2 additions & 0 deletions src/amrex/space2d/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import os as os

from amrex.Array4 import register_Array4_extension
from amrex.ArrayOfStructs import register_AoS_extension
from amrex.MultiFab import register_MultiFab_extension
from amrex.PODVector import register_PODVector_extension
from amrex.StructOfArrays import register_SoA_extension
from amrex.space2d.amrex_2d_pybind import (
Expand Down Expand Up @@ -461,6 +462,7 @@ __all__ = [
"refine",
"register_AoS_extension",
"register_Array4_extension",
"register_MultiFab_extension",
"register_PODVector_extension",
"register_SoA_extension",
"size",
Expand Down
61 changes: 61 additions & 0 deletions src/amrex/space2d/amrex_2d_pybind/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4049,6 +4049,67 @@ class MultiFab(FabArray_FArrayBox):
"""
Same as sum with local=false, but for non-cell-centered data, thisskips non-unique points that are owned by multiple boxes.
"""
def to_cupy(self, copy=False, order="F"):
"""
Provide a Cupy view into a MultiFab.
Note on the order of indices:
By default, this is as in AMReX in Fortran contiguous order, indexing as
x,y,z. This has performance implications for use in external libraries such
as cupy.
The order="C" option will index as z,y,x and perform better with cupy.
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
Parameters
----------
self : amrex.MultiFab
A MultiFab class in pyAMReX
copy : bool, optional
Copy the data if true, otherwise create a view (default).
order : string, optional
F order (default) or C. C is faster with external libraries.
Returns
-------
list of cupy.array
A list of cupy n-dimensional arrays, for each local block in the
MultiFab.
Raises
------
ImportError
Raises an exception if cupy is not installed
"""
def to_numpy(self, copy=False, order="F"):
"""
Provide a Numpy view into a MultiFab.
Note on the order of indices:
By default, this is as in AMReX in Fortran contiguous order, indexing as
x,y,z. This has performance implications for use in external libraries such
as cupy.
The order="C" option will index as z,y,x and perform better with cupy.
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
Parameters
----------
self : amrex.MultiFab
A MultiFab class in pyAMReX
copy : bool, optional
Copy the data if true, otherwise create a view (default).
order : string, optional
F order (default) or C. C is faster with external libraries.
Returns
-------
list of numpy.array
A list of numpy n-dimensional arrays, for each local block in the
MultiFab.
"""
def weighted_sync(self, arg0: MultiFab, arg1: Periodicity) -> None: ...

class PIdx:
Expand Down
2 changes: 2 additions & 0 deletions src/amrex/space3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ def Print(*args, **kwargs):

from ..Array4 import register_Array4_extension
from ..ArrayOfStructs import register_AoS_extension
from ..MultiFab import register_MultiFab_extension
from ..PODVector import register_PODVector_extension
from ..StructOfArrays import register_SoA_extension

register_Array4_extension(amrex_3d_pybind)
register_MultiFab_extension(amrex_3d_pybind)
register_PODVector_extension(amrex_3d_pybind)
register_SoA_extension(amrex_3d_pybind)
register_AoS_extension(amrex_3d_pybind)
2 changes: 2 additions & 0 deletions src/amrex/space3d/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import os as os

from amrex.Array4 import register_Array4_extension
from amrex.ArrayOfStructs import register_AoS_extension
from amrex.MultiFab import register_MultiFab_extension
from amrex.PODVector import register_PODVector_extension
from amrex.StructOfArrays import register_SoA_extension
from amrex.space3d.amrex_3d_pybind import (
Expand Down Expand Up @@ -461,6 +462,7 @@ __all__ = [
"refine",
"register_AoS_extension",
"register_Array4_extension",
"register_MultiFab_extension",
"register_PODVector_extension",
"register_SoA_extension",
"size",
Expand Down
61 changes: 61 additions & 0 deletions src/amrex/space3d/amrex_3d_pybind/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4052,6 +4052,67 @@ class MultiFab(FabArray_FArrayBox):
"""
Same as sum with local=false, but for non-cell-centered data, thisskips non-unique points that are owned by multiple boxes.
"""
def to_cupy(self, copy=False, order="F"):
"""
Provide a Cupy view into a MultiFab.
Note on the order of indices:
By default, this is as in AMReX in Fortran contiguous order, indexing as
x,y,z. This has performance implications for use in external libraries such
as cupy.
The order="C" option will index as z,y,x and perform better with cupy.
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
Parameters
----------
self : amrex.MultiFab
A MultiFab class in pyAMReX
copy : bool, optional
Copy the data if true, otherwise create a view (default).
order : string, optional
F order (default) or C. C is faster with external libraries.
Returns
-------
list of cupy.array
A list of cupy n-dimensional arrays, for each local block in the
MultiFab.
Raises
------
ImportError
Raises an exception if cupy is not installed
"""
def to_numpy(self, copy=False, order="F"):
"""
Provide a Numpy view into a MultiFab.
Note on the order of indices:
By default, this is as in AMReX in Fortran contiguous order, indexing as
x,y,z. This has performance implications for use in external libraries such
as cupy.
The order="C" option will index as z,y,x and perform better with cupy.
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
Parameters
----------
self : amrex.MultiFab
A MultiFab class in pyAMReX
copy : bool, optional
Copy the data if true, otherwise create a view (default).
order : string, optional
F order (default) or C. C is faster with external libraries.
Returns
-------
list of numpy.array
A list of numpy n-dimensional arrays, for each local block in the
MultiFab.
"""
def weighted_sync(self, arg0: MultiFab, arg1: Periodicity) -> None: ...

class PIdx:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_multifab.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,13 @@ def test_mfab_dtoh_copy(make_mfab_device):
device_max = mfab_device.max(0)
assert device_min == device_max
assert device_max == 11.0

# numpy bindings (w/ copy)
local_boxes_host = mfab_device.to_numpy(copy=True)
assert max([np.max(box) for box in local_boxes_host]) == device_max

# cupy bindings (w/o copy)
import cupy as cp

local_boxes_device = mfab_device.to_cupy()
assert max([cp.max(box) for box in local_boxes_device]) == device_max

0 comments on commit 9aa7cb8

Please sign in to comment.